Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/util.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3501 - (view) (download)

1 : jhr 3464 (* util.sml
2 :     *
3 :     * Utility code for Simplification.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure Util : sig
12 :    
13 :     (* return information about a reduction operator *)
14 :     val reductionInfo : Var.t -> {
15 :     rator : Var.t, (* primitive operator *)
16 :     init : Literal.t, (* identity element to use for initialization *)
17 :     mvs : SimpleTypes.meta_arg list (* meta-variable arguments for primitive application *)
18 :     }
19 :    
20 : jhr 3465 (* convert a block into a function by closing over its free variables *)
21 :     val makeFunction : string * Simple.block * SimpleTypes.ty -> Simple.func * Simple.var list
22 :    
23 : jhr 3464 end = struct
24 :    
25 : jhr 3465 structure S = Simple
26 : jhr 3464 structure BV = BasisVars
27 :     structure L = Literal
28 :     structure R = RealLit
29 : jhr 3465 structure VMap = SimpleVar.Map
30 : jhr 3464
31 :     fun reductionInfo rator =
32 :     if Var.same(BV.red_all, rator)
33 :     then {rator = BV.op_and, init = L.Bool true, mvs = []}
34 :     else if Var.same(BV.red_exists, rator)
35 :     then {rator = BV.op_or, init = L.Bool false, mvs = []}
36 :     else if Var.same(BV.red_max, rator)
37 : jhr 3482 then {rator = BV.fn_max_r, init = L.Real R.negInf, mvs = []}
38 : jhr 3464 else if Var.same(BV.red_mean, rator)
39 :     then raise Fail "FIXME: 'mean' reduction not yet supported"
40 :     else if Var.same(BV.red_min, rator)
41 : jhr 3482 then {rator = BV.fn_min_r, init = L.Real R.posInf, mvs = []}
42 : jhr 3464 else if Var.same(BV.red_product, rator)
43 :     then {rator = BV.mul_rr, init = L.Real R.one, mvs = []}
44 :     else if Var.same(BV.red_sum, rator)
45 :     then {rator = BV.add_tt, init = L.Real R.one, mvs = [SimpleTypes.SHAPE[]]}
46 :     else if Var.same(BV.red_variance, rator)
47 :     then raise Fail "FIXME: 'variance' reduction not yet supported"
48 :     else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator")
49 :    
50 : jhr 3465 local
51 :     val n = ref 0
52 :     fun mkFuncId (name, ty) = let val id = !n
53 :     in
54 :     n := id + 1;
55 :     SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty)
56 :     end
57 :     in
58 :     fun makeFunction (name, blk, resTy) = let
59 :     val freeVars = ref []
60 :     fun cvtVar (env, x) = (case VMap.find(env, x)
61 :     of SOME x' => (env, x')
62 :     | NONE => let
63 :     val x' = SimpleVar.copy(x, SimpleVar.FunParam)
64 :     in
65 :     freeVars := (x, x') :: !freeVars;
66 :     (VMap.insert(env, x, x'), x')
67 :     end
68 :     (* end case *))
69 :     fun cvtVars (env, xs) = let
70 :     fun cvt (x, (env, xs')) = let
71 :     val (env, x') = cvtVar (env, x)
72 :     in
73 :     (env, x'::xs')
74 :     end
75 :     in
76 :     List.foldr cvt (env, []) xs
77 :     end
78 :     fun newVar (env, x) = let
79 :     val x' = SimpleVar.copy(x, SimpleVar.LocalVar)
80 :     in
81 :     (VMap.insert(env, x, x'), x')
82 :     end
83 : jhr 3501 fun cvtBlock (env, S.Block{props, code}) = let
84 :     fun cvtStms (env, [], stms') = (env, S.Block{props = props, code = List.rev stms'})
85 : jhr 3465 | cvtStms (env, stm::stms, stms') = let
86 :     val (env, stm') = cvtStm (env, stm)
87 :     in
88 :     cvtStms (env, stms, stm'::stms')
89 :     end
90 :     in
91 : jhr 3501 cvtStms (env, code, [])
92 : jhr 3465 end
93 :     and cvtStm (env, stm) = (case stm
94 :     of S.S_Var(x, NONE) => let
95 :     val (env, x') = newVar (env, x)
96 :     in
97 :     (env, S.S_Var(x', NONE))
98 :     end
99 :     | S.S_Var(x, SOME e) => let
100 :     val (env, e') = cvtExp (env, e)
101 :     val (env, x') = newVar (env, x)
102 :     in
103 :     (env, S.S_Var(x', SOME e'))
104 :     end
105 :     | S.S_Assign(x, e) => let
106 :     val (env, e') = cvtExp (env, e)
107 :     val (env, x') = cvtVar (env, x)
108 :     in
109 :     (env, S.S_Assign(x', e'))
110 :     end
111 :     | S.S_IfThenElse(x, b1, b2) => let
112 :     val (env, x') = cvtVar (env, x)
113 :     val (env, b1') = cvtBlock (env, b1)
114 :     val (env, b2') = cvtBlock (env, b2)
115 :     in
116 :     (env, S.S_IfThenElse(x', b1', b2'))
117 :     end
118 :     | S.S_Foreach(x, xs, b) => let
119 :     val (env, x') = cvtVar (env, x)
120 :     val (env, xs') = cvtVar (env, xs)
121 :     val (env, b') = cvtBlock (env, b)
122 :     in
123 :     (env, S.S_Foreach(x', xs', b'))
124 :     end
125 :     | S.S_New(name, args) => let
126 :     val (env, args') = cvtVars (env, args)
127 :     in
128 :     (env, S.S_New(name, args'))
129 :     end
130 :     | S.S_Continue => (env, stm)
131 :     | S.S_Die => (env, stm)
132 :     | S.S_Stabilize => (env, stm)
133 :     | S.S_Return x => let
134 :     val (env, x') = cvtVar (env, x)
135 :     in
136 :     (env, S.S_Return x')
137 :     end
138 :     | S.S_Print xs => let
139 :     val (env, xs') = cvtVars (env, xs)
140 :     in
141 :     (env, S.S_Print xs')
142 :     end
143 :     | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce"
144 :     (* end case *))
145 :     and cvtExp (env, exp) = (case exp
146 :     of S.E_Var x => let
147 :     val (env, x') = cvtVar (env, x)
148 :     in
149 :     (env, S.E_Var x')
150 :     end
151 :     | S.E_Lit _ => (env, exp)
152 :     | S.E_Select(x, fld) => let
153 :     val (env, x') = cvtVar (env, x)
154 :     in
155 :     (env, S.E_Select(x', fld))
156 :     end
157 :     | S.E_Apply(f, args, ty) => let
158 :     val (env, args') = cvtVars (env, args)
159 :     in
160 :     (env, S.E_Apply(f, args', ty))
161 :     end
162 :     | S.E_Prim(f, mvs, args, ty) => let
163 :     val (env, args') = cvtVars (env, args)
164 :     in
165 :     (env, S.E_Prim(f, mvs, args', ty))
166 :     end
167 :     | S.E_Tensor(args, ty) => let
168 :     val (env, args') = cvtVars (env, args)
169 :     in
170 :     (env, S.E_Tensor(args', ty))
171 :     end
172 :     | S.E_Seq(args, ty) => let
173 :     val (env, args') = cvtVars (env, args)
174 :     in
175 :     (env, S.E_Seq(args', ty))
176 :     end
177 :     | S.E_Slice(x, indices, ty) => let
178 :     fun cvt (NONE, (env, idxs)) = (env, NONE::idxs)
179 :     | cvt (SOME x, (env, idxs)) = let
180 :     val (env, x') = cvtVar (env, x)
181 :     in
182 :     (env, SOME x' :: idxs)
183 :     end
184 :     val (env, x') = cvtVar (env, x)
185 :     val (env, indices') = List.foldr cvt (env, []) indices
186 :     in
187 :     (env, S.E_Slice(x', indices', ty))
188 :     end
189 :     | S.E_Coerce{srcTy, dstTy, x} => let
190 :     val (env, x') = cvtVar (env, x)
191 :     in
192 :     (env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'})
193 :     end
194 :     | S.E_LoadSeq _ => (env, exp)
195 :     | S.E_LoadImage _ => (env, exp)
196 :     (* end case *))
197 :     val (env, blk) = cvtBlock (VMap.empty, blk)
198 :     val (args, params) = ListPair.unzip (List.rev (! freeVars))
199 :     val fnTy = SimpleTypes.T_Fun(List.map SimpleVar.typeOf params, resTy)
200 :     in
201 :     (S.Func{f=mkFuncId(name, fnTy), params=params, body=blk}, args)
202 :     end
203 :     end (* local *)
204 :    
205 : jhr 3464 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0