Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/pure-cfg/src/compiler/cl-target/cl-target.sml
ViewVC logotype

Diff of /branches/pure-cfg/src/compiler/cl-target/cl-target.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 1308, Sat Jun 11 14:21:07 2011 UTC revision 1382, Thu Jun 23 20:03:05 2011 UTC
# Line 1  Line 1 
1  (* c-target.sml  (* cl-target.sml
2   *   *
3   * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4   * All rights reserved.   * All rights reserved.
# Line 15  Line 15 
15      structure ToCL = TreeToCL      structure ToCL = TreeToCL
16      structure N = CNames      structure N = CNames
17    
18    (* variable translation *)    (* translate TreeIL types to shadow types *)
19      structure TrVar =      fun shadowTy ty = (case ty
20        struct             of Ty.BoolTy => CL.T_Named "cl_bool"
21          type env = CL.typed_var TreeIL.Var.Map.map              | Ty.StringTy => raise Fail "unexpected string type"
22          fun lookup (env, x) = (case V.Map.find (env, x)              | Ty.IVecTy 1 => CL.T_Named(RN.shadowIntTy ())
23                 of SOME(CL.V(_, x')) => x'              | Ty.IVecTy n => raise Fail "unexpected int vector type"
24                  | NONE => raise Fail(concat["lookup(_, ", V.name x, ")"])              | Ty.TensorTy[] => CL.T_Named(RN.shadowRealTy ())
25                (* end case *))              | Ty.TensorTy[n] => CL.T_Named(RN.shadowVecTy n)
26        (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)              | Ty.TensorTy[n, m] => CL.T_Named(RN.shadowMatTy(n,m))
27          fun lvalueVar (env, x) = (case V.kind x              | Ty.ImageTy(ImageInfo.ImgInfo{dim, ...}) => CL.T_Named(RN.shadowImageTy dim)
28                 of IL.VK_Global => CL.mkVar(lookup(env, x))              | _ => raise Fail(concat["TreeToC.trType(", Ty.toString ty, ")"])
                 | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, x))  
                 | IL.VK_Local => CL.mkVar(lookup(env, x))  
               (* end case *))  
       (* translate a variable that occurs in an r-value context *)  
         fun rvalueVar (env, x) = (case V.kind x  
                of IL.VK_Global => CL.mkVar(lookup(env, x))  
                 | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x))  
                 | IL.VK_Local => CL.mkVar(lookup(env, x))  
29                (* end case *))                (* end case *))
       end  
30    
31          structure ToC = TreeToCFn (TrVar)    (* helper functions for specifying parameters in various address spaces *)
32        fun clParam (spc, ty, x) = CL.PARAM([spc], ty, x)
33        fun globalParam (ty, x) = CL.PARAM(["__global"], ty, x)
34        fun constantParam (ty, x) = CL.PARAM(["__constant"], ty, x)
35        fun localParam (ty, x) = CL.PARAM(["__local"], ty, x)
36        fun privateParam (ty, x) = CL.PARAM(["__private"], ty, x)
37    
38    (* C variable translation *)    (* C variable translation *)
39      structure TrCVar =      structure TrCVar =
# Line 50  Line 46 
46        (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)        (* translate a variable that occurs in an l-value context (i.e., as the target of an assignment) *)
47          fun lvalueVar (env, x) = (case V.kind x          fun lvalueVar (env, x) = (case V.kind x
48                 of IL.VK_Global => CL.mkIndirect(CL.mkVar RN.globalsVarName, lookup(env, x))                 of IL.VK_Global => CL.mkIndirect(CL.mkVar RN.globalsVarName, lookup(env, x))
49                  | IL.VK_State strand => raise Fail "unexpected strand context"                  | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfOut", lookup(env, x))
50                  | IL.VK_Local => CL.mkVar(lookup(env, x))                  | IL.VK_Local => CL.mkVar(lookup(env, x))
51                (* end case *))                (* end case *))
52        (* translate a variable that occurs in an r-value context *)        (* translate a variable that occurs in an r-value context *)
53          val rvalueVar = lvalueVar          fun rvalueVar (env, x) = (case V.kind x
54                   of IL.VK_Global => CL.mkIndirect(CL.mkVar RN.globalsVarName, lookup(env, x))
55                    | IL.VK_State strand => CL.mkIndirect(CL.mkVar "selfIn", lookup(env, x))
56                    | IL.VK_Local => CL.mkVar(lookup(env, x))
57                  (* end case *))
58        end        end
59    
60      structure ToC = TreeToCFn (TrCVar)      structure ToC = TreeToCFn (TrCVar)
# Line 64  Line 64 
64      type stm = CL.stm      type stm = CL.stm
65    
66    (* OpenCL specific types *)    (* OpenCL specific types *)
67        val clIntTy = CL.T_Named "cl_int"
68      val clProgramTy = CL.T_Named "cl_program"      val clProgramTy = CL.T_Named "cl_program"
69      val clKernelTy  = CL.T_Named "cl_kernel"      val clKernelTy  = CL.T_Named "cl_kernel"
70      val clCmdQueueTy = CL.T_Named "cl_command_queue"      val clCmdQueueTy = CL.T_Named "cl_command_queue"
# Line 71  Line 72 
72      val clDeviceIdTy = CL.T_Named "cl_device_id"      val clDeviceIdTy = CL.T_Named "cl_device_id"
73      val clPlatformIdTy = CL.T_Named "cl_platform_id"      val clPlatformIdTy = CL.T_Named "cl_platform_id"
74      val clMemoryTy = CL.T_Named "cl_mem"      val clMemoryTy = CL.T_Named "cl_mem"
75        val globPtrTy = CL.T_Ptr(CL.T_Named RN.globalsTy)
76    
77    (* variable or field that is mirrored between host and GPU *)    (* variable or field that is mirrored between host and GPU *)
78      type mirror_var = {      type mirror_var = {
79              hostTy : CL.ty,             (* variable type on Host (i.e., C type) *)              hostTy : CL.ty,             (* variable type on Host (i.e., C type) *)
80                shadowTy : CL.ty,           (* host-side shadow type of GPU type *)
81              gpuTy : CL.ty,              (* variable's type on GPU (i.e., OpenCL type) *)              gpuTy : CL.ty,              (* variable's type on GPU (i.e., OpenCL type) *)
82              var : CL.var                (* variable name *)              var : CL.var                (* variable name *)
83            }            }
# Line 97  Line 100 
100          topDecls : CL.decl list ref,          topDecls : CL.decl list ref,
101          strands : strand AtomTable.hash_table,          strands : strand AtomTable.hash_table,
102          initially :  CL.decl ref,          initially :  CL.decl ref,
103          numDims: int ref,          numDims: int ref,               (* number of dimensions in initially iteration *)
104          imgGlobals: (string * int) list ref,          imgGlobals: (string * int) list ref,
105          prFn: CL.decl ref          prFn: CL.decl ref
106        }        }
# Line 133  Line 136 
136          fun fragment (ENV{info, vMap, scope}, blk) = let          fun fragment (ENV{info, vMap, scope}, blk) = let
137                val (vMap, stms) = (case scope                val (vMap, stms) = (case scope
138                       of GlobalScope => ToC.trFragment (vMap, blk)                       of GlobalScope => ToC.trFragment (vMap, blk)
139    (* NOTE: if we move strand initialization to the GPU, then we'll have to change the following code! *)
140                          | InitiallyScope => ToC.trFragment (vMap, blk)
141                        | _ => ToCL.trFragment (vMap, blk)                        | _ => ToCL.trFragment (vMap, blk)
142                      (* end case *))                      (* end case *))
143                in                in
144                  (ENV{info=info, vMap=vMap, scope=scope}, stms)                  (ENV{info=info, vMap=vMap, scope=scope}, stms)
145                end                end
146          fun saveState cxt stateVars (env, args, stm) = (          fun block (ENV{vMap, scope, ...}, blk) = let
147                  fun saveState cxt stateVars trAssign (env, args, stm) = (
148                ListPair.foldrEq                ListPair.foldrEq
149                  (fn (x, e, stms) => ToCL.trAssign(env, x, e)@stms)                        (fn (x, e, stms) => trAssign(env, x, e)@stms)
150                    [stm]                    [stm]
151                      (stateVars, args)                      (stateVars, args)
152                ) handle ListPair.UnequalLengths => (                ) handle ListPair.UnequalLengths => (
153                  print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);                  print(concat["saveState ", cxt, ": length mismatch; ", Int.toString(List.length args), " args\n"]);
154                  raise Fail(concat["saveState ", cxt, ": length mismatch"]))                  raise Fail(concat["saveState ", cxt, ": length mismatch"]))
155          fun block (ENV{vMap, scope, ...}, blk) = (case scope                in
156                 of StrandScope stateVars => ToCL.trBlock (vMap, saveState "StrandScope" stateVars, blk)                  case scope
157                  | MethodScope stateVars => ToCL.trBlock (vMap, saveState "MethodScope" stateVars, blk)  (* NOTE: if we move strand initialization to the GPU, then we'll have to change the following code! *)
158                     of StrandScope stateVars =>
159                          ToCL.trBlock (vMap, saveState "StrandScope" stateVars ToCL.trAssign, blk)
160                      | MethodScope stateVars =>
161                          ToCL.trBlock (vMap, saveState "MethodScope" stateVars ToCL.trAssign, blk)
162                  | InitiallyScope => ToCL.trBlock (vMap, fn (_, _, stm) => [stm], blk)                  | InitiallyScope => ToCL.trBlock (vMap, fn (_, _, stm) => [stm], blk)
163                  | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)                  | _ => ToC.trBlock (vMap, fn (_, _, stm) => [stm], blk)
164                (* end case *))                  (* end case *)
165                  end
166          fun exp (ENV{vMap, ...}, e) = ToCL.trExp(vMap, e)          fun exp (ENV{vMap, ...}, e) = ToCL.trExp(vMap, e)
167        end        end
168    
169    (* variables *)    (* variables *)
170      structure Var =      structure Var =
171        struct        struct
172            fun mirror (ty, name) = {
173                    hostTy = ToC.trType ty,
174                    shadowTy = shadowTy ty,
175                    gpuTy = ToCL.trType ty,
176                    var = name
177                  }
178          fun name (ToCL.V(_, name)) = name          fun name (ToCL.V(_, name)) = name
179          fun global (Prog{globals, imgGlobals, ...}, name, ty) = let          fun global (Prog{globals, imgGlobals, ...}, name, ty) = let
180                val x = {hostTy = ToC.trType ty, gpuTy = ToCL.trType ty, var = name}                val x = mirror (ty, name)
181                fun isImgGlobal (Ty.ImageTy(ImageInfo.ImgInfo{dim, ...}), name) =                fun isImgGlobal (Ty.ImageTy(ImageInfo.ImgInfo{dim, ...}), name) =
182                      imgGlobals  := (name,dim) :: !imgGlobals                      imgGlobals  := (name,dim) :: !imgGlobals
183                  | isImgGlobal _ =  ()                  | isImgGlobal _ =  ()
# Line 172  Line 189 
189          fun param x = ToCL.V(ToCL.trType(V.ty x), V.name x)          fun param x = ToCL.V(ToCL.trType(V.ty x), V.name x)
190          fun state (Strand{state, ...}, x) = let          fun state (Strand{state, ...}, x) = let
191                val ty = V.ty x                val ty = V.ty x
192                val x' = {hostTy = ToC.trType ty, gpuTy = ToCL.trType ty, var = V.name x}                val x' = mirror (ty, V.name x)
193                in                in
194                  state := x' :: !state;                  state := x' :: !state;
195                  ToCL.V(#gpuTy x', #var x')                  ToCL.V(#gpuTy x', #var x')
# Line 215  Line 232 
232                    topDecls = ref [],                    topDecls = ref [],
233                    strands = AtomTable.mkTable (16, Fail "strand table"),                    strands = AtomTable.mkTable (16, Fail "strand table"),
234                    initially = ref(CL.D_Comment["missing initially"]),                    initially = ref(CL.D_Comment["missing initially"]),
235                    numDims = ref(0),                    numDims = ref 0,
236                    imgGlobals = ref[],                    imgGlobals = ref[],
237                    prFn = ref(CL.D_Comment(["No Print Function"]))                    prFn = ref(CL.D_Comment(["No Print Function"]))
238                  })                  })
239        (* register the global initialization part of a program *)  
 (* FIXME: unused code; can this be removed??  
           fun globalIndirects (globals,stms) = let  
                 fun getGlobals ({name,target as TargetUtil.TARGET_CL}::rest) =  
                       CL.mkAssign(CL.mkIndirect(CL.mkVar RN.globalsVarName,name),CL.mkVar name)  
                         ::getGlobals rest  
                   | getGlobals [] = []  
                   | getGlobals (_::rest) = getGlobals rest  
                 in  
                   stms @ getGlobals globals  
                 end  
 *)  
240        (* register the code that is used to register command-line options for input variables *)        (* register the code that is used to register command-line options for input variables *)
241          fun inputs (Prog{topDecls, ...}, stm) = let          fun inputs (Prog{topDecls, ...}, stm) = let
242                val inputsFn = CL.D_Func(                val inputsFn = CL.D_Func(
# Line 243  Line 249 
249    
250        (* register the global initialization part of a program *)        (* register the global initialization part of a program *)
251          fun init (Prog{topDecls, ...}, init) = let          fun init (Prog{topDecls, ...}, init) = let
252                val globPtrTy = CL.T_Ptr(CL.T_Named RN.globalsTy)                val globalsDecl = CL.mkAssign(CL.E_Var RN.globalsVarName,
253                        CL.mkApply("malloc", [CL.mkSizeof(CL.T_Named RN.globalsTy)]))
254                val initFn = CL.D_Func(                val initFn = CL.D_Func(
255                      [], CL.voidTy, RN.initGlobals, [CL.PARAM([], globPtrTy, RN.globalsVarName)],                      [], CL.voidTy, RN.initGlobals, [],
256                        CL.mkBlock[
257                            globalsDecl,
258                            CL.mkCall(RN.initGlobalsHelper, [CL.mkVar RN.globalsVarName])
259                          ])
260                  val initHelperFn = CL.D_Func(
261                        [], CL.voidTy, RN.initGlobalsHelper,
262                        [CL.PARAM([], globPtrTy, RN.globalsVarName)],
263                      init)                      init)
264                val shutdownFn = CL.D_Func(                val shutdownFn = CL.D_Func(
265                      [], CL.voidTy, RN.shutdown,                      [], CL.voidTy, RN.shutdown,
266                      [CL.PARAM([], CL.T_Ptr(CL.T_Named RN.worldTy), "wrld")],                      [CL.PARAM([], CL.T_Ptr(CL.T_Named RN.worldTy), "wrld")],
267                      CL.S_Block[])                      CL.S_Block[])
268                in                in
269                  topDecls := shutdownFn :: initFn :: !topDecls                  topDecls := shutdownFn :: initFn :: initHelperFn :: !topDecls
270                end                end
271    
272           (* create and register the initially function for a program *)           (* create and register the initially function for a program *)
273          fun initially {          fun initially {
274                prog = Prog{name=progName, strands, initially, ...},                prog = Prog{name=progName, strands, initially, numDims, ...},
275                isArray : bool,                isArray : bool,
276                iterPrefix : stm list,                iterPrefix : stm list,
277                iters : (var * exp * exp) list,                iters : (var * exp * exp) list,
# Line 285  Line 300 
300                        CL.mkDecl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),                        CL.mkDecl(CL.T_Array(CL.int32, SOME nDims), "base", SOME(CL.I_Array baseInit)),
301                        CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),                        CL.mkDecl(CL.T_Array(CL.uint32, SOME nDims), "size", SOME(CL.I_Array sizeInit)),
302                        CL.mkDecl(worldTy, wrld,                        CL.mkDecl(worldTy, wrld,
303                          SOME(CL.I_Exp(CL.E_Apply(N.allocInitially, [                          SOME(CL.I_Exp(CL.E_Apply(RN.allocInitially, [
304                              CL.mkVar "ProgramName",                              CL.mkVar "ProgramName",
305                              CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)),                              CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)),
306                              CL.E_Bool isArray,                              CL.E_Bool isArray,
# Line 294  Line 309 
309                              CL.E_Var "size"                              CL.E_Var "size"
310                            ]))))                            ]))))
311                      ]                      ]
312              (* create the loop nest for the initially iterations *)              (* create the loop nest for the initially iterations
313                val indexVar = "ix"                val indexVar = "ix"
314                val strandTy = CL.T_Ptr(CL.T_Named(N.strandTy name))                val strandTy = CL.T_Ptr(CL.T_Named(N.strandTy name))
315                fun mkLoopNest [] = CL.mkBlock(createPrefix @ [                fun mkLoopNest [] = CL.mkBlock(createPrefix @ [
# Line 302  Line 317 
317                          SOME(CL.I_Exp(                          SOME(CL.I_Exp(
318                            CL.E_Cast(strandTy,                            CL.E_Cast(strandTy,
319                            CL.E_Apply(N.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),                            CL.E_Apply(N.inState, [CL.E_Var "wrld", CL.E_Var indexVar]))))),
320                        CL.mkCall(N.strandInit name, CL.E_Var "sp" :: args),                        CL.mkCall(N.strandInit name,
321                            CL.E_Var RN.globalsVarName :: CL.E_Var "sp" :: args),
322                        CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))                        CL.mkAssign(CL.E_Var indexVar, CL.mkBinOp(CL.E_Var indexVar, CL.#+, CL.E_Int(1, CL.uint32)))
323                      ])                      ])
324                  | mkLoopNest ((CL.V(ty, param), lo, hi)::iters) = let                  | mkLoopNest ((CL.V(ty, param), lo, hi)::iters) = let
# Line 318  Line 334 
334                        CL.mkComment["initially"],                        CL.mkComment["initially"],
335                        CL.mkDecl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),                        CL.mkDecl(CL.uint32, indexVar, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
336                        mkLoopNest iters                        mkLoopNest iters
337                      ]                      ] *)
338                val body = CL.mkBlock(                val body = CL.mkBlock(
339                      iterPrefix @                      iterPrefix @
340                      allocCode @                      allocCode @
                     iterCode @  
341                      [CL.mkReturn(SOME(CL.E_Var "wrld"))])                      [CL.mkReturn(SOME(CL.E_Var "wrld"))])
342                val initFn = CL.D_Func([], worldTy, N.initially, [], body)                val initFn = CL.D_Func([], worldTy, N.initially, [], body)
343                in                in
344                    numDims := nDims;
345                  initially := initFn                  initially := initFn
346                end                end
347    
348        (***** OUTPUT *****)        (***** OUTPUT *****)
349          fun genStrandPrint (Strand{name, tyName, state, output, code,...}) = let          fun genStrandPrint (Strand{name, tyName, state, output, code,...}) = let
350              (* the print function *)              (* the print function *)
351                val prFnName = concat[name, "_print"]                val prFnName = concat[name, "Print"]
352                val prFn = let                val prFn = let
353                      val params = [                      val params = [
354                              CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),                              CL.PARAM([], CL.T_Ptr(CL.T_Named "FILE"), "outS"),
355                                CL.PARAM([], CL.T_Ptr(CL.T_Num(RawTypes.RT_UInt8)),"status"),
356                                CL.PARAM([], CL.intTy,"numStrands"),
357                              CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")                              CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "self")
358                            ]                            ]
359                      val SOME(ty, x) = !output                      val SOME(ty, x) = !output
360                      val outState = CL.mkIndirect(CL.mkVar "self", x)                      val outState = CL.mkSelect(CL.mkSubscript(CL.mkVar "self", CL.E_Var "i"), x)
361                      val prArgs = (case ty                      val prArgs = (case ty
362                             of Ty.IVecTy 1 => [CL.E_Str(!N.gIntFormat ^ "\n"), outState]                             of Ty.IVecTy 1 => [CL.E_Str(!N.gIntFormat ^ "\n"), outState]
363                              | Ty.IVecTy d => let                              | Ty.IVecTy d => let
364                                  val fmt = CL.E_Str(                                  val fmt = CL.mkStr(
365                                        String.concatWith " " (List.tabulate(d, fn _ => !N.gIntFormat))                                        String.concatWith " " (List.tabulate(d, fn _ => !N.gIntFormat))
366                                        ^ "\n")                                        ^ "\n")
367                                  val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))                                  val args = List.tabulate (d, fn i => ToC.ivecIndex(outState, d, i))
368                                  in                                  in
369                                    fmt :: args                                    fmt :: args
370                                  end                                  end
371                              | Ty.TensorTy[] => [CL.E_Str "%f\n", outState]                              | Ty.TensorTy[] => [CL.mkStr "%f\n", outState]
372                              | Ty.TensorTy[d] => let                              | Ty.TensorTy[d] => let
373                                  val fmt = CL.E_Str(                                  val fmt = CL.mkStr(
374                                        String.concatWith " " (List.tabulate(d, fn _ => "%f"))                                        String.concatWith " " (List.tabulate(d, fn _ => "%f"))
375                                        ^ "\n")                                        ^ "\n")
376                                  val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))                                  val args = List.tabulate (d, fn i => ToC.vecIndex(outState, d, i))
# Line 361  Line 379 
379                                  end                                  end
380                              | _ => raise Fail("genStrand: unsupported output type " ^ Ty.toString ty)                              | _ => raise Fail("genStrand: unsupported output type " ^ Ty.toString ty)
381                            (* end case *))                            (* end case *))
382                        val forBody = CL.mkIfThen(
383                              CL.mkBinOp(CL.mkSubscript(CL.E_Var "status",CL.E_Var "i"), CL.#==, CL.E_Var "DIDEROT_STABILIZE"),
384                              CL.mkBlock([CL.mkCall("fprintf", CL.mkVar "outS" :: prArgs)]))
385                        val body =  CL.mkFor(
386                            [(CL.intTy, "i", CL.mkInt 0)],
387                            CL.mkBinOp(CL.E_Var "i", CL.#<, CL.E_Var "numStrands"),
388                            [CL.mkPostOp(CL.E_Var "i", CL.^++)],
389                            forBody)
390                      in                      in
391                        CL.D_Func(["static"], CL.voidTy, prFnName, params,                        CL.D_Func(["static"], CL.voidTy, prFnName, params, body)
                         CL.mkCall("fprintf", CL.mkVar "outS" :: prArgs))  
392                      end                      end
393                in                in
394                  prFn                  prFn
# Line 375  Line 400 
400                  List.rev (List.map (fn x => (targetTy x, #var x)) (!state)),                  List.rev (List.map (fn x => (targetTy x, #var x)) (!state)),
401                  tyName)                  tyName)
402    
403        (* generates the load kernel function *)          fun genStrandCopy(Strand{tyName,name,state,...}) = let
404                  val params = [
405                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
406                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Named tyName), "selfOut")
407                      ]
408                  val assignStms = List.rev(List.map(fn x => CL.mkAssign(CL.mkIndirect(CL.E_Var "selfOut", #var x),                                                                                                                          CL.mkIndirect(CL.E_Var "selfIn", #var x))) (!state))
409                  in
410                    CL.D_Func([""], CL.voidTy, RN.strandCopy name, params,CL.mkBlock(assignStms))
411                  end
412    
413        (* generates the opencl buffers for the image data *)        (* generates the opencl buffers for the image data *)
414          fun getGlobalDataBuffers(globals,contextVar,errVar) = let          fun getGlobalDataBuffers(globals,contextVar,errVar) = let
415                  val globalBuffErr = "error creating OpenCL global buffer"
416                  fun errorFn msg = CL.mkIfThen(CL.mkBinOp(CL.E_Var errVar, CL.#!=, CL.E_Var "CL_SUCCESS"),
417                        CL.mkBlock([CL.mkCall("fprintf",[CL.E_Var "stderr", CL.E_Str msg]),
418                        CL.mkCall("exit",[CL.mkInt 1])]))
419                val globalBufferDecl =  CL.mkDecl(clMemoryTy,concat[RN.globalsVarName,"_cl"],NONE)                val globalBufferDecl =  CL.mkDecl(clMemoryTy,concat[RN.globalsVarName,"_cl"],NONE)
420                val globalBuffer = CL.mkAssign(CL.mkVar(concat[RN.globalsVarName,"_cl"]),                val globalBuffer = CL.mkAssign(CL.mkVar(concat[RN.globalsVarName,"_cl"]),
421                      CL.mkApply("clCreateBuffer", [                      CL.mkApply("clCreateBuffer", [
422                          CL.mkVar contextVar,                          CL.mkVar contextVar,
423                          CL.mkVar "CL_MEM_COPY_HOST_PTR",                          CL.mkVar "CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR",
424                          CL.mkApply("sizeof",[CL.mkVar RN.globalsTy]),                          CL.mkSizeof(CL.T_Named RN.globalsTy),
425                          CL.mkVar RN.globalsVarName,                          CL.mkVar RN.globalsVarName,
426                          CL.mkUnOp(CL.%&,CL.mkVar errVar)                          CL.mkUnOp(CL.%&,CL.mkVar errVar)
427                        ]))                        ]))
428                fun genDataBuffers([],_,_) = []                fun genDataBuffers ([],_,_,_) = []
429                  | genDataBuffers((var,nDims)::globals,contextVar,errVar) = let                  | genDataBuffers ((var,nDims)::globals, contextVar, errVar,errFn) = let
430                        val hostVar = CL.mkIndirect(CL.mkVar RN.globalsVarName, var)
431  (* FIXME: use CL constructors to  build expressions (not strings) *)  (* FIXME: use CL constructors to  build expressions (not strings) *)
432                      val size = if nDims = 1                      fun sizeExp i = CL.mkSubscript(CL.mkIndirect(hostVar, "size"), CL.mkInt i)
433                              then CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*,  (* FIXME: there is no reason that images have to be restricted to float elements! *)
434                                CL.mkIndirect(CL.mkVar var, "size[0]"))                      val size = CL.mkBinOp(CL.mkSizeof(CL.float), CL.#*, sizeExp 0)
435                            else if nDims = 2                      val size = if (nDims > 1)
436                              then CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*,                            then CL.mkBinOp(size, CL.#*, sizeExp 1)
437                                CL.mkIndirect(CL.mkVar var, concat["size[0]", " * ", var, "->size[1]"]))                            else size
438                              else CL.mkBinOp(CL.mkApply("sizeof",[CL.mkVar "float"]), CL.#*,                      val size = if (nDims > 2)
439                                CL.mkIndirect(CL.mkVar var,concat["size[0]", " * ", var, "->size[1] * ", var, "->size[2]"]))                            then CL.mkBinOp(size, CL.#*, sizeExp 2)
440                              else size
441                      in                      in
442                        CL.mkDecl(clMemoryTy, RN.addBufferSuffix var ,NONE)::                        CL.mkDecl(clMemoryTy, RN.addBufferSuffix var ,NONE)::
443                        CL.mkDecl(clMemoryTy, RN.addBufferSuffixData var ,NONE)::                        CL.mkDecl(clMemoryTy, RN.addBufferSuffixData var ,NONE)::
444                        CL.mkAssign(CL.mkVar(RN.addBufferSuffix var), CL.mkApply("clCreateBuffer",                        CL.mkAssign(CL.mkVar(RN.addBufferSuffix var),
445                          [CL.mkVar contextVar,                          CL.mkApply("clCreateBuffer", [
446                                CL.mkVar contextVar,
447                          CL.mkVar "CL_MEM_COPY_HOST_PTR",                          CL.mkVar "CL_MEM_COPY_HOST_PTR",
448                          CL.mkApply("sizeof",[CL.mkVar (RN.imageTy nDims)]),                              CL.mkSizeof(CL.T_Named(RN.imageTy nDims)),
449                          CL.mkVar var,                              hostVar,
450                          CL.mkUnOp(CL.%&,CL.mkVar errVar)])) ::                              CL.mkUnOp(CL.%&,CL.mkVar errVar)
451                        CL.mkAssign(CL.mkVar(RN.addBufferSuffixData var), CL.mkApply("clCreateBuffer",                            ])) ::
452                          [CL.mkVar contextVar,                        errFn(concat["error in creating ",RN.addBufferSuffix var, " global buffer"]) ::
453                          CL.mkAssign(CL.mkVar(RN.addBufferSuffixData var),
454                            CL.mkApply("clCreateBuffer", [
455                                CL.mkVar contextVar,
456                          CL.mkVar "CL_MEM_COPY_HOST_PTR",                          CL.mkVar "CL_MEM_COPY_HOST_PTR",
457                          size,                          size,
458                          CL.mkIndirect(CL.mkVar var,"data"),                              CL.mkIndirect(hostVar, "data"),
459                          CL.mkUnOp(CL.%&,CL.mkVar errVar)])):: genDataBuffers(globals,contextVar,errVar)                              CL.mkUnOp(CL.%&,CL.mkVar errVar)
460                              ])) ::
461                            errFn(concat["error in creating ",RN.addBufferSuffixData var, " global buffer"]) ::
462                            genDataBuffers(globals,contextVar,errVar,errFn)
463                      end                      end
464                in                in
465                  globalBufferDecl :: globalBuffer :: genDataBuffers(globals,contextVar,errVar)                  globalBufferDecl
466                    :: globalBuffer
467                    :: errorFn(globalBuffErr)
468                    :: genDataBuffers(globals,contextVar,errVar,errorFn)
469                end                end
470    
   
471  (* generates the kernel arguments for the image data *)  (* generates the kernel arguments for the image data *)
472          fun genGlobalArguments(globals,count,kernelVar,errVar) = let          fun genGlobalArguments(globals,count,kernelVar,errVar) = let
473          val globalArgument = CL.mkExpStm(CL.mkAssignOp(CL.mkVar errVar,CL.|=,CL.mkApply("clSetKernelArg",                val globalArgErr = "error creating OpenCL global argument"
474                  fun errorFn msg = CL.mkIfThen(CL.mkBinOp(CL.E_Var errVar, CL.#!=, CL.E_Var "CL_SUCCESS"),
475                        CL.mkBlock([CL.mkCall("fprintf",[CL.E_Var "stderr", CL.E_Str msg]),
476                        CL.mkCall("exit",[CL.mkInt 1])]))
477                  val globalArgument = CL.mkExpStm(CL.mkAssignOp(CL.mkVar errVar,CL.|=,
478                        CL.mkApply("clSetKernelArg",
479                                                                  [CL.mkVar kernelVar,                                                                  [CL.mkVar kernelVar,
480                                                                   CL.mkPostOp(CL.E_Var count, CL.^++),                                                                   CL.mkPostOp(CL.E_Var count, CL.^++),
481                                                                   CL.mkApply("sizeof",[CL.mkVar "cl_mem"]),                                                                   CL.mkApply("sizeof",[CL.mkVar "cl_mem"]),
482                                                                   CL.mkUnOp(CL.%&,CL.mkVar(concat[RN.globalsVarName,"_cl"]))])))                                                                   CL.mkUnOp(CL.%&,CL.mkVar(concat[RN.globalsVarName,"_cl"]))])))
483                  fun genDataArguments ([],_,_,_,_) = []
484          fun genDataArguments([],_,_,_) = []                  | genDataArguments ((var,nDims)::globals,count,kernelVar,errVar,errFn) =
485            | genDataArguments((var,nDims)::globals,count,kernelVar,errVar) =                      CL.mkExpStm(CL.mkAssignOp(CL.mkVar errVar,CL.$=,
486                          CL.mkApply("clSetKernelArg",
                 CL.mkExpStm(CL.mkAssignOp(CL.mkVar errVar,CL.|=, CL.mkApply("clSetKernelArg",  
487                                  [CL.mkVar kernelVar,                                  [CL.mkVar kernelVar,
488                                   CL.mkPostOp(CL.E_Var count, CL.^++),                                   CL.mkPostOp(CL.E_Var count, CL.^++),
489                                   CL.mkApply("sizeof",[CL.mkVar "cl_mem"]),                                   CL.mkApply("sizeof",[CL.mkVar "cl_mem"]),
490                                   CL.mkUnOp(CL.%&,CL.mkVar(RN.addBufferSuffix var))])))::                                   CL.mkUnOp(CL.%&,CL.mkVar(RN.addBufferSuffix var))])))::
491                             errFn(concat["error in creating ",RN.addBufferSuffix var, " argument"]) ::
492                          CL.mkExpStm(CL.mkAssignOp(CL.mkVar errVar,CL.|=,CL.mkApply("clSetKernelArg",                      CL.mkExpStm(CL.mkAssignOp(CL.mkVar errVar,CL.$=,
493                          CL.mkApply("clSetKernelArg",
494                                  [CL.mkVar kernelVar,                                  [CL.mkVar kernelVar,
495                                   CL.mkPostOp(CL.E_Var count, CL.^++),                                   CL.mkPostOp(CL.E_Var count, CL.^++),
496                                   CL.mkApply("sizeof",[CL.mkVar "cl_mem"]),                                   CL.mkApply("sizeof",[CL.mkVar "cl_mem"]),
497                                   CL.mkUnOp(CL.%&,CL.mkVar(RN.addBufferSuffixData var))]))):: genDataArguments (globals,count,kernelVar,errVar)                           CL.mkUnOp(CL.%&,CL.mkVar(RN.addBufferSuffixData var))]))) ::
498                             errFn(concat["error in creating ",RN.addBufferSuffixData var, " argument"]) ::
499                        genDataArguments (globals,count,kernelVar,errVar,errFn)
500          in          in
501                   [globalArgument,errorFn(globalArgErr)] @ genDataArguments(globals, count, kernelVar, errVar,errorFn)
                 [globalArgument] @ genDataArguments(globals,count,kernelVar,errVar)  
   
502          end          end
503    
504        (* generates the globals buffers and arguments function *)        (* generates the globals buffers and arguments function *)
505          fun genGlobalBuffersArgs (imgGlobals) = let          fun genGlobalBuffersArgs imgGlobals = let
506              (* Delcare opencl setup objects *)              (* Delcare opencl setup objects *)
507                val errVar = "err"                val errVar = "err"
508                val imgDataSizeVar = "image_dataSize"                val imgDataSizeVar = "image_dataSize"
509                val params = [                val params = [
510                        CL.PARAM([],CL.T_Named("cl_context"), "context"),                        CL.PARAM([],CL.T_Named("cl_context"), "context"),
511                        CL.PARAM([],CL.T_Named("cl_kernel"), "kernel"),                        CL.PARAM([],CL.T_Named("cl_kernel"), "kernel"),
512                          CL.PARAM([],CL.T_Named("cl_command_queue"), "cmdQ"),
513                        CL.PARAM([],CL.T_Named("int"), "argStart")                        CL.PARAM([],CL.T_Named("int"), "argStart")
514                      ]                      ]
515                val clGlobalBuffers = getGlobalDataBuffers(!imgGlobals, "context", "err")                val clGlobalBuffers = getGlobalDataBuffers(!imgGlobals, "context", errVar)
516                val clGlobalArguments = genGlobalArguments(!imgGlobals, "argStart", "kernel", "err")                val clGlobalArguments = genGlobalArguments(!imgGlobals, "argStart", "kernel", errVar)
517              (* Body put all the statments together *)              (* Body put all the statments together *)
518                val body =  clGlobalBuffers @ clGlobalArguments                val body = CL.mkDecl(clIntTy, errVar, SOME(CL.I_Exp(CL.mkInt 0)))
519                        :: clGlobalBuffers @ clGlobalArguments
520                in                in
521                  CL.D_Func([],CL.voidTy,RN.globalsSetupName,params,CL.mkBlock(body))                  CL.D_Func([],CL.voidTy,RN.globalsSetupName,params,CL.mkBlock(body))
522                end                end
523    
524        (* generate the data and global parameters *)        (* generate the data and global parameters *)
525          fun genKeneralGlobalParams ((name,tyname)::rest) =          fun genKeneralGlobalParams ((name,tyname)::rest) =
526                CL.PARAM([], CL.T_Ptr(CL.T_Named RN.globalsTy), concat[RN.globalsVarName]) ::                globalParam (CL.T_Ptr(CL.T_Named (RN.imageTy tyname)), RN.addBufferSuffix name) ::
527                CL.PARAM([], CL.T_Ptr(CL.T_Named (RN.imageTy tyname)),RN.addBufferSuffix name) ::                globalParam (CL.T_Ptr(CL.voidTy), RN.addBufferSuffixData name) ::
               CL.PARAM([], CL.T_Ptr(CL.voidTy),RN.addBufferSuffixData name) ::  
528                genKeneralGlobalParams rest                genKeneralGlobalParams rest
529            | genKeneralGlobalParams [] = []            | genKeneralGlobalParams [] = []
530    
531        (*generate code for intilizing kernel global data *)        (*generate code for intilizing kernel global data *)
         fun initKernelGlobals (globals, imgGlobals) = let  
532  (* FIXME: should use List.map here *)  (* FIXME: should use List.map here *)
               fun initGlobalStruct ({hostTy, gpuTy, var}::rest) =  
                     CL.mkAssign(CL.mkVar var, CL.mkIndirect(CL.mkVar RN.globalsVarName, var)) ::  
                     initGlobalStruct rest  
                 | initGlobalStruct [] = []  
533                fun initGlobalImages ((name, tyname)::rest) =                fun initGlobalImages ((name, tyname)::rest) =
534                      CL.mkAssign(CL.mkVar name, CL.mkVar (RN.addBufferSuffix name)) ::                CL.mkAssign(
535                      CL.mkAssign(CL.mkIndirect(CL.mkVar name,"data"),CL.mkVar (RN.addBufferSuffixData name)) ::                  CL.mkIndirect(CL.E_Var RN.globalsVarName, name),
536                    CL.mkVar (RN.addBufferSuffix name)) ::
537                  CL.mkAssign(
538                    CL.mkIndirect(CL.mkIndirect(CL.E_Var RN.globalsVarName, name), "data"),
539                    CL.mkVar (RN.addBufferSuffixData name)) ::
540                      initGlobalImages rest                      initGlobalImages rest
541                    | initGlobalImages [] = []                    | initGlobalImages [] = []
               in  
                 initGlobalStruct globals @ initGlobalImages(imgGlobals)  
               end  
542    
543          (* generate the main kernel function for the .cl file *)          (* generate the main kernel function for the .cl file *)
544          fun genKernelFun (strand, nDims, globals, imgGlobals) = let          fun genKernelFun (strand, nDims, globals, imgGlobals) = let
# Line 497  Line 546 
546                val fName = RN.kernelFuncName;                val fName = RN.kernelFuncName;
547                val inState = "strand_in"                val inState = "strand_in"
548                val outState = "strand_out"                val outState = "strand_out"
549                  val tempVar = "tmp"
550                val params = [                val params = [
551                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Named tyName), "selfIn"),                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Named tyName), "selfIn"),
552                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Named tyName), "selfOut"),                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Named tyName), "selfOut"),
553                        CL.PARAM(["__global"], CL.intTy, "width")                        CL.PARAM(["__global"], CL.T_Ptr(CL.T_Num(RawTypes.RT_UInt8)), "strandStatus"),
554                          CL.PARAM(["__global"], CL.intTy, "width"),
555                          CL.PARAM(["__global"], globPtrTy, RN.globalsVarName)
556                      ] @ genKeneralGlobalParams(!imgGlobals)                      ] @ genKeneralGlobalParams(!imgGlobals)
557                val thread_ids = if nDims = 1                val thread_ids = if nDims = 1
558                      then [                      then [
559                          CL.mkDecl(CL.intTy, "x", SOME(CL.I_Exp(CL.mkInt(0, CL.intTy)))),                            CL.mkDecl(CL.intTy, "x",
560                          CL.mkAssign(CL.mkVar "x",CL.mkApply(RN.getGlobalThreadId,[CL.mkInt(0,CL.intTy)]))                              SOME(CL.I_Exp(CL.mkApply(RN.getGlobalThreadId,[CL.mkInt 0]))))
561                        ]                        ]
562                      else [                      else if nDims = 2
563                          CL.mkDecl(CL.intTy, "x", SOME(CL.I_Exp(CL.mkInt(0, CL.intTy)))),                        then [
564                          CL.mkDecl(CL.intTy, "y", SOME(CL.I_Exp(CL.mkInt(0, CL.intTy)))),                            CL.mkDecl(CL.intTy, "x",
565                          CL.mkAssign(CL.mkVar "x",  CL.mkApply(RN.getGlobalThreadId,[CL.mkInt(0,CL.intTy)])),                              SOME(CL.I_Exp(CL.mkApply(RN.getGlobalThreadId,[CL.mkInt 0])))),
566                          CL.mkAssign(CL.mkVar "y",CL.mkApply(RN.getGlobalThreadId,[CL.mkInt(1,CL.intTy)]))                            CL.mkDecl(CL.intTy, "y",
567                                SOME(CL.I_Exp(CL.mkApply(RN.getGlobalThreadId,[CL.mkInt 1]))))
568                        ]                        ]
569                        else raise Fail "nDims > 2"
570                val strandDecl = [                val strandDecl = [
571                      CL.mkDecl(CL.T_Named tyName, inState, NONE),                        CL.mkDecl(CL.T_Ptr(CL.T_Named (concat["__global ",tyName])), inState, NONE),
572                      CL.mkDecl(CL.T_Named tyName, outState,NONE)]                        CL.mkDecl(CL.T_Ptr(CL.T_Named (concat["__global ",tyName])), outState, NONE),
573                val strandObjects  = if nDims = 1                        CL.mkDecl(CL.T_Ptr(CL.T_Named (concat["__global ",tyName])), tempVar, NONE)
                         then [  
                             CL.mkAssign( CL.mkVar inState, CL.mkSubscript(CL.mkVar "selfIn", CL.mkStr "x")),  
                             CL.mkAssign(CL.mkVar outState,CL.mkSubscript(CL.mkVar "selfOut", CL.mkStr "x"))  
574                            ]                            ]
575                          else let                val barrierCode = CL.mkCall(RN.strandCopy name, [CL.E_Var outState, CL.E_Var inState])
576                                  val index = CL.mkBinOp(CL.mkBinOp(CL.mkVar "x",CL.#*,CL.mkVar "width"),CL.#+,CL.mkVar "y")                val barrierStm = CL.mkCall("barrier",[CL.E_Var "CLK_LOCAL_MEM_FENCE"])
577                                  in                val index = if nDims = 1 then
578                                          [CL.mkAssign(CL.mkVar inState, CL.mkSubscript(CL.mkVar "selfIn",index)),                          CL.mkStr "x"
579                                           CL.mkAssign(CL.mkVar outState,CL.mkSubscript(CL.mkVar "selfOut",index))]                      else
580                                  end                          CL.mkBinOp(
581                val status = CL.mkDecl(CL.intTy, "status", SOME(CL.I_Exp(CL.mkInt(0, CL.intTy))))                              CL.mkBinOp(CL.mkVar "x", CL.#*, CL.mkVar "width"), CL.#+, CL.mkVar "y")
582                val local_vars =  thread_ids @ initKernelGlobals(!globals,!imgGlobals)  @ strandDecl @ strandObjects @ [status]  
583                val while_exp = CL.mkBinOp(CL.mkBinOp(CL.mkVar "status",CL.#!=, CL.mkVar RN.kStabilize),CL.#||,CL.mkBinOp(CL.mkVar "status", CL.#!=, CL.mkVar RN.kDie))                val strandObjects =
584                val whileBody = CL.mkBlock [                       [ CL.mkAssign(CL.mkVar inState,  CL.mkBinOp(CL.mkVar "selfIn",CL.#+,index)),
585                           CL.mkAssign(CL.mkVar outState, CL.mkBinOp(CL.mkVar "selfOut",CL.#+,index))
586                         ]
587    
588                    val stabalizeStm = CL.mkAssign(CL.mkSubscript(CL.mkVar "strandStatus",index),
589                                                                            CL.E_Var "status")
590                  val status = CL.mkDecl(CL.intTy, "status", SOME(CL.I_Exp(CL.mkSubscript(CL.mkVar "strandStatus",index))))
591                  val strandInitStm = CL.mkCall(RN.strandInit name, [
592                          CL.E_Var RN.globalsVarName,
593                          CL.E_Var outState,
594                          CL.E_Var "x",
595    (* FIXME: if nDims = 1, then "y" is not defined! the arguments to this call should really come from
596     * the initially code!
597     *)
598                          CL.E_Var "y"])
599                  val local_vars = thread_ids
600                        @ initGlobalImages(!imgGlobals)
601                        @ strandDecl
602                        @ strandObjects
603                        @ [strandInitStm,status]
604                  val while_exp = CL.mkBinOp(CL.mkVar "status",CL.#==, CL.mkVar RN.kActive)
605                  val whileBody = CL.mkBlock ([barrierCode,barrierStm] @ [
606                        CL.mkAssign(CL.mkVar "status",                        CL.mkAssign(CL.mkVar "status",
607                          CL.mkApply(RN.strandUpdate name,                          CL.mkApply(RN.strandUpdate name,
608                            [CL.mkUnOp(CL.%&,CL.mkVar inState), CL.mkUnOp(CL.%&,CL.mkVar outState)])),                            [CL.mkVar inState, CL.mkVar outState,CL.E_Var RN.globalsVarName]))] )
                       CL.mkCall(RN.strandStabilize name,  
                         [CL.mkUnOp(CL.%&,CL.mkVar inState), CL.mkUnOp(CL.%&,CL.mkVar outState)])  
                     ]  
609                val whileBlock = [CL.mkWhile(while_exp, whileBody)]                val whileBlock = [CL.mkWhile(while_exp, whileBody)]
610                val body = CL.mkBlock(local_vars  @ whileBlock)                val body = CL.mkBlock(local_vars @ whileBlock @ [stabalizeStm])
611                in                in
612                  CL.D_Func(["__kernel"], CL.voidTy, fName, params, body)                  CL.D_Func(["__kernel"], CL.voidTy, fName, params, body)
613                end                end
614        (* generate a global structure from the globals *)  
615          fun genGlobalStruct (targetTy, globals) = let        (* generate a global structure type definition from the list of globals *)
616            fun genGlobalStruct (targetTy, globals, tyName) = let
617                val globs = List.map (fn (x : mirror_var) => (targetTy x, #var x)) globals                val globs = List.map (fn (x : mirror_var) => (targetTy x, #var x)) globals
618                in                in
619                  CL.D_StructDef(globs, RN.globalsTy)                  CL.D_StructDef(globs, tyName)
620                end                end
621    
622          fun genGlobals (declFn, targetTy, globals) = let          fun genGlobals (declFn, targetTy, globals) = let
623                fun doVar (x : mirror_var) = declFn (CL.D_Var([], targetTy x, #var x, NONE))                fun doVar (x : mirror_var) = declFn (CL.D_Var([], targetTy x, #var x, NONE))
624                in                in
625                  List.app doVar globals                  List.app doVar globals
626                end                end
627    
628            fun genStrandDesc (Strand{name, output, ...}) = let
629                (* the strand's descriptor object *)
630                  val descI = let
631                        fun fnPtr (ty, f) = CL.I_Exp(CL.mkCast(CL.T_Named ty, CL.mkVar f))
632                        val SOME(outTy, _) = !output
633                        in
634                          CL.I_Struct[
635                              ("name", CL.I_Exp(CL.mkStr name)),
636                              ("stateSzb", CL.I_Exp(CL.mkSizeof(CL.T_Named(N.strandTy name)))),
637    (*
638                              ("outputSzb", CL.I_Exp(CL.mkSizeof(ToC.trTy outTy))),
639    *)
640                              ("update", fnPtr("update_method_t", "0")),
641                              ("print", fnPtr("print_method_t", name ^ "Print"))
642                            ]
643                        end
644                  val desc = CL.D_Var([], CL.T_Named N.strandDescTy, N.strandDesc name, SOME descI)
645                  in
646                    desc
647                  end
648    
649          (* generate the table of strand descriptors *)
650            fun genStrandTable (declFn, strands) = let
651                  val nStrands = length strands
652                  fun genInit (Strand{name, ...}) = CL.I_Exp(CL.mkUnOp(CL.%&, CL.E_Var(N.strandDesc name)))
653                  fun genInits (_, []) = []
654                    | genInits (i, s::ss) = (i, genInit s) :: genInits(i+1, ss)
655                  in
656                    declFn (CL.D_Var([], CL.int32, N.numStrands,
657                      SOME(CL.I_Exp(CL.E_Int(IntInf.fromInt nStrands, CL.int32)))));
658                    declFn (CL.D_Var([],
659                      CL.T_Array(CL.T_Ptr(CL.T_Named N.strandDescTy), SOME nStrands),
660                      N.strands,
661                      SOME(CL.I_Array(genInits (0, strands)))))
662                  end
663    
664          fun genSrc (baseName, prog) = let          fun genSrc (baseName, prog) = let
665                val Prog{double, globals, topDecls, strands, initially, imgGlobals, numDims, ...} = prog                val Prog{name,double, globals, topDecls, strands, initially, imgGlobals, numDims, ...} = prog
666                val clFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "cl"}                val clFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "cl"}
667                val cFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}                val cFileName = OS.Path.joinBaseExt{base=baseName, ext=SOME "c"}
668                val clOutS = TextIO.openOut clFileName                val clOutS = TextIO.openOut clFileName
669                val cOutS = TextIO.openOut cFileName                val cOutS = TextIO.openOut cFileName
 (* FIXME: need to use PrintAsC and PrintAsCL *)  
670                val clppStrm = PrintAsCL.new clOutS                val clppStrm = PrintAsCL.new clOutS
671                val cppStrm = PrintAsC.new cOutS                val cppStrm = PrintAsC.new cOutS
672                  val progName = name
673                fun cppDecl dcl = PrintAsC.output(cppStrm, dcl)                fun cppDecl dcl = PrintAsC.output(cppStrm, dcl)
674                fun clppDecl dcl = PrintAsCL.output(clppStrm, dcl)                fun clppDecl dcl = PrintAsCL.output(clppStrm, dcl)
675                val strands = AtomTable.listItems strands                val strands = AtomTable.listItems strands
# Line 576  Line 683 
683                      "#define DIDEROT_TARGET_CL",                      "#define DIDEROT_TARGET_CL",
684                      "#include \"Diderot/cl-diderot.h\""                      "#include \"Diderot/cl-diderot.h\""
685                    ]));                    ]));
686                  genGlobals (clppDecl, #gpuTy, !globals);                  clppDecl (genGlobalStruct (#gpuTy, !globals, RN.globalsTy));
                 clppDecl (genGlobalStruct (#gpuTy, !globals));  
687                  clppDecl (genStrandTyDef(#gpuTy, strand));                  clppDecl (genStrandTyDef(#gpuTy, strand));
688                    clppDecl  (!init_code);
689                    clppDecl  (genStrandCopy(strand));
690                  List.app clppDecl (!code);                  List.app clppDecl (!code);
691                  clppDecl (genKernelFun (strand, !numDims, globals, imgGlobals));                  clppDecl (genKernelFun (strand, !numDims, globals, imgGlobals));
   
692                (* Generate the Host C file *)                (* Generate the Host C file *)
693                  cppDecl (CL.D_Verbatim([                  cppDecl (CL.D_Verbatim([
694                      if double                      if double
# Line 590  Line 697 
697                      "#define DIDEROT_TARGET_CL",                      "#define DIDEROT_TARGET_CL",
698                      "#include \"Diderot/diderot.h\""                      "#include \"Diderot/diderot.h\""
699                    ]));                    ]));
700                  genGlobals (cppDecl, #hostTy, !globals);                  cppDecl (CL.D_Var(["static"], CL.charPtr, "ProgramName",
701                  cppDecl (genGlobalStruct (#hostTy, !globals));                    SOME(CL.I_Exp(CL.mkStr progName))));
702                    cppDecl (genGlobalStruct (#hostTy, !globals, RN.globalsTy));
703                    cppDecl (genGlobalStruct (#shadowTy, !globals, RN.shadowGlobalsTy));
704    (* FIXME: does this really need to be a global? *)
705                    cppDecl (CL.D_Var(["static"], globPtrTy, RN.globalsVarName, NONE));
706                  cppDecl (genStrandTyDef (#hostTy, strand));                  cppDecl (genStrandTyDef (#hostTy, strand));
                 cppDecl  (!init_code);  
707                  cppDecl (genStrandPrint strand);                  cppDecl (genStrandPrint strand);
708                  List.app cppDecl (List.rev (!topDecls));                  List.app cppDecl (List.rev (!topDecls));
709                  cppDecl (genGlobalBuffersArgs (imgGlobals));                  cppDecl (genGlobalBuffersArgs imgGlobals);
710                    List.app (fn strand => cppDecl (genStrandDesc strand)) strands;
711                    genStrandTable (cppDecl, strands);
712                  cppDecl (!initially);                  cppDecl (!initially);
713                  PrintAsC.close cppStrm;                  PrintAsC.close cppStrm;
714                  PrintAsCL.close clppStrm;                  PrintAsCL.close clppStrm;
# Line 604  Line 716 
716                  TextIO.closeOut clOutS                  TextIO.closeOut clOutS
717                end                end
718    
719        (* output the code to a file.  The string is the basename of the file, the extension        (* output the code to the filesystem.  The string is the basename of the source file *)
        * is provided by the target.  
        *)  
720          fun generate (basename, prog as Prog{double, parallel, debug, ...}) = let          fun generate (basename, prog as Prog{double, parallel, debug, ...}) = let
721                fun condCons (true, x, xs) = x::xs                fun condCons (true, x, xs) = x::xs
722                  | condCons (false, _, xs) = xs                  | condCons (false, _, xs) = xs
# Line 661  Line 771 
771          fun init (Strand{name, tyName, code, init_code, ...}, params, init) = let          fun init (Strand{name, tyName, code, init_code, ...}, params, init) = let
772                val fName = RN.strandInit name                val fName = RN.strandInit name
773                val params =                val params =
774                      CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut") ::                      globalParam (globPtrTy, RN.globalsVarName) ::
775                        globalParam (CL.T_Ptr(CL.T_Named tyName), "selfOut") ::
776                        List.map (fn (ToCL.V(ty, x)) => CL.PARAM([], ty, x)) params                        List.map (fn (ToCL.V(ty, x)) => CL.PARAM([], ty, x)) params
777                val initFn = CL.D_Func([], CL.voidTy, fName, params, init)                val initFn = CL.D_Func([], CL.voidTy, fName, params, init)
778                in                in
# Line 672  Line 783 
783          fun method (Strand{name, tyName, code,...}, methName, body) = let          fun method (Strand{name, tyName, code,...}, methName, body) = let
784                val fName = concat[name, "_", methName]                val fName = concat[name, "_", methName]
785                val params = [                val params = [
786                        CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfIn"),                        globalParam (CL.T_Ptr(CL.T_Named tyName), "selfIn"),
787                        CL.PARAM([], CL.T_Ptr(CL.T_Named tyName), "selfOut")                        globalParam (CL.T_Ptr(CL.T_Named tyName), "selfOut"),
788                          globalParam (CL.T_Ptr(CL.T_Named (RN.globalsTy)), RN.globalsVarName)
789                      ]                      ]
790                val methFn = CL.D_Func([], CL.int32, fName, params, body)                val methFn = CL.D_Func([], CL.int32, fName, params, body)
791                in                in

Legend:
Removed from v.1308  
changed lines
  Added in v.1382

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0