Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/mk-low-ir.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/mk-low-ir.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5478 - (view) (download)

1 : jhr 3648 (* mk-low-ir.sml
2 :     *
3 :     * Helper code to build LowIR assigments using the AvailRHS infrastructure.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2016 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure MkLowIR : sig
12 :    
13 : jhr 3653 (* an environment that maps De Bruijn indices to their iteration-index value *)
14 :     type index_env = int IntRedBlackMap.map
15 : jhr 3832
16 : jhr 5474 (* FIXME ??? *)
17 : cchiw 3741 val lookupIdx : int IntRedBlackMap.map * int -> int
18 : jhr 5474 (* FIXME ??? *)
19 : jhr 3745 val lookupMu : int IntRedBlackMap.map * Ein.mu -> int
20 :    
21 : jhr 3648 (* make "x := <int-literal>" *)
22 :     val intLit : AvailRHS.t * IntLit.t -> LowIR.var
23 :     (* make "x := <real-literal>" *)
24 :     val realLit : AvailRHS.t * RealLit.t -> LowIR.var
25 :     (* make "x := <real-literal>", where the real literal is specified as an integer *)
26 :     val intToRealLit : AvailRHS.t * int -> LowIR.var
27 : jhr 3661
28 :     (* generate a reduction sequence using the given binary operator *)
29 :     val reduce : AvailRHS.t * (AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var) * LowIR.var list
30 : jhr 4317 -> LowIR.var
31 : jhr 3661
32 : jhr 3754 (* integer arithmetic *)
33 :     val intAdd : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
34 :    
35 : jhr 3661 (* scalar arithmetic *)
36 :     val realAdd : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
37 :     val realSub : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
38 :     val realMul : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
39 :     val realDiv : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
40 :     val realNeg : AvailRHS.t * LowIR.var -> LowIR.var
41 : cchiw 5241 val realClamp : AvailRHS.t * LowIR.var * LowIR.var * LowIR.var -> LowIR.var
42 : jhr 5296 val realSign : AvailRHS.t * LowIR.var -> LowIR.var
43 :    
44 : jhr 5474 (* scalar math functions *)
45 : jhr 5007 val realAbs : AvailRHS.t * LowIR.var -> LowIR.var
46 : jhr 3746 val realSqrt : AvailRHS.t * LowIR.var -> LowIR.var
47 :     val realCos : AvailRHS.t * LowIR.var -> LowIR.var
48 :     val realArcCos : AvailRHS.t * LowIR.var -> LowIR.var
49 :     val realSin : AvailRHS.t * LowIR.var -> LowIR.var
50 :     val realArcSin : AvailRHS.t * LowIR.var -> LowIR.var
51 :     val realTan : AvailRHS.t * LowIR.var -> LowIR.var
52 :     val realArcTan : AvailRHS.t * LowIR.var -> LowIR.var
53 : jhr 5007 val realExp : AvailRHS.t * LowIR.var -> LowIR.var
54 :     val intPow : AvailRHS.t * LowIR.var * int -> LowIR.var
55 : jhr 3746
56 : jhr 3665 (* vector arithmetic *)
57 :     val vecAdd : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
58 :     val vecSub : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
59 :     val vecScale : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
60 :     val vecMul : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
61 :     val vecNeg : AvailRHS.t * int * LowIR.var -> LowIR.var
62 :     val vecSum : AvailRHS.t * int * LowIR.var -> LowIR.var
63 : jhr 3746 val vecDot : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
64 : jhr 3745
65 : jhr 3665 (* tensor operations *)
66 : jhr 3795 val tensorIndex : AvailRHS.t * index_env * LowIR.var * Ein.alpha -> LowIR.var
67 :     val tensorIndexIX : AvailRHS.t * LowIR.var * int list -> LowIR.var
68 : jhr 3665
69 : jhr 3648 (* make "x := [args]" *)
70 :     val cons : AvailRHS.t * int list * LowIR.var list -> LowIR.var
71 : jhr 3653 (* code for δ_{i,j} *)
72 : cchiw 4289 val delta : AvailRHS.t * index_env * Ein.mu * Ein.mu -> LowIR.var
73 : jhr 3653 (* code for ε_{i,j} *)
74 : cchiw 4289 val epsilon2 : AvailRHS.t * index_env * Ein.mu * Ein.mu -> LowIR.var
75 : jhr 3653 (* code for ε_{i,j,k} *)
76 : cchiw 4289 val epsilon3 : AvailRHS.t * index_env * Ein.mu * Ein.mu * Ein.mu -> LowIR.var
77 : jhr 3648
78 : jhr 3653 (* evaluate δ_{i,j} *)
79 :     val evalDelta : index_env * Ein.mu * Ein.mu -> int
80 :    
81 : jhr 3648 end = struct
82 :    
83 :     structure IR = LowIR
84 :     structure V = IR.Var
85 :     structure Ty = LowTypes
86 :     structure Op = LowOps
87 : jhr 3653 structure E = Ein
88 :     structure IMap = IntRedBlackMap
89 : jhr 3648
90 : jhr 3745 (* an environment that maps De Bruijn indices to their iteration-index value *)
91 :     type index_env = int IMap.map
92 : jhr 3665
93 :     fun lookupIdx (mapp, id) = (case IMap.find(mapp, id)
94 : jhr 4317 of SOME x => x
95 :     | NONE => raise Fail(concat["lookupIdx(_, ", Int.toString id, "): out of bounds"])
96 :     (* end case *))
97 : jhr 3665
98 : jhr 3745 fun lookupMu (mapp, E.V id) = lookupIdx (mapp, id)
99 :     | lookupMu (_, E.C i) = i
100 : jhr 3653
101 : jhr 3660 val add = AvailRHS.addAssign
102 : jhr 3648
103 : jhr 3745 fun intLit (avail, n) = add (avail, "intLit", Ty.intTy, IR.LIT(Literal.Int n))
104 : jhr 3660 fun realLit (avail, r) = add (avail, "realLit", Ty.realTy, IR.LIT(Literal.Real r))
105 : jhr 3648 fun intToRealLit (avail, n) = realLit (avail, RealLit.fromInt(IntInf.fromInt n))
106 : jhr 3661
107 : jhr 3665 fun cons (avail, shp, args) =
108 : jhr 4317 add (avail, "tensor", Ty.TensorTy shp, IR.CONS(args, Ty.TensorTy shp))
109 : jhr 3648
110 : jhr 3665 fun reduce (avail, rator, []) = raise Fail "reduction with no arguments"
111 :     | reduce (avail, rator, arg::args) = let
112 : jhr 4317 fun gen (acc, []) = acc
113 :     | gen (acc, x::xs) = gen (rator (avail, acc, x), xs)
114 :     in
115 :     gen (arg, args)
116 :     end
117 : jhr 3665
118 : jhr 3754 (* integer arithmetic *)
119 :     local
120 :     fun scalarOp2 rator (avail, x, y) = add (avail, "i", Ty.IntTy, IR.OP(rator, [x, y]))
121 :     in
122 :     val intAdd = scalarOp2 Op.IAdd
123 :     end
124 :    
125 : jhr 3661 (* scalar arithmetic *)
126 : jhr 3653 local
127 : jhr 3746 fun scalarOp1 rator (avail, x) = add (avail, "r", Ty.realTy, IR.OP(rator, [x]))
128 : jhr 3745 fun scalarOp2 rator (avail, x, y) = add (avail, "r", Ty.realTy, IR.OP(rator, [x, y]))
129 : cchiw 5241 fun scalarOp3 rator (avail, x, y, z) = add(avail, "t", Ty.realTy, IR.OP(rator, [x, y, z]))
130 : jhr 5007 fun scalarOp1R rator (avail, x) = add (avail, "r", Ty.realTy, IR.OP(rator(Ty.realTy), [x]))
131 : jhr 3661 in
132 :     val realAdd = scalarOp2 Op.RAdd
133 :     val realSub = scalarOp2 Op.RSub
134 :     val realMul = scalarOp2 Op.RMul
135 :     val realDiv = scalarOp2 Op.RDiv
136 : jhr 3746 val realNeg = scalarOp1 Op.RNeg
137 : jhr 5007 val realAbs = scalarOp1R Op.Abs
138 : jhr 5296 val realSign = scalarOp1 Op.Sign
139 : jhr 3746 val realSqrt = scalarOp1 Op.Sqrt
140 : cchiw 3969 val realExp = scalarOp1 Op.Exp
141 : jhr 3746 val realCos = scalarOp1 Op.Cos
142 :     val realArcCos = scalarOp1 Op.ArcCos
143 :     val realSin = scalarOp1 Op.Sin
144 :     val realArcSin = scalarOp1 Op.ArcSin
145 :     val realTan = scalarOp1 Op.Tan
146 :     val realArcTan = scalarOp1 Op.ArcTan
147 : cchiw 5241 val realClamp = scalarOp3 Op.RClamp
148 : jhr 3661 end (* local *)
149 :    
150 : jhr 3665 (* vector arithmetic *)
151 : jhr 3661 local
152 : jhr 3665 fun vecOp1 rator (avail, dim, x) =
153 : jhr 4317 add (avail, "v", Ty.TensorTy[dim], IR.OP(rator dim, [x]))
154 : jhr 3665 fun vecOp2 rator (avail, dim, x, y) =
155 : jhr 4317 add (avail, "v", Ty.TensorTy[dim], IR.OP(rator dim, [x, y]))
156 : jhr 3653 in
157 : jhr 3665 val vecAdd = vecOp2 Op.VAdd
158 :     val vecSub = vecOp2 Op.VSub
159 :     val vecScale = vecOp2 Op.VScale
160 :     val vecMul = vecOp2 Op.VMul
161 :     val vecNeg = vecOp1 Op.VNeg
162 : jhr 3754 fun vecSum (avail, dim, v) = add (avail, "vsm", Ty.realTy, IR.OP(Op.VSum dim, [v]))
163 : jhr 4056 fun vecDot (avail, dim, u, v) = add (avail, "vdot", Ty.realTy, IR.OP(Op.VDot dim, [u, v]))
164 : jhr 3665 end (* local *)
165 : jhr 3745
166 : jhr 5007 (* limits *)
167 : cchiw 4169 fun allConst [E.C 0] = true
168 : jhr 5007 | allConst [E.C 0, E.C 0] = true
169 : cchiw 4169
170 : jhr 3665 fun tensorIndex (avail, mapp, arg, []) = arg
171 : cchiw 4169 | tensorIndex (avail, mapp, arg, ixs) =
172 :     (case (V.ty arg)
173 :     of Ty.TensorTy[] =>
174 :     (* are all the indices constant 0? *)
175 :     (*if(allConst ixs) then arg
176 :     else*) raise Fail "indexing a real arg"
177 : jhr 5007 | _ => add (
178 :     avail, "r", Ty.realTy,
179 :     IR.OP(Op.TensorIndex(V.ty arg, List.map (fn ix => lookupMu(mapp, ix)) ixs), [arg]))
180 : cchiw 4169 (* end case *))
181 : jhr 3665
182 : jhr 3795 fun tensorIndexIX (avail, arg, []) = arg
183 : jhr 3891 | tensorIndexIX (avail, arg, [ix]) = let
184 : jhr 4317 val Ty.TensorTy[d] = V.ty arg
185 :     in
186 :     add (avail, "r", Ty.realTy, IR.OP(Op.VIndex(d, ix), [arg]))
187 :     end
188 : jhr 3891 | tensorIndexIX (avail, arg, ixs) =
189 : jhr 4317 add (avail, "r", Ty.realTy, IR.OP(Op.TensorIndex(V.ty arg, ixs), [arg]))
190 : cchiw 3784
191 : jhr 3653 fun evalDelta (mapp, i, j) = let
192 : jhr 4317 val i' = lookupMu (mapp, i)
193 :     val j' = lookupMu (mapp, j)
194 :     in
195 :     if (i' = j') then 1 else 0
196 :     end
197 : jhr 3653
198 : jhr 3980 fun delta (avail, mapp, i, j) = let
199 : cchiw 4289 val i' = lookupMu (mapp, i)
200 :     val j' = lookupMu (mapp, j)
201 : jhr 3980 in
202 :     if (i' = j') then intToRealLit (avail, 1) else intToRealLit (avail, 0)
203 :     end
204 : jhr 4056
205 : jhr 3653 fun epsilon2 (avail, mapp, i, j) = let
206 : jhr 4317 val i' = lookupMu (mapp, i)
207 :     val j' = lookupMu (mapp, j)
208 :     in
209 :     if (i' = j')
210 :     then intToRealLit (avail, 0)
211 : jhr 3653 else if (j' > i')
212 : jhr 4317 then intToRealLit (avail, 1)
213 :     else intToRealLit (avail, ~1)
214 :     end
215 : jhr 3653
216 :     fun epsilon3 (avail, mapp, i, j, k) = let
217 : jhr 4317 val i' = lookupMu (mapp, i)
218 :     val j' = lookupMu (mapp, j)
219 :     val k' = lookupMu (mapp, k)
220 :     in
221 : jhr 3653 if (i' = j' orelse j' = k' orelse i' = k')
222 : jhr 4317 then intToRealLit (avail, 0)
223 : jhr 3665 else if (j' > i')
224 : jhr 4317 then if (j' > k' andalso k' > i')
225 :     then intToRealLit (avail, ~1)
226 :     else intToRealLit (avail, 1)
227 :     else if (i' > k' andalso k' > j')
228 :     then intToRealLit (avail, 1)
229 :     else intToRealLit (avail, ~1)
230 :     end
231 : jhr 5007
232 : jhr 5474 fun intPow (avail, x, pow_n) = let
233 : jhr 5478 fun pow 0 = add (avail, "_Pow_0", Ty.realTy, IR.LIT(Literal.Real RealLit.one))
234 :     | pow 1 = x
235 :     | pow 2 = add (avail, "_Pow_2", Ty.realTy, IR.OP(Op.RMul, [x, x]))
236 :     | pow n = let
237 : jhr 5007 fun half m = let
238 : jhr 5478 val y = pow (m div 2)
239 : jhr 5007 val name = String.concat["_Pow", Int.toString (m), "_"]
240 :     in
241 :     add (avail, name, Ty.realTy, IR.OP(Op.RMul, [y, y]))
242 :     end
243 :     in
244 :     if ((n mod 2) = 0)
245 :     then half n
246 :     else let
247 :     val y = half (n-1)
248 :     val name = String.concat["_Pow", Int.toString (n), "_"]
249 :     in
250 :     add (avail, name, Ty.realTy, IR.OP(Op.RMul, [x, y]))
251 :     end
252 :     end
253 :     in
254 : jhr 5478 if (pow_n < 0)
255 :     then add (avail, "_PowInv", Ty.realTy,
256 :     IR.OP(Op.RDiv, [
257 :     add (avail, "_One", Ty.realTy, IR.LIT(Literal.Real RealLit.one)),
258 :     pow(~pow_n)
259 :     ]))
260 :     else pow pow_n
261 : jhr 5007 end
262 : jhr 3653
263 : jhr 3648 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0