Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/pure-cfg/src/compiler/cl-target/cl-target.sml
ViewVC logotype

Diff of /branches/pure-cfg/src/compiler/cl-target/cl-target.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 1271, Mon Jun 6 02:45:57 2011 UTC revision 1273, Mon Jun 6 10:46:20 2011 UTC
# Line 12  Line 12 
12      structure Ty = IL.Ty      structure Ty = IL.Ty
13      structure CL = CLang      structure CL = CLang
14      structure RN = RuntimeNames      structure RN = RuntimeNames
15      structure ToC = TreeToCL      structure ToCL = TreeToCL
16    
17      type var = ToC.var      type var = ToCL.var
18      type exp = CL.exp      type exp = CL.exp
19      type stm = CL.stm      type stm = CL.stm
20    
# Line 57  Line 57 
57        | StrandScope of TreeIL.var list  (* strand initialization *)        | StrandScope of TreeIL.var list  (* strand initialization *)
58        | MethodScope of TreeIL.var list  (* method body; vars are state variables *)        | MethodScope of TreeIL.var list  (* method body; vars are state variables *)
59    
60    (* the supprted widths of vectors of reals on the target.  For the GNU vector extensions,    (* the supprted widths of vectors of reals on the target. *)
61     * the supported sizes are powers of two, but float2 is broken.  (* FIXME: for OpenCL 1.1, 3 is also valid *)
62     * NOTE: we should also consider the AVX vector hardware, which has 256-bit registers.      fun vectorWidths () = [2, 4, 8, 16]
    *)  
     fun vectorWidths () = if !RuntimeNames.doublePrecision  
           then [2, 4, 8]  
           else [4, 8]  
63    
64    (* tests for whether various expression forms can appear inline *)    (* tests for whether various expression forms can appear inline *)
65      fun inlineCons n = (n < 2)          (* vectors are inline, but not matrices *)      fun inlineCons n = (n < 2)          (* vectors are inline, but not matrices *)
# Line 73  Line 69 
69      structure Tr =      structure Tr =
70        struct        struct
71          fun fragment (ENV{info, vMap, scope}, blk) = let          fun fragment (ENV{info, vMap, scope}, blk) = let
72                val (vMap, stms) = ToC.trFragment (vMap, blk)                val (vMap, stms) = ToCL.trFragment (vMap, blk)
73                in                in
74                  (ENV{info=info, vMap=vMap, scope=scope}, stms)                  (ENV{info=info, vMap=vMap, scope=scope}, stms)
75                end                end
76          fun saveState cxt stateVars (env, args, stm) = (          fun saveState cxt stateVars (env, args, stm) = (
77                ListPair.foldrEq                ListPair.foldrEq
78                  (fn (x, e, stms) => ToC.trAssign(env, x, e)@stms)                  (fn (x, e, stms) => ToCL.trAssign(env, x, e)@stms)
79                    [stm]                    [stm]
80                      (stateVars, args)                      (stateVars, args)
81                ) handle ListPair.UnequalLengths => (                ) handle ListPair.UnequalLengths => (
82                  print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);                  print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);
83                  raise Fail(concat["saveState ", cxt, ": length mismatch"]))                  raise Fail(concat["saveState ", cxt, ": length mismatch"]))
84          fun block (ENV{vMap, scope, ...}, blk) = (case scope          fun block (ENV{vMap, scope, ...}, blk) = (case scope
85                 of StrandScope stateVars => ToC.trBlock (vMap, saveState "StrandScope" stateVars, blk)                 of StrandScope stateVars => ToCL.trBlock (vMap, saveState "StrandScope" stateVars, blk)
86                  | MethodScope stateVars => ToC.trBlock (vMap, saveState "MethodScope" stateVars, blk)                  | MethodScope stateVars => ToCL.trBlock (vMap, saveState "MethodScope" stateVars, blk)
87                  | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)                  | _ => ToCL.trBlock (vMap, fn (_, _, stm) => [stm], blk)
88                (* end case *))                (* end case *))
89          fun exp (ENV{vMap, ...}, e) = ToC.trExp(vMap, e)          fun exp (ENV{vMap, ...}, e) = ToCL.trExp(vMap, e)
90        end        end
91    
92    (* variables *)    (* variables *)
93      structure Var =      structure Var =
94        struct        struct
95          fun name (ToC.V(_, name)) = name          fun name (ToCL.V(_, name)) = name
96           fun global (Prog{globals,imgGlobals, ...}, name, ty) = let           fun global (Prog{globals,imgGlobals, ...}, name, ty) = let
97                val ty' = ToC.trType ty                val ty' = ToCL.trType ty
98                fun isImgGlobal (imgGlobals, Ty.ImageTy(ImageInfo.ImgInfo{dim, ...}), name) =  imgGlobals  := (name,dim):: !imgGlobals                fun isImgGlobal (imgGlobals, Ty.ImageTy(ImageInfo.ImgInfo{dim, ...}), name) =  imgGlobals  := (name,dim):: !imgGlobals
99                  | isImgGlobal (imgGlobals, _, _) =  ()                  | isImgGlobal (imgGlobals, _, _) =  ()
100                in                in
101                  globals := CL.D_Var([], ty', name, NONE) :: !globals;                  globals := CL.D_Var([], ty', name, NONE) :: !globals;
102                  isImgGlobal(imgGlobals,ty,name);                  isImgGlobal(imgGlobals,ty,name);
103               ToC.V(ty', name)               ToCL.V(ty', name)
104                end                end
105          fun param x = ToC.V(ToC.trType(V.ty x), V.name x)          fun param x = ToCL.V(ToCL.trType(V.ty x), V.name x)
106          fun state (Strand{state, ...}, x) = let          fun state (Strand{state, ...}, x) = let
107                val ty' = ToC.trType(V.ty x)                val ty' = ToCL.trType(V.ty x)
108                val x' = ToC.V(ty', V.name x)                val x' = ToCL.V(ty', V.name x)
109                in                in
110                  state := x' :: !state;                  state := x' :: !state;
111                  x'                  x'
# Line 211  Line 207 
207                      end                      end
208                val baseInit = mapi (fn (i, (_, e, _)) => (i, CL.I_Exp e)) iters                val baseInit = mapi (fn (i, (_, e, _)) => (i, CL.I_Exp e)) iters
209                val sizeInit = mapi                val sizeInit = mapi
210                      (fn (i, (ToC.V(ty, _), lo, hi)) =>                      (fn (i, (ToCL.V(ty, _), lo, hi)) =>
211                          (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, ty))))                          (i, CL.I_Exp(CL.mkBinOp(CL.mkBinOp(hi, CL.#-, lo), CL.#+, CL.E_Int(1, ty))))
212                      ) iters                      ) iters
213                    val numStrandsVar = "numStrandsVar"                    val numStrandsVar = "numStrandsVar"
# Line 221  Line 217 
217                        CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),                        CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
218                        CL.mkDecl(CL.int32,"numDims",SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nDims, CL.int32))))                        CL.mkDecl(CL.int32,"numDims",SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nDims, CL.int32))))
219                            ]                            ]
   
220                                    val numStrandsLoopBody = CL.mkExpStm(CL.mkAssignOp(CL.E_Var numStrandsVar, CL.*=,CL.mkSubscript(CL.E_Var "size",CL.E_Var "i")))                                    val numStrandsLoopBody = CL.mkExpStm(CL.mkAssignOp(CL.E_Var numStrandsVar, CL.*=,CL.mkSubscript(CL.E_Var "size",CL.E_Var "i")))
   
   
221                                    val numStrandsLoop =  CL.mkFor([(CL.intTy, "i", CL.E_Int(0,CL.intTy))],                                    val numStrandsLoop =  CL.mkFor([(CL.intTy, "i", CL.E_Int(0,CL.intTy))],
222                                                           CL.mkBinOp(CL.E_Var "i", CL.#<, CL.E_Var "numDims"),                                                           CL.mkBinOp(CL.E_Var "i", CL.#<, CL.E_Var "numDims"),
223                                                           [CL.mkPostOp(CL.E_Var "i", CL.^++)], numStrandsLoopBody)                                                           [CL.mkPostOp(CL.E_Var "i", CL.^++)], numStrandsLoopBody)
224                                    in                                    in
225                                  numDims := nDims;                                  numDims := nDims;
226                                  initially := allocCode @ [numStrandsLoop]                                  initially := allocCode @ [numStrandsLoop]
   
227                    end                    end
228    
229    
# Line 305  Line 297 
297                                  val fmt = CL.E_Str(                                  val fmt = CL.E_Str(
298                                        String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))                                        String.concatWith " " (List.tabulate(d, fn _ => !RN.gIntFormat))
299                                        ^ "\n")                                        ^ "\n")
300                                  val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))                                  val args = List.tabulate (d, fn i => ToCL.vecIndex(outState, i))
301                                  in                                  in
302                                    fmt :: args                                    fmt :: args
303                                  end                                  end
# Line 314  Line 306 
306                                  val fmt = CL.E_Str(                                  val fmt = CL.E_Str(
307                                        String.concatWith " " (List.tabulate(d, fn _ => "%f"))                                        String.concatWith " " (List.tabulate(d, fn _ => "%f"))
308                                        ^ "\n")                                        ^ "\n")
309                                  val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))                                  val args = List.tabulate (d, fn i => ToCL.vecIndex(outState, i))
310                                  in                                  in
311                                    fmt :: args                                    fmt :: args
312                                  end                                  end
# Line 356  Line 348 
348          fun genStrandTyDef (Strand{tyName, state,...}) =          fun genStrandTyDef (Strand{tyName, state,...}) =
349              (* the type declaration for the strand's state struct *)              (* the type declaration for the strand's state struct *)
350                CL.D_StructDef(                CL.D_StructDef(
351                        List.rev (List.map (fn ToC.V(ty, x) => (ty, x)) (!state)),                        List.rev (List.map (fn ToCL.V(ty, x) => (ty, x)) (!state)),
352                        tyName)                        tyName)
353    
354    
# Line 453  Line 445 
445          (* generates the main function of host code *)          (* generates the main function of host code *)
446          fun genHostMain() = let          fun genHostMain() = let
447                  val setupCall = [CL.mkCall(RN.setupFName,[CL.E_Var RN.globalsVarName])]                  val setupCall = [CL.mkCall(RN.setupFName,[CL.E_Var RN.globalsVarName])]
448                  val globalsDecl = CL.mkDecl(CL.T_Ptr(CL.T_Named RN.globalsTy), RN.globalsVarName,SOME(CL.I_Exp(CL.mkApply("malloc",                val globalsDecl = CL.mkDecl(
449                                                                          [CL.mkApply("sizeof",[CL.E_Var RN.globalsTy])]))))                      CL.T_Ptr(CL.T_Named RN.globalsTy),
450                        RN.globalsVarName,
451                        SOME(CL.I_Exp(CL.mkApply("malloc", [CL.mkApply("sizeof",[CL.E_Var RN.globalsTy])]))))
452                  val initGlobalsCall = CL.mkCall(RN.initGlobals,[CL.E_Var RN.globalsVarName])                  val initGlobalsCall = CL.mkCall(RN.initGlobals,[CL.E_Var RN.globalsVarName])
453                  val returnStm = [CL.mkReturn(SOME(CL.E_Int(0,CL.intTy)))]                  val returnStm = [CL.mkReturn(SOME(CL.E_Int(0,CL.intTy)))]
454                  val params = [                  val params = [
# Line 495  Line 489 
489                  val params = [                  val params = [
490                           CL.PARAM([],CL.T_Named("cl_device_id"), deviceVar)                           CL.PARAM([],CL.T_Named("cl_device_id"), deviceVar)
491                           ]                           ]
492                  val delcarations = [CL.mkDecl(CL.clProgramTy, programVar, NONE),                val declarations = [
493                        CL.mkDecl(CL.clProgramTy, programVar, NONE),
494                            CL.mkDecl(CL.clKernelTy, kernelVar, NONE),                            CL.mkDecl(CL.clKernelTy, kernelVar, NONE),
495                            CL.mkDecl(CL.clCmdQueueTy, cmdVar, NONE),                            CL.mkDecl(CL.clCmdQueueTy, cmdVar, NONE),
496                            CL.mkDecl(CL.clContextTy, contextVar, NONE),                            CL.mkDecl(CL.clContextTy, contextVar, NONE),
# Line 510  Line 505 
505                            CL.mkDecl(CL.clMemoryTy,clOutStateVar,NONE),                            CL.mkDecl(CL.clMemoryTy,clOutStateVar,NONE),
506                            CL.mkDecl(CL.T_Ptr(CL.T_Named tyName), outStateVar,NONE),                            CL.mkDecl(CL.T_Ptr(CL.T_Named tyName), outStateVar,NONE),
507                            CL.mkDecl(CL.charPtr, clFNVar,SOME(CL.I_Exp(CL.E_Str filename))),                            CL.mkDecl(CL.charPtr, clFNVar,SOME(CL.I_Exp(CL.E_Str filename))),
508                            CL.mkDecl(CL.charPtr, headerFNVar,SOME(CL.I_Exp(CL.E_Str "../src/include/Diderot/opencl_types.h"))),  (* FIXME:  use Paths.diderotInclude *)
509                        CL.mkDecl(CL.charPtr, headerFNVar,SOME(CL.I_Exp(CL.E_Str "../src/include/Diderot/cl-types.h"))),
510                            CL.mkDecl(CL.T_Array(CL.charPtr,SOME(2)),sourcesVar,NONE),                            CL.mkDecl(CL.T_Array(CL.charPtr,SOME(2)),sourcesVar,NONE),
511                            CL.mkDecl(CL.T_Array(CL.T_Named "size_t",SOME(nDims)),globalVar,NONE),                            CL.mkDecl(CL.T_Array(CL.T_Named "size_t",SOME(nDims)),globalVar,NONE),
512                            CL.mkDecl(CL.T_Array(CL.T_Named "size_t",SOME(nDims)),localVar,NONE),                            CL.mkDecl(CL.T_Array(CL.T_Named "size_t",SOME(nDims)),localVar,NONE),
513                            CL.mkDecl(CL.intTy,numDevicesVar,SOME(CL.I_Exp(CL.E_Int(~1,CL.intTy)))),                            CL.mkDecl(CL.intTy,numDevicesVar,SOME(CL.I_Exp(CL.E_Int(~1,CL.intTy)))),
514                            CL.mkDecl(CL.T_Array(CL.T_Named "cl_platform_id", SOME(1)), platformsVar, NONE),                            CL.mkDecl(CL.T_Array(CL.T_Named "cl_platform_id", SOME(1)), platformsVar, NONE),
515                            CL.mkDecl(CL.intTy,"num_platforms",SOME(CL.I_Exp(CL.E_Int(~1,CL.intTy))))]                      CL.mkDecl(CL.intTy,"num_platforms",SOME(CL.I_Exp(CL.E_Int(~1,CL.intTy))))
516                    ]
517                  (*Setup Global Variables *)                  (*Setup Global Variables *)
518                  val globalsDecl = CL.mkDecl(CL.T_Ptr(CL.T_Named RN.globalsTy), RN.globalsVarName,SOME(CL.I_Exp(CL.mkApply("malloc",                val globalsDecl = CL.mkDecl(
519                                                                          [CL.mkApply("sizeof",[CL.E_Var RN.globalsTy])]))))                      CL.T_Ptr(CL.T_Named RN.globalsTy),
520                        RN.globalsVarName,
521                        SOME(CL.I_Exp(CL.mkApply("malloc", [CL.mkApply("sizeof",[CL.E_Var RN.globalsTy])]))))
522                  val initGlobalsCall = CL.mkCall(RN.initGlobals,[CL.E_Var RN.globalsVarName])                  val initGlobalsCall = CL.mkCall(RN.initGlobals,[CL.E_Var RN.globalsVarName])
523    
524                  (* Retrieve the platforms                  (* Retrieve the platforms
# Line 786  Line 784 
784    
785    
786                  (* Body put all the statments together *)                  (* Body put all the statments together *)
787                  val body =  delcarations @ [globalsDecl,initGlobalsCall] (*@ platformStm @ devicesStm *) @ contextStm @ commandStm @ !initially @ [strandSize] @                  val body =  declarations @ [globalsDecl,initGlobalsCall] (*@ platformStm @ devicesStm *) @ contextStm @ commandStm @ !initially @ [strandSize] @
788                                     strandsArrays @ globalAndlocalStms @ [widthDel,strands_init]  @ clStrandObjects @ clGlobalBuffers @ sourceStms  @ create_build_stms  (*@                                     strandsArrays @ globalAndlocalStms @ [widthDel,strands_init]  @ clStrandObjects @ clGlobalBuffers @ sourceStms  @ create_build_stms  (*@
789                                     kernelArguments @ clGlobalArguments @ enqueueStm @  [outputStm] @ freeStms @ outputData *)                                     kernelArguments @ clGlobalArguments @ enqueueStm @  [outputStm] @ freeStms @ outputData *)
790    
# Line 900  Line 898 
898                val cFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}                val cFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
899                val clOutS = TextIO.openOut clFileName                val clOutS = TextIO.openOut clFileName
900                val cOutS = TextIO.openOut cFileName                val cOutS = TextIO.openOut cFileName
901    (* FIXME: need to use PrintAsC and PrintAsCL *)
902                val clppStrm = PrintAsC.new clOutS                val clppStrm = PrintAsC.new clOutS
903                val cppStrm = PrintAsC.new cOutS                val cppStrm = PrintAsC.new cOutS
904                fun cppDecl dcl = PrintAsC.output(cppStrm, dcl)                fun cppDecl dcl = PrintAsC.output(cppStrm, dcl)
905                fun clppDecl dcl = PrintAsC.output(clppStrm, dcl)                fun clppDecl dcl = PrintAsC.output(clppStrm, dcl)
906                val strands = AtomTable.listItems strands                val strands = AtomTable.listItems strands
907                val single_strand as Strand{name, tyName, code,init_code, ...}= hd(strands)                val [strand as Strand{name, tyName, code,init_code, ...}] = strands
908                in                in
   
909              (* Generate the OpenCl file *)              (* Generate the OpenCl file *)
910                    clppDecl (CL.D_Verbatim([
911                        if double
912                          then "#define DIDEROT_DOUBLE_PRECISION"
913                          else "#define DIDEROT_SINGLE_PRECISION",
914                        "#define DIDEROT_TARGET_CL",
915                        "#include \"Diderot/cl-types.h\""
916                      ]));
917              List.app clppDecl (List.rev (!globals));              List.app clppDecl (List.rev (!globals));
918              clppDecl (genGlobalStruct (!globals));              clppDecl (genGlobalStruct (!globals));
919              clppDecl (genStrandTyDef single_strand);                  clppDecl (genStrandTyDef strand);
920              List.app clppDecl (!code);              List.app clppDecl (!code);
921              clppDecl (genKernelFun (single_strand,!numDims,globals,imgGlobals));                  clppDecl (genKernelFun (strand,!numDims,globals,imgGlobals));
   
   
922              (* Generate the Host file .c *)              (* Generate the Host file .c *)
923              cppDecl (CL.D_Verbatim([              cppDecl (CL.D_Verbatim([
924                          if double                          if double
925                            then "#define DIDEROT_DOUBLE_PRECISION"                            then "#define DIDEROT_DOUBLE_PRECISION"
926                            else "#define DIDEROT_SINGLE_PRECISION",                            else "#define DIDEROT_SINGLE_PRECISION",
927                           "#include \"Diderot/diderot.h\"",                      "#define DIDEROT_TARGET_CL",
928                           "#include <OpenCL/OpenCL.h>",                      "#include \"Diderot/diderot.h\""
                          "#include <assert.h>"  
929                        ]));                        ]));
   
             (* cppDecl (CL.D_Verbatim([ "#include <OpenCL/OpenCL.h>",  
                                                                  "#include Diderot/diderot.h"])); *)  
930                  List.app cppDecl (List.rev (!globals));                  List.app cppDecl (List.rev (!globals));
931              cppDecl (genGlobalStruct (!globals));              cppDecl (genGlobalStruct (!globals));
932              cppDecl (genStrandTyDef single_strand);                  cppDecl (genStrandTyDef strand);
933                   cppDecl  (!init_code);                   cppDecl  (!init_code);
934                   cppDecl (genStrandInit(single_strand,!numDims));                  cppDecl (genStrandInit(strand,!numDims));
935                   cppDecl (genStrandPrint(single_strand,!numDims));                  cppDecl (genStrandPrint(strand,!numDims));
936             (* cppDecl (genKernelLoader());*)             (* cppDecl (genKernelLoader());*)
937              List.app cppDecl (List.rev (!topDecls));              List.app cppDecl (List.rev (!topDecls));
938              cppDecl (genHostSetupFunc (single_strand,clFileName,!numDims,initially,imgGlobals));                  cppDecl (genHostSetupFunc (strand, clFileName, !numDims, initially, imgGlobals));
   
                 (*List.app (fn strand => List.app ppDecl (genStrand strand)) strands;  
                  genStrandTable (ppStrm, strands);  
                 ppDecl (!initially);*)  
   
939                  PrintAsC.close cppStrm;                  PrintAsC.close cppStrm;
940                  PrintAsC.close clppStrm;                  PrintAsC.close clppStrm;
941                  TextIO.closeOut cOutS;                  TextIO.closeOut cOutS;
# Line 976  Line 970 
970                  RunCC.link (basename, ldOpts)                  RunCC.link (basename, ldOpts)
971                  end                  end
972    
   
   
973        end        end
974    
975    (* strands *)    (* strands *)
976      structure Strand =      structure Strand =
977        struct        struct
# Line 1007  Line 1000 
1000                val fName = RN.strandInit name                val fName = RN.strandInit name
1001                val params =                val params =
1002                      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::                      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
1003                        List.map (fn (ToC.V(ty, x)) => CL.PARAM([], ty, x)) params                        List.map (fn (ToCL.V(ty, x)) => CL.PARAM([], ty, x)) params
1004                val initFn = CL.D_Func([], CL.voidTy, fName, params, init)                val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
1005                in                in
1006                          init_code := initFn                          init_code := initFn
# Line 1025  Line 1018 
1018                                  code := methFn :: !code                                  code := methFn :: !code
1019                end                end
1020    
1021          fun output (Strand{output, ...}, ty, ToC.V(_, x)) = output := SOME(ty, x)          fun output (Strand{output, ...}, ty, ToCL.V(_, x)) = output := SOME(ty, x)
1022    
1023        end        end
1024    

Legend:
Removed from v.1271  
changed lines
  Added in v.1273

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0