Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/target-cpu/gen-world.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/target-cpu/gen-world.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4505 - (view) (download)

1 : jhr 3924 (* gen-world.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure GenWorld : sig
10 :    
11 : jhr 4349 val genStruct : CodeGenEnv.t * Atom.atom * int -> CLang.decl
12 : jhr 3924
13 : jhr 4407 (* generate the function that creates the initial set of strands *)
14 :     val genCreateFun : CodeGenEnv.t * TreeIR.block * TreeIR.strand * TreeIR.create -> CLang.decl
15 : jhr 3924
16 :     end = struct
17 :    
18 : jhr 3974 structure IR = TreeIR
19 : jhr 3924 structure CL = CLang
20 :     structure RN = CxxNames
21 :     structure Env = CodeGenEnv
22 : jhr 3927 structure Util = CodeGenUtil
23 : jhr 3926 structure ToCxx = TreeToCxx
24 : jhr 3924
25 :     (* generate the struct declaration for the world representation *)
26 : jhr 4349 fun genStruct (env : CodeGenEnv.t, strandName, nAxes) = let
27 : jhr 4369 val spec = Env.target env
28 : jhr 4317 fun memberVar (ty, name) = CL.mkVarDcl(ty, name)
29 : jhr 4500 val members = [memberVar(CL.T_Named "strand_array", "_strands")]
30 : jhr 4317 val members = if #hasGlobals spec
31 :     then memberVar(RN.globalPtrTy, "_globals") :: members
32 :     else members
33 :     val members = if #exec spec orelse not(#hasInputs spec)
34 :     then members
35 :     else memberVar(CL.T_Named "defined_inputs", "_definedInp") :: members
36 :     val members = if TargetSpec.isParallel spec
37 :     then members
38 :     else memberVar(CL.T_Named "uint32_t", "_nactive") ::
39 :     memberVar(CL.T_Named "uint32_t", "_nstable") :: members
40 : jhr 4369 val members = (case #spatialDim spec
41 :     of SOME d => memberVar(CL.T_Ptr(CL.T_Template("diderot::kdtree", [
42 : jhr 4500 CL.T_Named(Int.toString d), Env.realTy env, CL.T_Named "strand_array"
43 : jhr 4369 ])), "_tree") :: members
44 :     | NONE => members
45 :     (* end case *))
46 : jhr 4317 (* add world method decls *)
47 :     fun memberFun (ty, name, params) = CL.mkProto(ty, name, params)
48 : jhr 4407 val members = CL.mkConstrProto("world", []) :: members
49 :     val members = CL.mkDestrProto "world" :: members
50 :     val members = memberFun (CL.boolTy, "init", []) :: members
51 :     val members = memberFun (CL.boolTy, "alloc", [
52 :     CL.PARAM([], CL.T_Array(CL.int32, SOME nAxes), "base"),
53 :     CL.PARAM([], CL.T_Array(CL.uint32, SOME nAxes), "size")
54 :     ]) :: members
55 :     val members = memberFun (CL.boolTy, "create_strands", []) :: members
56 : jhr 4500 val members = if #hasStartMeth spec
57 :     then memberFun (CL.voidTy, "run_start_methods", []) :: members
58 : jhr 4407 else members
59 :     val members = memberFun (CL.uint32, "run", [CL.PARAM([], CL.uint32, "max_nsteps")]) :: members
60 :     val members = memberFun (CL.voidTy, "swap_state", []) :: members
61 :     (* FIXME: the following three functions should be "const" functions *)
62 :     val members = memberFun (CL.uint32, "num_stable_strands", []) :: members
63 :     val members = memberFun (CL.uint32, "num_all_strands", []) :: members
64 :     val members = memberFun (CL.uint32, "num_active_strands", []) :: members
65 : jhr 4500 val members = if #hasGlobalStart spec
66 :     then memberFun (CL.voidTy, "global_start", []) :: members
67 : jhr 4317 else members
68 :     val members = if #hasGlobalUpdate spec
69 :     then memberFun (CL.voidTy, "global_update", []) :: members
70 :     else members
71 : jhr 4500 val members = if #hasStabilizeAll spec
72 :     then memberFun (CL.voidTy, "stabilize_all", []) :: members
73 :     else members
74 : jhr 4317 in
75 :     CL.D_ClassDef{
76 :     name = "world",
77 :     args = NONE,
78 :     from = SOME "public diderot::world_base",
79 :     public = List.rev members,
80 :     protected = [],
81 :     private = []
82 :     }
83 :     end
84 : jhr 3924
85 : jhr 4407 fun genCreateFun (env : CodeGenEnv.t, globInit, strand, create) = let
86 : jhr 4317 val IR.Strand{name, stateInit=IR.Method{hasG, needsW, ...}, ...} = strand
87 :     val strandName = Atom.toString name
88 :     val env = Env.insert(env, PseudoVars.world, "this")
89 :     val thisV = CL.mkVar "this"
90 :     val spec = Env.target env
91 :     val {dim, locals, prefix, loops, body} = Util.decomposeCreate create
92 :     (* for each loop in the nest, we return the tuple
93 :     * (stms, loExp, hiExp, szExp, mkLoop)
94 :     * where `stms` are the statements needed to define any new variables,
95 :     * `loExp` and `hiExp` are CLang expressions for the low and high loop
96 :     * bounds, `szExp` is the number of loop iterations, and mkLoop is a
97 :     * function for buildÃ¥ing the CLang representation of the loop.
98 :     *)
99 :     fun doLoop env (Util.ForLoop(i, lo, hi)) = let
100 :     val (loV, loStms) = ToCxx.trExpToVar (env, CL.intTy, "lo", lo)
101 :     val (hiV, hiStms) = ToCxx.trExpToVar (env, CL.intTy, "hi", hi)
102 :     val szE = CL.mkBinOp(CL.mkBinOp(hiV, CL.#-, loV), CL.#+, CL.mkInt 1)
103 :     val stms = loStms @ hiStms
104 :     fun mkLoop (env, mkBody) = let
105 :     val iV = TreeVar.name i
106 :     in
107 :     CL.mkFor(
108 :     CL.intTy, [( iV, loV)],
109 :     CL.mkBinOp(CL.mkVar iV, CL.#<=, hiV),
110 :     [CL.mkPostOp(CL.mkVar iV, CL.^++)],
111 :     mkBody (Env.insert (env, i, iV)))
112 :     end
113 :     in
114 :     (stms, loV, hiV, szE, mkLoop)
115 :     end
116 :     | doLoop env (Util.ForeachLoop(i, seq)) = let
117 :     val seqTy = ToCxx.trType (env, TreeTypes.SeqTy(TreeVar.ty i, NONE))
118 :     val (seqV, stms) = ToCxx.trExpToVar (env, seqTy, "seq", seq)
119 :     val szE = CL.mkDispatch(seqV, "length", [])
120 :     fun mkLoop (env, mkBody) = raise Fail "FIXME"
121 :     in
122 :     (stms, CL.mkInt 0, szE, szE, mkLoop)
123 :     end
124 :     fun tr env = let
125 :     val (env, prefixCode) = TreeToCxx.trStms (env, prefix)
126 :     val loopInfo = List.map (doLoop env) loops
127 :     (* collect the statements that define the loop bounds *)
128 :     val bndsStms = List.foldr
129 :     (fn ((stms, _, _, _, _), stms') => stms @ stms')
130 :     [] loopInfo
131 :     val allocStm =
132 : jhr 3927 CL.mkIfThen(CL.mkIndirectDispatch(thisV, "alloc", [
133 :     CL.mkVar "base",
134 :     CL.mkVar "size"
135 : jhr 3926 ]),
136 :     (* then *)
137 :     CL.mkBlock [
138 :     CL.mkReturn(SOME(CL.mkVar "true"))
139 :     ])
140 :     (* endif *)
141 : jhr 4317 fun mkArrDcl (ty, name, dim, init) = CL.mkDecl(
142 :     CL.T_Array(ty, SOME dim), name,
143 :     SOME(CL.I_Exps(List.map CL.I_Exp init)))
144 :     (* code to allocate strands *)
145 :     val allocCode = (case dim
146 :     of NONE => let (* collection of strands *)
147 :     val (sz1::szs) = List.map #4 loopInfo
148 :     val sizeExp = List.foldl
149 :     (fn (sz, lhs) => CL.mkBinOp(lhs, CL.#*, sz))
150 :     sz1 szs
151 :     in [
152 :     mkArrDcl(CL.int32, "base", 1, [CL.mkInt 0]),
153 :     mkArrDcl(CL.uint32, "size", 1, [CL.mkStaticCast(CL.uint32, sizeExp)]),
154 :     allocStm
155 :     ] end
156 :     | SOME d => let (* grid of strands *)
157 :     val baseInit = List.map #2 loopInfo
158 :     val sizeInit = List.map
159 :     (fn info => CL.mkStaticCast(CL.uint32, #4 info))
160 :     loopInfo
161 :     in [
162 :     mkArrDcl(CL.int32, "base", d, baseInit),
163 :     mkArrDcl(CL.uint32, "size", d, sizeInit),
164 :     allocStm
165 :     ] end
166 :     (* end case *))
167 :     val idx = "ix" (* for indexing into the strand-state array *)
168 :     val loopCode = let
169 :     val idxV = CL.mkVar idx
170 :     fun statePtr inout =
171 :     CL.mkAddrOf(CL.mkSubscript(CL.mkIndirect(thisV, inout), idxV))
172 :     fun mkNest [] env = ToCxx.trWithLocals (env, #locals body,
173 :     fn env => let
174 :     val (env, stms') = ToCxx.trStms (env, #stms body)
175 :     val (_, args) = #newStm body
176 :     (* NOTE: the args' list must match the parameters in GenStrand *)
177 :     val args' = List.map (fn e => ToCxx.trExp(env, e)) args
178 : jhr 4500 val args' =
179 : jhr 4505 CL.mkDispatch(RN.strandArray env, "strand", [idxV])
180 : jhr 4500 :: args'
181 : jhr 4317 val args' = if hasG
182 :     then CL.mkIndirect(thisV, "_globals") :: args'
183 :     else args'
184 :     val args' = if needsW
185 :     then thisV :: args'
186 :     else args'
187 :     val newStm = CL.mkCall(strandName ^ "_init", args')
188 :     val incStm = CL.mkExpStm (CL.mkUnOp(CL.%++, idxV))
189 :     in
190 : jhr 4500 stms' @ [newStm, incStm]
191 : jhr 4317 end)
192 :     | mkNest ((_, _, _, _, mkLoop)::r) env = mkLoop (env, mkNest r)
193 :     in
194 :     mkNest loopInfo env
195 :     end
196 :     val stms = prefixCode @ bndsStms @ allocCode @ [
197 :     CL.mkDecl(CL.uint32, idx, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
198 :     loopCode,
199 : jhr 3927 CL.mkAssign(
200 : jhr 4317 CL.mkIndirect(thisV, "_stage"),
201 : jhr 4387 CL.mkVar "diderot::POST_CREATE"),
202 : jhr 3927 CL.mkReturn(SOME(CL.mkVar "false"))
203 :     ]
204 : jhr 4317 val stms = if #hasGlobals spec
205 :     then CL.mkDeclInit (
206 :     RN.globalPtrTy, RN.globalsVar, CL.mkIndirect(thisV, "_globals")) ::
207 :     stms
208 :     else stms
209 :     val stms = if #hasGlobalInit spec
210 :     then CL.mkIfThen (CL.mkApply ("init_globals", [thisV]),
211 :     CL.mkReturn(SOME(CL.mkVar "true"))
212 :     ) :: stms
213 :     else stms
214 :     in
215 :     stms
216 :     end (* tr *)
217 :     val body = TreeToCxx.trWithLocals (env, locals, tr)
218 :     in
219 : jhr 4407 CL.mkFuncDcl(CL.boolTy, "world::create_strands", [], body)
220 : jhr 4317 end
221 : jhr 3924
222 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0