Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/pure-cfg/src/compiler/cl-target/cl-target.sml
ViewVC logotype

Diff of /branches/pure-cfg/src/compiler/cl-target/cl-target.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 1314, Sat Jun 11 17:02:26 2011 UTC revision 1315, Sat Jun 11 21:10:15 2011 UTC
# Line 1  Line 1 
1  (* c-target.sml  (* cl-target.sml
2   *   *
3   * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4   * All rights reserved.   * All rights reserved.
# Line 15  Line 15 
15      structure ToCL = TreeToCL      structure ToCL = TreeToCL
16      structure N = CNames      structure N = CNames
17    
18    (* variable translation *)    (* C variable translation *)
19      structure TrVar =      structure TrCVar =
20        struct        struct
21          type env = CL.typed_var TreeIL.Var.Map.map          type env = CL.typed_var TreeIL.Var.Map.map
22          fun lookup (env, x) = (case V.Map.find (env, x)          fun lookup (env, x) = (case V.Map.find (env, x)
23                 of SOME(CL.V(_, x')) => x'                 of SOME(CL.V(_, x')) => x'
24                  | NONE => raise Fail(concat["lookup(_, ", V.name x, ")"])                  | NONE => raise Fail(concat["TrCVar.lookup(_, ", V.name x, ")"])
25                (* end case *))                (* end case *))
26        (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)        (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)
27          fun lvalueVar (env, x) = (case V.kind x          fun lvalueVar (env, x) = (case V.kind x
28                 of IL.VK_Global => CL.mkVar(lookup(env, x))                 of IL.VK_Global => CL.mkIndirect(CL.mkVar RN.globalsVarName, lookup(env, x))
29                  | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, x))                  | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, x))
30                  | IL.VK_Local => CL.mkVar(lookup(env, x))                  | IL.VK_Local => CL.mkVar(lookup(env, x))
31                (* end case *))                (* end case *))
32        (* translate a variable that occurs in an r-value context *)        (* translate a variable that occurs in an r-value context *)
33          fun rvalueVar (env, x) = (case V.kind x          fun rvalueVar (env, x) = (case V.kind x
                of IL.VK_Global => CL.mkVar(lookup(env, x))  
                 | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x))  
                 | IL.VK_Local => CL.mkVar(lookup(env, x))  
               (* end case *))  
       end  
   
         structure ToC = TreeToCFn (TrVar)  
   
   (* C variable translation *)  
     structure TrCVar =  
       struct  
         type env = CL.typed_var TreeIL.Var.Map.map  
         fun lookup (env, x) = (case V.Map.find (env, x)  
                of SOME(CL.V(_, x')) => x'  
                 | NONE => raise Fail(concat["TrCVar.lookup(_, ", V.name x, ")"])  
               (* end case *))  
       (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)  
         fun lvalueVar (env, x) = (case V.kind x  
34                 of IL.VK_Global => CL.mkIndirect(CL.mkVar RN.globalsVarName, lookup(env, x))                 of IL.VK_Global => CL.mkIndirect(CL.mkVar RN.globalsVarName, lookup(env, x))
35                  | IL.VK_State strand => raise Fail "unexpected strand context"                  | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x))
36                  | IL.VK_Local => CL.mkVar(lookup(env, x))                  | IL.VK_Local => CL.mkVar(lookup(env, x))
37                (* end case *))                (* end case *))
       (* translate a variable that occurs in an r-value context *)  
         val rvalueVar = lvalueVar  
38        end        end
39    
40      structure ToC = TreeToCFn (TrCVar)      structure ToC = TreeToCFn (TrCVar)
# Line 135  Line 115 
115          fun fragment (ENV{info, vMap, scope}, blk) = let          fun fragment (ENV{info, vMap, scope}, blk) = let
116                val (vMap, stms) = (case scope                val (vMap, stms) = (case scope
117                       of GlobalScope => ToC.trFragment (vMap, blk)                       of GlobalScope => ToC.trFragment (vMap, blk)
118                          | InitiallyScope => ToC.trFragment (vMap, blk)
119                        | _ => ToCL.trFragment (vMap, blk)                        | _ => ToCL.trFragment (vMap, blk)
120                      (* end case *))                      (* end case *))
121                in                in
122                  (ENV{info=info, vMap=vMap, scope=scope}, stms)                  (ENV{info=info, vMap=vMap, scope=scope}, stms)
123                end                end
124          fun saveState cxt stateVars (env, args, stm) = (          fun block (ENV{vMap, scope, ...}, blk) = let
125                  fun saveState cxt stateVars trAssign (env, args, stm) = (
126                ListPair.foldrEq                ListPair.foldrEq
127                  (fn (x, e, stms) => ToCL.trAssign(env, x, e)@stms)                        (fn (x, e, stms) => trAssign(env, x, e)@stms)
128                    [stm]                    [stm]
129                      (stateVars, args)                      (stateVars, args)
130                ) handle ListPair.UnequalLengths => (                ) handle ListPair.UnequalLengths => (
131                  print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);                  print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);
132                  raise Fail(concat["saveState ", cxt, ": length mismatch"]))                  raise Fail(concat["saveState ", cxt, ": length mismatch"]))
133          fun block (ENV{vMap, scope, ...}, blk) = (case scope                in
134  (* NOTE: if we move strand initialization to the GPU, then we'll have to change the following line! *)                  case scope
135                 of StrandScope stateVars => ToC.trBlock (vMap, saveState "StrandScope" stateVars, blk)  (* NOTE: if we move strand initialization to the GPU, then we'll have to change the following code! *)
136                  | MethodScope stateVars => ToCL.trBlock (vMap, saveState "MethodScope" stateVars, blk)                   of StrandScope stateVars =>
137                  | InitiallyScope => ToCL.trBlock (vMap, fn (_, _, stm) => [stm], blk)                        ToC.trBlock (vMap, saveState "StrandScope" stateVars ToC.trAssign, blk)
138                      | MethodScope stateVars =>
139                          ToCL.trBlock (vMap, saveState "MethodScope" stateVars ToCL.trAssign, blk)
140                      | InitiallyScope => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)
141                  | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)                  | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)
142                (* end case *))                  (* end case *)
143                  end
144          fun exp (ENV{vMap, ...}, e) = ToCL.trExp(vMap, e)          fun exp (ENV{vMap, ...}, e) = ToCL.trExp(vMap, e)
145        end        end
146    
# Line 304  Line 290 
290                          SOME(CL.I_Exp(                          SOME(CL.I_Exp(
291                            CL.E_Cast(strandTy,                            CL.E_Cast(strandTy,
292                            CL.E_Apply(N.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),                            CL.E_Apply(N.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
293                        CL.mkCall(N.strandInit name, CL.E_Var "sp" :: args),                        CL.mkCall(N.strandInit name,
294                            CL.E_Var RN.globalsVarName :: CL.E_Var "sp" :: args),
295                        CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))                        CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
296                      ])                      ])
297                  | mkLoopNest ((CL.V(ty, param), lo, hi)::iters) = let                  | mkLoopNest ((CL.V(ty, param), lo, hi)::iters) = let
# Line 326  Line 313 
313                      allocCode @                      allocCode @
314                      iterCode @                      iterCode @
315                      [CL.mkReturn(SOME(CL.E_Var "wrld"))])                      [CL.mkReturn(SOME(CL.E_Var "wrld"))])
316                val initFn = CL.D_Func([], worldTy, N.initially, [], body)                val initFn = CL.D_Func([], worldTy, N.initially, [CL.PARAM([], globPtrTy, RN.globalsVarName)], body)
317                in                in
318                  initially := initFn                  initially := initFn
319                end                end
# Line 345  Line 332 
332                      val prArgs = (case ty                      val prArgs = (case ty
333                             of Ty.IVecTy 1 => [CL.E_Str(!N.gIntFormat ^ "\n"), outState]                             of Ty.IVecTy 1 => [CL.E_Str(!N.gIntFormat ^ "\n"), outState]
334                              | Ty.IVecTy d => let                              | Ty.IVecTy d => let
335                                  val fmt = CL.E_Str(                                  val fmt = CL.mkStr(
336                                        String.concatWith " " (List.tabulate(d, fn _ => !N.gIntFormat))                                        String.concatWith " " (List.tabulate(d, fn _ => !N.gIntFormat))
337                                        ^ "\n")                                        ^ "\n")
338                                  val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))                                  val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))
339                                  in                                  in
340                                    fmt :: args                                    fmt :: args
341                                  end                                  end
342                              | Ty.TensorTy[] => [CL.E_Str "%f\n", outState]                              | Ty.TensorTy[] => [CL.mkStr "%f\n", outState]
343                              | Ty.TensorTy[d] => let                              | Ty.TensorTy[d] => let
344                                  val fmt = CL.E_Str(                                  val fmt = CL.mkStr(
345                                        String.concatWith " " (List.tabulate(d, fn _ => "%f"))                                        String.concatWith " " (List.tabulate(d, fn _ => "%f"))
346                                        ^ "\n")                                        ^ "\n")
347                                  val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))                                  val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))
# Line 392  Line 379 
379                        ]))                        ]))
380                fun genDataBuffers([],_,_) = []                fun genDataBuffers([],_,_) = []
381                  | genDataBuffers((var,nDims)::globals,contextVar,errVar) = let                  | genDataBuffers((var,nDims)::globals,contextVar,errVar) = let
382                        val hostVar = CL.mkIndirect(CL.mkVar RN.globalsVarName, var)
383  (* FIXME: use CL constructors to build expressions (not strings) *)  (* FIXME: use CL constructors to build expressions (not strings) *)
384                      val size = if nDims = 1                      fun sizeExp i = CL.mkSubscript(CL.mkIndirect(hostVar, "size"), CL.mkInt i)
385                              then CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*,                      val size = CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*, sizeExp 0)
386                                CL.mkIndirect(CL.mkVar var, "size[0]"))                      val size = if (nDims > 1)
387                            else if nDims = 2                            then CL.mkBinOp(size, CL.#*, sizeExp 1)
388                              then CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*,                            else size
389                                CL.mkIndirect(CL.mkVar var, concat["size[0]", " * ", var, "->size[1]"]))                      val size = if (nDims > 2)
390                              else CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*,                            then CL.mkBinOp(size, CL.#*, sizeExp 2)
391                                CL.mkIndirect(CL.mkVar var,concat["size[0]", " * ", var, "->size[1] * ", var, "->size[2]"]))                            else size
392                      in                      in
393                        CL.mkDecl(clMemoryTy, RN.addBufferSuffix var ,NONE)::                        CL.mkDecl(clMemoryTy, RN.addBufferSuffix var ,NONE)::
394                        CL.mkDecl(clMemoryTy, RN.addBufferSuffixData var ,NONE)::                        CL.mkDecl(clMemoryTy, RN.addBufferSuffixData var ,NONE)::
395                        CL.mkAssign(CL.mkVar(RN.addBufferSuffix var), CL.mkApply("clCreateBuffer",                        CL.mkAssign(CL.mkVar(RN.addBufferSuffix var),
396                          [CL.mkVar contextVar,                          CL.mkApply("clCreateBuffer", [
397                                CL.mkVar contextVar,
398                          CL.mkVar "CL_MEM_COPY_HOST_PTR",                          CL.mkVar "CL_MEM_COPY_HOST_PTR",
399                          CL.mkApply("sizeof",[CL.mkVar (RN.imageTy nDims)]),                          CL.mkApply("sizeof",[CL.mkVar (RN.imageTy nDims)]),
400                          CL.mkVar var,                              hostVar,
401                          CL.mkUnOp(CL.%&,CL.mkVar errVar)])) ::                              CL.mkUnOp(CL.%&,CL.mkVar errVar)
402                        CL.mkAssign(CL.mkVar(RN.addBufferSuffixData var), CL.mkApply("clCreateBuffer",                            ])) ::
403                          [CL.mkVar contextVar,                        CL.mkAssign(CL.mkVar(RN.addBufferSuffixData var),
404                            CL.mkApply("clCreateBuffer", [
405                                CL.mkVar contextVar,
406                          CL.mkVar "CL_MEM_COPY_HOST_PTR",                          CL.mkVar "CL_MEM_COPY_HOST_PTR",
407                          size,                          size,
408                          CL.mkIndirect(CL.mkVar var,"data"),                              CL.mkIndirect(hostVar, "data"),
409                          CL.mkUnOp(CL.%&,CL.mkVar errVar)])):: genDataBuffers(globals,contextVar,errVar)                              CL.mkUnOp(CL.%&,CL.mkVar errVar)
410                              ])) :: genDataBuffers(globals,contextVar,errVar)
411                      end                      end
412                in                in
413                  globalBufferDecl :: globalBuffer :: genDataBuffers(globals,contextVar,errVar)                  globalBufferDecl :: globalBuffer :: genDataBuffers(globals,contextVar,errVar)
# Line 454  Line 446 
446                val errVar = "err"                val errVar = "err"
447                val imgDataSizeVar = "image_dataSize"                val imgDataSizeVar = "image_dataSize"
448                val params = [                val params = [
449                          CL.PARAM([], globPtrTy, RN.globalsVarName),
450                        CL.PARAM([],CL.T_Named("cl_context"), "context"),                        CL.PARAM([],CL.T_Named("cl_context"), "context"),
451                        CL.PARAM([],CL.T_Named("cl_kernel"), "kernel"),                        CL.PARAM([],CL.T_Named("cl_kernel"), "kernel"),
452                        CL.PARAM([],CL.T_Named("int"), "argStart")                        CL.PARAM([],CL.T_Named("int"), "argStart")
# Line 558  Line 551 
551                  List.app doVar globals                  List.app doVar globals
552                end                end
553    
554            fun genStrandDesc (Strand{name, output, ...}) = let
555                (* the strand's descriptor object *)
556                  val descI = let
557                        fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
558                        val SOME(outTy, _) = !output
559                        in
560                          CL.I_Struct[
561                              ("name", CL.I_Exp(CL.mkStr name)),
562                              ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(N.strandTy name)))),
563    (*
564                              ("outputSzb", CL.I_Exp(CL.mkSizeof(ToC.trTy outTy))),
565    *)
566                              ("update", fnPtr("update_method_t", "0")),
567                              ("print", fnPtr("print_method_t", name ^ "_print"))
568                            ]
569                        end
570                  val desc = CL.D_Var([], CL.T_Named N.strandDescTy, N.strandDesc name, SOME descI)
571                  in
572                    desc
573                  end
574    
575          (* generate the table of strand descriptors *)
576            fun genStrandTable (declFn, strands) = let
577                  val nStrands = length strands
578                  fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)))
579                  fun genInits (_, []) = []
580                    | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
581                  in
582                    declFn (CL.D_Var([], CL.int32, N.numStrands,
583                      SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
584                    declFn (CL.D_Var([],
585                      CL.T_Array(CL.T_Ptr(CL.T_Named N.strandDescTy), SOME nStrands),
586                      N.strands,
587                      SOME(CL.I_Array(genInits (0, strands)))))
588                  end
589    
590          fun genSrc (baseName, prog) = let          fun genSrc (baseName, prog) = let
591                val Prog{double, globals, topDecls, strands, initially, imgGlobals, numDims, ...} = prog                val Prog{double, globals, topDecls, strands, initially, imgGlobals, numDims, ...} = prog
592                val clFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "cl"}                val clFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "cl"}
# Line 592  Line 621 
621                      "#define DIDEROT_TARGET_CL",                      "#define DIDEROT_TARGET_CL",
622                      "#include \"Diderot/diderot.h\""                      "#include \"Diderot/diderot.h\""
623                    ]));                    ]));
624                    cppDecl (CL.D_Var(["static"], CL.charPtr, "ProgramName",
625                      SOME(CL.I_Exp(CL.mkStr name))));
626    (* FIXME: I don't think that the following is necessary, since we have the global struct. [jhr]
627                  genGlobals (cppDecl, #hostTy, !globals);                  genGlobals (cppDecl, #hostTy, !globals);
628    *)
629                  cppDecl (genGlobalStruct (#hostTy, !globals));                  cppDecl (genGlobalStruct (#hostTy, !globals));
630                  cppDecl (genStrandTyDef (#hostTy, strand));                  cppDecl (genStrandTyDef (#hostTy, strand));
631                  cppDecl  (!init_code);                  cppDecl  (!init_code);
632                  cppDecl (genStrandPrint strand);                  cppDecl (genStrandPrint strand);
633                  List.app cppDecl (List.rev (!topDecls));                  List.app cppDecl (List.rev (!topDecls));
634                  cppDecl (genGlobalBuffersArgs (imgGlobals));                  cppDecl (genGlobalBuffersArgs imgGlobals);
635                    List.app (fn strand => cppDecl (genStrandDesc strand)) strands;
636                    genStrandTable (cppDecl, strands);
637                  cppDecl (!initially);                  cppDecl (!initially);
638                  PrintAsC.close cppStrm;                  PrintAsC.close cppStrm;
639                  PrintAsCL.close clppStrm;                  PrintAsCL.close clppStrm;

Legend:
Removed from v.1314  
changed lines
  Added in v.1315

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0