Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/mk-low-ir.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/mk-low-ir.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3969 - (view) (download)

1 : jhr 3648 (* mk-low-ir.sml
2 :     *
3 :     * Helper code to build LowIR assigments using the AvailRHS infrastructure.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2016 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure MkLowIR : sig
12 :    
13 : jhr 3653 (* an environment that maps De Bruijn indices to their iteration-index value *)
14 :     type index_env = int IntRedBlackMap.map
15 : jhr 3832
16 : jhr 3745 (* ??? *)
17 : cchiw 3741 val lookupIdx : int IntRedBlackMap.map * int -> int
18 : jhr 3745 (* ??? *)
19 :     val lookupMu : int IntRedBlackMap.map * Ein.mu -> int
20 :    
21 : jhr 3648 (* make "x := <int-literal>" *)
22 :     val intLit : AvailRHS.t * IntLit.t -> LowIR.var
23 :     (* make "x := <real-literal>" *)
24 :     val realLit : AvailRHS.t * RealLit.t -> LowIR.var
25 :     (* make "x := <real-literal>", where the real literal is specified as an integer *)
26 :     val intToRealLit : AvailRHS.t * int -> LowIR.var
27 : jhr 3661
28 :     (* generate a reduction sequence using the given binary operator *)
29 :     val reduce : AvailRHS.t * (AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var) * LowIR.var list
30 :     -> LowIR.var
31 :    
32 : jhr 3754 (* integer arithmetic *)
33 :     val intAdd : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
34 :    
35 : jhr 3661 (* scalar arithmetic *)
36 :     val realAdd : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
37 :     val realSub : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
38 :     val realMul : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
39 :     val realDiv : AvailRHS.t * LowIR.var * LowIR.var -> LowIR.var
40 :     val realNeg : AvailRHS.t * LowIR.var -> LowIR.var
41 :    
42 : jhr 3746 (* scalar math functions *)
43 :     val realSqrt : AvailRHS.t * LowIR.var -> LowIR.var
44 :     val realCos : AvailRHS.t * LowIR.var -> LowIR.var
45 :     val realArcCos : AvailRHS.t * LowIR.var -> LowIR.var
46 :     val realSin : AvailRHS.t * LowIR.var -> LowIR.var
47 :     val realArcSin : AvailRHS.t * LowIR.var -> LowIR.var
48 :     val realTan : AvailRHS.t * LowIR.var -> LowIR.var
49 :     val realArcTan : AvailRHS.t * LowIR.var -> LowIR.var
50 : cchiw 3969 val realExp : AvailRHS.t * LowIR.var -> LowIR.var
51 :     val intPow : AvailRHS.t * LowIR.var * int -> LowIR.var
52 : jhr 3746
53 : jhr 3665 (* vector arithmetic *)
54 :     val vecAdd : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
55 :     val vecSub : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
56 :     val vecScale : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
57 :     val vecMul : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
58 :     val vecNeg : AvailRHS.t * int * LowIR.var -> LowIR.var
59 :     val vecSum : AvailRHS.t * int * LowIR.var -> LowIR.var
60 : jhr 3746 val vecDot : AvailRHS.t * int * LowIR.var * LowIR.var -> LowIR.var
61 : jhr 3745
62 : jhr 3665 (* tensor operations *)
63 : jhr 3795 val tensorIndex : AvailRHS.t * index_env * LowIR.var * Ein.alpha -> LowIR.var
64 :     val tensorIndexIX : AvailRHS.t * LowIR.var * int list -> LowIR.var
65 : jhr 3665
66 : jhr 3648 (* make "x := [args]" *)
67 :     val cons : AvailRHS.t * int list * LowIR.var list -> LowIR.var
68 : jhr 3653 (* code for δ_{i,j} *)
69 :     val delta : AvailRHS.t * index_env * Ein.mu * Ein.mu -> LowIR.var
70 :     (* code for ε_{i,j} *)
71 :     val epsilon2 : AvailRHS.t * index_env * Ein.index_id * Ein.index_id -> LowIR.var
72 :     (* code for ε_{i,j,k} *)
73 :     val epsilon3 : AvailRHS.t * index_env * Ein.index_id * Ein.index_id * Ein.index_id -> LowIR.var
74 : jhr 3648
75 : jhr 3653 (* evaluate δ_{i,j} *)
76 :     val evalDelta : index_env * Ein.mu * Ein.mu -> int
77 :    
78 : jhr 3648 end = struct
79 :    
80 :     structure IR = LowIR
81 :     structure V = IR.Var
82 :     structure Ty = LowTypes
83 :     structure Op = LowOps
84 : jhr 3653 structure E = Ein
85 :     structure IMap = IntRedBlackMap
86 : jhr 3648
87 : jhr 3745 (* an environment that maps De Bruijn indices to their iteration-index value *)
88 :     type index_env = int IMap.map
89 : jhr 3665
90 :     fun lookupIdx (mapp, id) = (case IMap.find(mapp, id)
91 :     of SOME x => x
92 : jhr 3745 | NONE => raise Fail(concat["lookupIdx(_, ", Int.toString id, "): out of bounds"])
93 : jhr 3665 (* end case *))
94 :    
95 : jhr 3745 fun lookupMu (mapp, E.V id) = lookupIdx (mapp, id)
96 :     | lookupMu (_, E.C i) = i
97 : jhr 3653
98 : jhr 3660 val add = AvailRHS.addAssign
99 : jhr 3648
100 : jhr 3745 fun intLit (avail, n) = add (avail, "intLit", Ty.intTy, IR.LIT(Literal.Int n))
101 : jhr 3660 fun realLit (avail, r) = add (avail, "realLit", Ty.realTy, IR.LIT(Literal.Real r))
102 : jhr 3648 fun intToRealLit (avail, n) = realLit (avail, RealLit.fromInt(IntInf.fromInt n))
103 : jhr 3661
104 : jhr 3665 fun cons (avail, shp, args) =
105 :     add (avail, "tensor", Ty.TensorTy shp, IR.CONS(args, Ty.TensorTy shp))
106 : jhr 3648
107 : jhr 3665 fun reduce (avail, rator, []) = raise Fail "reduction with no arguments"
108 :     | reduce (avail, rator, arg::args) = let
109 :     fun gen (acc, []) = acc
110 :     | gen (acc, x::xs) = gen (rator (avail, acc, x), xs)
111 :     in
112 :     gen (arg, args)
113 :     end
114 :    
115 : jhr 3754 (* integer arithmetic *)
116 :     local
117 :     fun scalarOp2 rator (avail, x, y) = add (avail, "i", Ty.IntTy, IR.OP(rator, [x, y]))
118 :     in
119 :     val intAdd = scalarOp2 Op.IAdd
120 :     end
121 :    
122 : jhr 3661 (* scalar arithmetic *)
123 : jhr 3653 local
124 : jhr 3746 fun scalarOp1 rator (avail, x) = add (avail, "r", Ty.realTy, IR.OP(rator, [x]))
125 : jhr 3745 fun scalarOp2 rator (avail, x, y) = add (avail, "r", Ty.realTy, IR.OP(rator, [x, y]))
126 : jhr 3661 in
127 :     val realAdd = scalarOp2 Op.RAdd
128 :     val realSub = scalarOp2 Op.RSub
129 :     val realMul = scalarOp2 Op.RMul
130 :     val realDiv = scalarOp2 Op.RDiv
131 : jhr 3746 val realNeg = scalarOp1 Op.RNeg
132 :     val realSqrt = scalarOp1 Op.Sqrt
133 : cchiw 3969 val realExp = scalarOp1 Op.Exp
134 : jhr 3746 val realCos = scalarOp1 Op.Cos
135 :     val realArcCos = scalarOp1 Op.ArcCos
136 :     val realSin = scalarOp1 Op.Sin
137 :     val realArcSin = scalarOp1 Op.ArcSin
138 :     val realTan = scalarOp1 Op.Tan
139 :     val realArcTan = scalarOp1 Op.ArcTan
140 : cchiw 3969
141 : jhr 3661 end (* local *)
142 :    
143 : jhr 3665 (* vector arithmetic *)
144 : jhr 3661 local
145 : jhr 3665 fun vecOp1 rator (avail, dim, x) =
146 :     add (avail, "v", Ty.TensorTy[dim], IR.OP(rator dim, [x]))
147 :     fun vecOp2 rator (avail, dim, x, y) =
148 :     add (avail, "v", Ty.TensorTy[dim], IR.OP(rator dim, [x, y]))
149 : jhr 3653 in
150 : jhr 3665 val vecAdd = vecOp2 Op.VAdd
151 :     val vecSub = vecOp2 Op.VSub
152 :     val vecScale = vecOp2 Op.VScale
153 :     val vecMul = vecOp2 Op.VMul
154 :     val vecNeg = vecOp1 Op.VNeg
155 : jhr 3754 fun vecSum (avail, dim, v) = add (avail, "vsm", Ty.realTy, IR.OP(Op.VSum dim, [v]))
156 : jhr 3665 end (* local *)
157 : jhr 3745
158 : cchiw 3741 fun vecDot (avail, vecIX, a, b) =
159 : jhr 3745 vecSum (avail, vecIX, vecMul (avail, vecIX, a, b))
160 : jhr 3665
161 :     fun tensorIndex (avail, mapp, arg, []) = arg
162 : jhr 3891 | tensorIndex (avail, mapp, arg, [ix]) = let
163 :     val Ty.TensorTy[d] = V.ty arg
164 :     in
165 :     add (
166 :     avail, "r", Ty.realTy,
167 :     IR.OP(Op.VIndex(d, lookupMu(mapp, ix)), [arg]))
168 :     end
169 :     | tensorIndex (avail, mapp, arg, ixs) =
170 : jhr 3665 add (
171 : jhr 3745 avail, "r", Ty.realTy,
172 : jhr 3891 IR.OP(Op.TensorIndex(V.ty arg, List.map (fn ix => lookupMu(mapp, ix)) ixs), [arg]))
173 : jhr 3665
174 : jhr 3795 fun tensorIndexIX (avail, arg, []) = arg
175 : jhr 3891 | tensorIndexIX (avail, arg, [ix]) = let
176 :     val Ty.TensorTy[d] = V.ty arg
177 :     in
178 :     add (avail, "r", Ty.realTy, IR.OP(Op.VIndex(d, ix), [arg]))
179 :     end
180 :     | tensorIndexIX (avail, arg, ixs) =
181 :     add (avail, "r", Ty.realTy, IR.OP(Op.TensorIndex(V.ty arg, ixs), [arg]))
182 : cchiw 3784
183 : jhr 3653 fun evalDelta (mapp, i, j) = let
184 :     val i' = lookupMu (mapp, i)
185 :     val j' = lookupMu (mapp, j)
186 :     in
187 :     if (i' = j') then 1 else 0
188 :     end
189 :    
190 :     fun delta (avail, mapp, i, j) = intToRealLit (avail, evalDelta (mapp, i, j))
191 :    
192 :     fun epsilon2 (avail, mapp, i, j) = let
193 :     val i' = lookupIdx (mapp, i)
194 :     val j' = lookupIdx (mapp, j)
195 :     in
196 :     if (i' = j')
197 : jhr 3665 then intToRealLit (avail, 0)
198 : jhr 3653 else if (j' > i')
199 : jhr 3665 then intToRealLit (avail, 1)
200 :     else intToRealLit (avail, ~1)
201 : jhr 3653 end
202 :    
203 :     fun epsilon3 (avail, mapp, i, j, k) = let
204 :     val i' = lookupIdx (mapp, i)
205 :     val j' = lookupIdx (mapp, j)
206 :     val k' = lookupIdx (mapp, k)
207 :     in
208 :     if (i' = j' orelse j' = k' orelse i' = k')
209 : jhr 3665 then intToRealLit (avail, 0)
210 :     else if (j' > i')
211 : jhr 3653 then if (j' > k' andalso k' > i')
212 :     then intToRealLit (avail, ~1)
213 : jhr 3665 else intToRealLit (avail, 1)
214 : jhr 3653 else if (i' > k' andalso k' > j')
215 :     then intToRealLit (avail, 1)
216 : jhr 3665 else intToRealLit (avail, ~1)
217 : jhr 3653 end
218 : cchiw 3969 fun intPow(avail, x, pow_n) = let
219 :     fun pow (1, avail) = x
220 :     | pow (2, avail) = add (avail, "_Pow_2", Ty.realTy, IR.OP(Op.RMul, [x, x]))
221 :     | pow (n, avail) = let
222 :     fun half m = let
223 :     val y = pow (m div 2, avail)
224 :     val name = String.concat["_Pow", Int.toString (m), "_"]
225 :     in add (avail, name, Ty.realTy, IR.OP(Op.RMul, [y, y])) end
226 :     in if ((n mod 2) = 0)
227 :     then half n
228 :     else let
229 :     val y = half (n-1)
230 :     val name = String.concat["_Pow", Int.toString (n), "_"]
231 :     in add (avail, name, Ty.realTy, IR.OP(Op.RMul, [x, y])) end
232 :     end
233 :     in
234 :     pow (pow_n, avail)
235 :     end
236 : jhr 3653
237 : jhr 3648 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0