SCM Repository
Annotation of /branches/vis15/src/compiler/simplify/util.sml
Parent Directory
|
Revision Log
Revision 3465 - (view) (download)
1 : | jhr | 3464 | (* util.sml |
2 : | * | ||
3 : | * Utility code for Simplification. | ||
4 : | * | ||
5 : | * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) | ||
6 : | * | ||
7 : | * COPYRIGHT (c) 2015 The University of Chicago | ||
8 : | * All rights reserved. | ||
9 : | *) | ||
10 : | |||
11 : | structure Util : sig | ||
12 : | |||
13 : | (* return information about a reduction operator *) | ||
14 : | val reductionInfo : Var.t -> { | ||
15 : | rator : Var.t, (* primitive operator *) | ||
16 : | init : Literal.t, (* identity element to use for initialization *) | ||
17 : | mvs : SimpleTypes.meta_arg list (* meta-variable arguments for primitive application *) | ||
18 : | } | ||
19 : | |||
20 : | jhr | 3465 | (* convert a block into a function by closing over its free variables *) |
21 : | val makeFunction : string * Simple.block * SimpleTypes.ty -> Simple.func * Simple.var list | ||
22 : | |||
23 : | jhr | 3464 | end = struct |
24 : | |||
25 : | jhr | 3465 | structure S = Simple |
26 : | jhr | 3464 | structure BV = BasisVars |
27 : | structure L = Literal | ||
28 : | structure R = RealLit | ||
29 : | jhr | 3465 | structure VMap = SimpleVar.Map |
30 : | jhr | 3464 | |
31 : | fun reductionInfo rator = | ||
32 : | if Var.same(BV.red_all, rator) | ||
33 : | then {rator = BV.op_and, init = L.Bool true, mvs = []} | ||
34 : | else if Var.same(BV.red_exists, rator) | ||
35 : | then {rator = BV.op_or, init = L.Bool false, mvs = []} | ||
36 : | else if Var.same(BV.red_max, rator) | ||
37 : | then {rator = BV.fn_max, init = L.Real R.negInf, mvs = []} | ||
38 : | else if Var.same(BV.red_mean, rator) | ||
39 : | then raise Fail "FIXME: 'mean' reduction not yet supported" | ||
40 : | else if Var.same(BV.red_min, rator) | ||
41 : | then {rator = BV.fn_min, init = L.Real R.posInf, mvs = []} | ||
42 : | else if Var.same(BV.red_product, rator) | ||
43 : | then {rator = BV.mul_rr, init = L.Real R.one, mvs = []} | ||
44 : | else if Var.same(BV.red_sum, rator) | ||
45 : | then {rator = BV.add_tt, init = L.Real R.one, mvs = [SimpleTypes.SHAPE[]]} | ||
46 : | else if Var.same(BV.red_variance, rator) | ||
47 : | then raise Fail "FIXME: 'variance' reduction not yet supported" | ||
48 : | else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator") | ||
49 : | |||
50 : | jhr | 3465 | local |
51 : | val n = ref 0 | ||
52 : | fun mkFuncId (name, ty) = let val id = !n | ||
53 : | in | ||
54 : | n := id + 1; | ||
55 : | SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty) | ||
56 : | end | ||
57 : | in | ||
58 : | fun makeFunction (name, blk, resTy) = let | ||
59 : | val freeVars = ref [] | ||
60 : | fun cvtVar (env, x) = (case VMap.find(env, x) | ||
61 : | of SOME x' => (env, x') | ||
62 : | | NONE => let | ||
63 : | val x' = SimpleVar.copy(x, SimpleVar.FunParam) | ||
64 : | in | ||
65 : | freeVars := (x, x') :: !freeVars; | ||
66 : | (VMap.insert(env, x, x'), x') | ||
67 : | end | ||
68 : | (* end case *)) | ||
69 : | fun cvtVars (env, xs) = let | ||
70 : | fun cvt (x, (env, xs')) = let | ||
71 : | val (env, x') = cvtVar (env, x) | ||
72 : | in | ||
73 : | (env, x'::xs') | ||
74 : | end | ||
75 : | in | ||
76 : | List.foldr cvt (env, []) xs | ||
77 : | end | ||
78 : | fun newVar (env, x) = let | ||
79 : | val x' = SimpleVar.copy(x, SimpleVar.LocalVar) | ||
80 : | in | ||
81 : | (VMap.insert(env, x, x'), x') | ||
82 : | end | ||
83 : | fun cvtBlock (env, S.Block stms) = let | ||
84 : | fun cvtStms (env, [], stms') = (env, S.Block(List.rev stms')) | ||
85 : | | cvtStms (env, stm::stms, stms') = let | ||
86 : | val (env, stm') = cvtStm (env, stm) | ||
87 : | in | ||
88 : | cvtStms (env, stms, stm'::stms') | ||
89 : | end | ||
90 : | in | ||
91 : | cvtStms (env, stms, []) | ||
92 : | end | ||
93 : | and cvtStm (env, stm) = (case stm | ||
94 : | of S.S_Var(x, NONE) => let | ||
95 : | val (env, x') = newVar (env, x) | ||
96 : | in | ||
97 : | (env, S.S_Var(x', NONE)) | ||
98 : | end | ||
99 : | | S.S_Var(x, SOME e) => let | ||
100 : | val (env, e') = cvtExp (env, e) | ||
101 : | val (env, x') = newVar (env, x) | ||
102 : | in | ||
103 : | (env, S.S_Var(x', SOME e')) | ||
104 : | end | ||
105 : | | S.S_Assign(x, e) => let | ||
106 : | val (env, e') = cvtExp (env, e) | ||
107 : | val (env, x') = cvtVar (env, x) | ||
108 : | in | ||
109 : | (env, S.S_Assign(x', e')) | ||
110 : | end | ||
111 : | | S.S_IfThenElse(x, b1, b2) => let | ||
112 : | val (env, x') = cvtVar (env, x) | ||
113 : | val (env, b1') = cvtBlock (env, b1) | ||
114 : | val (env, b2') = cvtBlock (env, b2) | ||
115 : | in | ||
116 : | (env, S.S_IfThenElse(x', b1', b2')) | ||
117 : | end | ||
118 : | | S.S_Foreach(x, xs, b) => let | ||
119 : | val (env, x') = cvtVar (env, x) | ||
120 : | val (env, xs') = cvtVar (env, xs) | ||
121 : | val (env, b') = cvtBlock (env, b) | ||
122 : | in | ||
123 : | (env, S.S_Foreach(x', xs', b')) | ||
124 : | end | ||
125 : | | S.S_New(name, args) => let | ||
126 : | val (env, args') = cvtVars (env, args) | ||
127 : | in | ||
128 : | (env, S.S_New(name, args')) | ||
129 : | end | ||
130 : | | S.S_Continue => (env, stm) | ||
131 : | | S.S_Die => (env, stm) | ||
132 : | | S.S_Stabilize => (env, stm) | ||
133 : | | S.S_Return x => let | ||
134 : | val (env, x') = cvtVar (env, x) | ||
135 : | in | ||
136 : | (env, S.S_Return x') | ||
137 : | end | ||
138 : | | S.S_Print xs => let | ||
139 : | val (env, xs') = cvtVars (env, xs) | ||
140 : | in | ||
141 : | (env, S.S_Print xs') | ||
142 : | end | ||
143 : | | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce" | ||
144 : | (* end case *)) | ||
145 : | and cvtExp (env, exp) = (case exp | ||
146 : | of S.E_Var x => let | ||
147 : | val (env, x') = cvtVar (env, x) | ||
148 : | in | ||
149 : | (env, S.E_Var x') | ||
150 : | end | ||
151 : | | S.E_Lit _ => (env, exp) | ||
152 : | | S.E_Select(x, fld) => let | ||
153 : | val (env, x') = cvtVar (env, x) | ||
154 : | in | ||
155 : | (env, S.E_Select(x', fld)) | ||
156 : | end | ||
157 : | | S.E_Apply(f, args, ty) => let | ||
158 : | val (env, args') = cvtVars (env, args) | ||
159 : | in | ||
160 : | (env, S.E_Apply(f, args', ty)) | ||
161 : | end | ||
162 : | | S.E_Prim(f, mvs, args, ty) => let | ||
163 : | val (env, args') = cvtVars (env, args) | ||
164 : | in | ||
165 : | (env, S.E_Prim(f, mvs, args', ty)) | ||
166 : | end | ||
167 : | | S.E_Tensor(args, ty) => let | ||
168 : | val (env, args') = cvtVars (env, args) | ||
169 : | in | ||
170 : | (env, S.E_Tensor(args', ty)) | ||
171 : | end | ||
172 : | | S.E_Seq(args, ty) => let | ||
173 : | val (env, args') = cvtVars (env, args) | ||
174 : | in | ||
175 : | (env, S.E_Seq(args', ty)) | ||
176 : | end | ||
177 : | | S.E_Slice(x, indices, ty) => let | ||
178 : | fun cvt (NONE, (env, idxs)) = (env, NONE::idxs) | ||
179 : | | cvt (SOME x, (env, idxs)) = let | ||
180 : | val (env, x') = cvtVar (env, x) | ||
181 : | in | ||
182 : | (env, SOME x' :: idxs) | ||
183 : | end | ||
184 : | val (env, x') = cvtVar (env, x) | ||
185 : | val (env, indices') = List.foldr cvt (env, []) indices | ||
186 : | in | ||
187 : | (env, S.E_Slice(x', indices', ty)) | ||
188 : | end | ||
189 : | | S.E_Coerce{srcTy, dstTy, x} => let | ||
190 : | val (env, x') = cvtVar (env, x) | ||
191 : | in | ||
192 : | (env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'}) | ||
193 : | end | ||
194 : | | S.E_LoadSeq _ => (env, exp) | ||
195 : | | S.E_LoadImage _ => (env, exp) | ||
196 : | (* end case *)) | ||
197 : | val (env, blk) = cvtBlock (VMap.empty, blk) | ||
198 : | val (args, params) = ListPair.unzip (List.rev (! freeVars)) | ||
199 : | val fnTy = SimpleTypes.T_Fun(List.map SimpleVar.typeOf params, resTy) | ||
200 : | in | ||
201 : | (S.Func{f=mkFuncId(name, fnTy), params=params, body=blk}, args) | ||
202 : | end | ||
203 : | end (* local *) | ||
204 : | |||
205 : | jhr | 3464 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |