Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/target-cpu/gen-world.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/target-cpu/gen-world.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4349 - (view) (download)

1 : jhr 3924 (* gen-world.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure GenWorld : sig
10 :    
11 : jhr 4349 val genStruct : CodeGenEnv.t * Atom.atom * int -> CLang.decl
12 : jhr 3924
13 : jhr 4130 val genInitiallyFun : CodeGenEnv.t * TreeIR.block * TreeIR.strand * TreeIR.create -> CLang.decl
14 : jhr 3924
15 :     end = struct
16 :    
17 : jhr 3974 structure IR = TreeIR
18 : jhr 3924 structure CL = CLang
19 :     structure RN = CxxNames
20 :     structure Env = CodeGenEnv
21 : jhr 3927 structure Util = CodeGenUtil
22 : jhr 3926 structure ToCxx = TreeToCxx
23 : jhr 3924
24 :     (* generate the struct declaration for the world representation *)
25 : jhr 4349 fun genStruct (env : CodeGenEnv.t, strandName, nAxes) = let
26 :     val spec = Env.target env
27 : jhr 4317 fun memberVar (ty, name) = CL.mkVarDcl(ty, name)
28 :     val members = []
29 :     val statePtrTy = CL.T_Ptr(CL.T_Named(Atom.toString strandName ^ "_strand"))
30 :     val members = if TargetSpec.dualState spec
31 :     then memberVar(statePtrTy, "_outState") ::
32 :     memberVar(statePtrTy, "_inState") :: members
33 :     else memberVar(statePtrTy, "_state") :: members
34 :     val members = memberVar(CL.T_Ptr CL.uint8, "_status") :: members
35 :     val members = if #hasGlobals spec
36 :     then memberVar(RN.globalPtrTy, "_globals") :: members
37 :     else members
38 :     val members = if #exec spec orelse not(#hasInputs spec)
39 :     then members
40 :     else memberVar(CL.T_Named "defined_inputs", "_definedInp") :: members
41 :     val members = if TargetSpec.isParallel spec
42 :     then members
43 :     else memberVar(CL.T_Named "uint32_t", "_nactive") ::
44 :     memberVar(CL.T_Named "uint32_t", "_nstable") :: members
45 :     (* add world method decls *)
46 :     fun memberFun (ty, name, params) = CL.mkProto(ty, name, params)
47 :     val members =
48 :     memberFun (CL.voidTy, "swap_state", []) ::
49 :     memberFun (CL.uint32, "run", [CL.PARAM([], CL.uint32, "max_nsteps")]) ::
50 :     memberFun (CL.boolTy, "initially", []) ::
51 :     memberFun (CL.boolTy, "alloc", [
52 :     CL.PARAM([], CL.T_Array(CL.int32, SOME nAxes), "base"),
53 :     CL.PARAM([], CL.T_Array(CL.uint32, SOME nAxes), "size")
54 :     ]) ::
55 :     memberFun (CL.boolTy, "init", []) ::
56 :     CL.mkDestrProto "world" ::
57 :     CL.mkConstrProto("world", []) ::
58 :     members
59 :     val members = if #hasGlobalInitially spec
60 :     then memberFun (CL.voidTy, "global_initially", []) :: members
61 :     else members
62 :     val members = if #hasGlobalUpdate spec
63 :     then memberFun (CL.voidTy, "global_update", []) :: members
64 :     else members
65 : jhr 4349 val members = if #hasCom spec
66 :     then let
67 :     val realTy = Env.realTy env
68 :     val seqTy = CL.T_Template("diderot::dynseq", [CL.uint32])
69 :     val posTy = if nAxes > 1
70 :     then TypeToCxx.trType(env, TreeTypes.TensorRefTy[nAxes])
71 :     else realTy
72 :     in
73 :     memberFun (seqTy, "sphere_query", [
74 :     CL.PARAM([], posTy, "pos"),
75 :     CL.PARAM([], realTy, "radius")
76 :     ]) :: members
77 :     end
78 :     else members
79 : jhr 4317 in
80 :     CL.D_ClassDef{
81 :     name = "world",
82 :     args = NONE,
83 :     from = SOME "public diderot::world_base",
84 :     public = List.rev members,
85 :     protected = [],
86 :     private = []
87 :     }
88 :     end
89 : jhr 3924
90 : jhr 4130 fun genInitiallyFun (env : CodeGenEnv.t, globInit, strand, create) = let
91 : jhr 4317 val IR.Strand{name, stateInit=IR.Method{hasG, needsW, ...}, ...} = strand
92 :     val strandName = Atom.toString name
93 :     val env = Env.insert(env, PseudoVars.world, "this")
94 :     val thisV = CL.mkVar "this"
95 :     val spec = Env.target env
96 :     val {dim, locals, prefix, loops, body} = Util.decomposeCreate create
97 :     (* for each loop in the nest, we return the tuple
98 :     * (stms, loExp, hiExp, szExp, mkLoop)
99 :     * where `stms` are the statements needed to define any new variables,
100 :     * `loExp` and `hiExp` are CLang expressions for the low and high loop
101 :     * bounds, `szExp` is the number of loop iterations, and mkLoop is a
102 :     * function for buildÃ¥ing the CLang representation of the loop.
103 :     *)
104 :     fun doLoop env (Util.ForLoop(i, lo, hi)) = let
105 :     val (loV, loStms) = ToCxx.trExpToVar (env, CL.intTy, "lo", lo)
106 :     val (hiV, hiStms) = ToCxx.trExpToVar (env, CL.intTy, "hi", hi)
107 :     val szE = CL.mkBinOp(CL.mkBinOp(hiV, CL.#-, loV), CL.#+, CL.mkInt 1)
108 :     val stms = loStms @ hiStms
109 :     fun mkLoop (env, mkBody) = let
110 :     val iV = TreeVar.name i
111 :     in
112 :     CL.mkFor(
113 :     CL.intTy, [( iV, loV)],
114 :     CL.mkBinOp(CL.mkVar iV, CL.#<=, hiV),
115 :     [CL.mkPostOp(CL.mkVar iV, CL.^++)],
116 :     mkBody (Env.insert (env, i, iV)))
117 :     end
118 :     in
119 :     (stms, loV, hiV, szE, mkLoop)
120 :     end
121 :     | doLoop env (Util.ForeachLoop(i, seq)) = let
122 :     val seqTy = ToCxx.trType (env, TreeTypes.SeqTy(TreeVar.ty i, NONE))
123 :     val (seqV, stms) = ToCxx.trExpToVar (env, seqTy, "seq", seq)
124 :     val szE = CL.mkDispatch(seqV, "length", [])
125 :     fun mkLoop (env, mkBody) = raise Fail "FIXME"
126 :     in
127 :     (stms, CL.mkInt 0, szE, szE, mkLoop)
128 :     end
129 :     fun tr env = let
130 :     val (env, prefixCode) = TreeToCxx.trStms (env, prefix)
131 :     val loopInfo = List.map (doLoop env) loops
132 :     (* collect the statements that define the loop bounds *)
133 :     val bndsStms = List.foldr
134 :     (fn ((stms, _, _, _, _), stms') => stms @ stms')
135 :     [] loopInfo
136 :     val allocStm =
137 : jhr 3927 CL.mkIfThen(CL.mkIndirectDispatch(thisV, "alloc", [
138 :     CL.mkVar "base",
139 :     CL.mkVar "size"
140 : jhr 3926 ]),
141 :     (* then *)
142 :     CL.mkBlock [
143 :     CL.mkReturn(SOME(CL.mkVar "true"))
144 :     ])
145 :     (* endif *)
146 : jhr 4317 fun mkArrDcl (ty, name, dim, init) = CL.mkDecl(
147 :     CL.T_Array(ty, SOME dim), name,
148 :     SOME(CL.I_Exps(List.map CL.I_Exp init)))
149 :     (* code to allocate strands *)
150 :     val allocCode = (case dim
151 :     of NONE => let (* collection of strands *)
152 :     val (sz1::szs) = List.map #4 loopInfo
153 :     val sizeExp = List.foldl
154 :     (fn (sz, lhs) => CL.mkBinOp(lhs, CL.#*, sz))
155 :     sz1 szs
156 :     in [
157 :     mkArrDcl(CL.int32, "base", 1, [CL.mkInt 0]),
158 :     mkArrDcl(CL.uint32, "size", 1, [CL.mkStaticCast(CL.uint32, sizeExp)]),
159 :     allocStm
160 :     ] end
161 :     | SOME d => let (* grid of strands *)
162 :     val baseInit = List.map #2 loopInfo
163 :     val sizeInit = List.map
164 :     (fn info => CL.mkStaticCast(CL.uint32, #4 info))
165 :     loopInfo
166 :     in [
167 :     mkArrDcl(CL.int32, "base", d, baseInit),
168 :     mkArrDcl(CL.uint32, "size", d, sizeInit),
169 :     allocStm
170 :     ] end
171 :     (* end case *))
172 :     val idx = "ix" (* for indexing into the strand-state array *)
173 :     val loopCode = let
174 :     val idxV = CL.mkVar idx
175 :     fun statePtr inout =
176 :     CL.mkAddrOf(CL.mkSubscript(CL.mkIndirect(thisV, inout), idxV))
177 :     fun mkNest [] env = ToCxx.trWithLocals (env, #locals body,
178 :     fn env => let
179 :     val (env, stms') = ToCxx.trStms (env, #stms body)
180 :     val (_, args) = #newStm body
181 :     (* NOTE: the args' list must match the parameters in GenStrand *)
182 :     val args' = List.map (fn e => ToCxx.trExp(env, e)) args
183 :     val args' = let
184 :     val state = if TargetSpec.dualState spec
185 :     then "_inState"
186 :     else "_state"
187 :     in
188 :     statePtr state :: args'
189 :     end
190 :     val args' = if hasG
191 :     then CL.mkIndirect(thisV, "_globals") :: args'
192 :     else args'
193 :     val args' = if needsW
194 :     then thisV :: args'
195 :     else args'
196 :     val newStm = CL.mkCall(strandName ^ "_init", args')
197 :     val incStm = CL.mkExpStm (CL.mkUnOp(CL.%++, idxV))
198 :     val newStms = if TargetSpec.dualState spec
199 :     then [
200 :     newStm,
201 :     CL.mkCall("memcpy", [
202 :     statePtr "_outState", statePtr "_inState",
203 :     CL.mkSizeof(CL.T_Named(strandName ^ "_strand"))
204 :     ]),
205 :     incStm
206 :     ]
207 :     else [newStm, incStm]
208 :     in
209 :     stms' @ newStms
210 :     end)
211 :     | mkNest ((_, _, _, _, mkLoop)::r) env = mkLoop (env, mkNest r)
212 :     in
213 :     mkNest loopInfo env
214 :     end
215 :     val stms = prefixCode @ bndsStms @ allocCode @ [
216 :     CL.mkDecl(CL.uint32, idx, SOME(CL.I_Exp(CL.E_Int(0, CL.uint32)))),
217 :     loopCode,
218 : jhr 3927 CL.mkAssign(
219 : jhr 4317 CL.mkIndirect(thisV, "_stage"),
220 :     CL.mkVar "diderot::POST_INITIALLY"),
221 : jhr 3927 CL.mkReturn(SOME(CL.mkVar "false"))
222 :     ]
223 : jhr 4317 val stms = if #hasGlobals spec
224 :     then CL.mkDeclInit (
225 :     RN.globalPtrTy, RN.globalsVar, CL.mkIndirect(thisV, "_globals")) ::
226 :     stms
227 :     else stms
228 :     val stms = if #hasGlobalInit spec
229 :     then CL.mkIfThen (CL.mkApply ("init_globals", [thisV]),
230 :     CL.mkReturn(SOME(CL.mkVar "true"))
231 :     ) :: stms
232 :     else stms
233 :     in
234 :     stms
235 :     end (* tr *)
236 :     val body = TreeToCxx.trWithLocals (env, locals, tr)
237 :     in
238 : jhr 4028 CL.mkFuncDcl(CL.boolTy, "world::initially", [], body)
239 : jhr 4317 end
240 : jhr 3924
241 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0