Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/helper-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/helper-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3624 - (view) (download)

1 : jhr 3624 (* helper-set.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 : cchiw 3553 (*Helper functions*)
10 : cchiw 3602 structure Helper = struct
11 : cchiw 3444 local
12 :    
13 : cchiw 3602 structure IL = LowIL
14 :     structure Ty = LowILTypes
15 :     structure Op = LowOps
16 : cchiw 3444 structure Var = LowIL.Var
17 :     structure E = Ein
18 :     structure IMap = IntRedBlackMap
19 :     in
20 :    
21 : cchiw 3602 fun err str = raise Fail (str)
22 : cchiw 3444 val empty = IMap.empty
23 : cchiw 3602 fun lookup k d = IMap.find (d, k)
24 :     fun insert (k, v) d = IMap.insert (d, k, v)
25 :     fun find (v, mapp) = (case IMap.find (mapp, v)
26 :     of NONE => raise Fail (concat["Outside Bound (", Int.toString v, ") "])
27 : cchiw 3444 | SOME s => s
28 : cchiw 3553 (* end *) )
29 : cchiw 3444
30 : cchiw 3553 (*mapIndex:E.mu * dict-> int
31 : cchiw 3444 * lookup
32 : cchiw 3553 *)
33 : cchiw 3602 fun mapIndex (e1, mapp) = (case e1
34 :     of E.V e => find (e, mapp)
35 : cchiw 3444 | E.C c => c
36 : cchiw 3553 (*end case*) )
37 : cchiw 3602
38 :     (* *************************** IL.LIT **************************** *)
39 : cchiw 3553 (* mkINt:int->Var*code list*)
40 : cchiw 3602 fun mkInt (avail, n) = let
41 :     val lhs = IL.Var.new ("Int", Ty.intTy)
42 :     val rhs = IL.LIT (Literal.Int (IntInf.fromInt n))
43 : cchiw 3542 in
44 : cchiw 3602 (avail,AvailRHS.addAssign avail (lhs, rhs))
45 : cchiw 3542 end
46 : cchiw 3602 fun mkReal (avail, n) = let
47 :     val lhs = IL.Var.new ("real", Ty.TensorTy [])
48 :     val rhs = IL.LIT (Literal.Int (IntInf.fromInt n) )
49 : cchiw 3542 in
50 : cchiw 3602 (avail,AvailRHS.addAssign avail (lhs, rhs))
51 : cchiw 3542 end
52 : cchiw 3602 (* *************************** IL.CONS **************************** *)
53 :     fun assgnCons (avail, pre, shape, args) = let
54 :     val ty = Ty.TensorTy shape
55 :     val lhs = IL.Var.new ("cons"^"_", ty)
56 :     val rhs = IL.CONS (ty, args)
57 : cchiw 3444 in
58 : cchiw 3602 (avail,AvailRHS.addAssign avail (lhs, rhs))
59 : cchiw 3444 end
60 :    
61 : cchiw 3602 (* *************************** IL.OP **************************** *)
62 :     fun assignOP (avail, opss, args, pre, ty) = let
63 :     val lhs = IL.Var.new (pre, ty)
64 :     val rhs = IL.OP (opss, args)
65 : cchiw 3444 in
66 : cchiw 3602 (avail,AvailRHS.addAssign avail (lhs, rhs))
67 : cchiw 3444 end
68 : cchiw 3602 (* *************************** Op.IndexTensor **************************** *)
69 : cchiw 3553 (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
70 : cchiw 3602 * Integer, or Generic Tensor
71 : cchiw 3553 *)
72 : cchiw 3602 fun getTensorTy (params, id) = case List.nth (params, id)
73 :     of E.TEN (3, [shape]) => Ty.iVecTy (shape) (*FIX HERE*)
74 :     | E.TEN (_, shape) => Ty.TensorTy shape
75 :     |_ => err "NONE Tensor Param"
76 :    
77 : cchiw 3553 (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
78 : cchiw 3444 * ->Var*code list
79 :     * Index Tensor at specific indices to give a scalar result
80 : cchiw 3553 *)
81 : cchiw 3602 fun indexTensor (avail, _, (params, args, id, [], ty) ) = (avail, List.nth (args, id))
82 :     | indexTensor (_, _, ( params, args, id, [_, _, _], Ty.TensorTy [_, _, _, _] ) ) = raise Fail "uneven"
83 :     | indexTensor (avail, mapp, ( params, args, id, ix, ty) ) = let
84 :     val ixx = (List.map (fn (e1) => mapIndex (e1, mapp) ) ix)
85 :     val argTy = getTensorTy (params, id)
86 :     val opp = Op.IndexTensor (id, ixx, argTy)
87 :     val nU = List.nth (args, id)
88 :     val name = String.concat["Indx_",String.concat (List.map Int.toString ixx), "_"]
89 : cchiw 3444 in
90 : cchiw 3602 assignOP (avail, opp, [nU], name, ty)
91 : cchiw 3444 end
92 : cchiw 3602 (* *************************** Op._ Shortcuts **************************** *)
93 :     fun mkAddSca (avail, args) = assignOP (avail, Op.addSca, args, "addSca", Ty.TensorTy [])
94 :     fun mkAddInt (avail, args) = assignOP (avail, Op.addSca, args, "addInt", Ty.intTy)
95 :     fun mkAddPtr (avail, args, ty) = assignOP (avail, Op.addSca, args, "addPtr", ty)
96 :     fun mkAddVec (avail, vecIX, args) = assignOP (avail, Op.addVec vecIX, args, "addV", Ty.TensorTy ([vecIX]) )
97 :     fun mkSubSca (avail, args) = assignOP (avail, Op.subSca, args, "subSca", Ty.TensorTy [])
98 :     fun mkProdSca (avail, args) = assignOP (avail, Op.prodSca, args, "prodSca", Ty.TensorTy [])
99 :     fun mkProdInt (avail, args) = assignOP (avail, Op.prodSca, args, "prodInt", Ty.intTy)
100 :     fun mkProdVec (avail, vecIX, args) = assignOP (avail, Op.prodVec vecIX, args, "prodV", Ty.TensorTy ([vecIX]) )
101 :     fun mkDivSca (avail, args) = assignOP (avail, Op.divSca, args, "divSca", Ty.TensorTy [])
102 :     fun mkSumVec (avail, vecIX, args) = assignOP (avail, Op.sumVec vecIX, args, "sumVec", Ty.TensorTy [])
103 :     (* *************************** Op. Other **************************** *)
104 :     fun mkDotVec (avail, vecIX, args) = let
105 :     val (avail, vD) = mkProdVec (avail, vecIX, args)
106 :     in mkSumVec (avail, vecIX, [vD]) end
107 :     fun intToReal (avail, n) = let
108 :     val (avail, vC) = mkReal (avail, n)
109 :     in assignOP (avail, Op.IntToReal, [vC], "cast", Ty.TensorTy []) end
110 :     fun mkPowInt ((avail, nU), nn) = let
111 :     fun pow (1, avail) = (avail, nU)
112 :     | pow (2, avail) = let
113 :     val opp = Op.prodSca
114 : cchiw 3553 val name = String.concat["_Pow2_"]
115 : cchiw 3602 in assignOP (avail, opp, [nU, nU], name, Ty.intTy) end
116 :     | pow (n, avail) = let
117 : cchiw 3553 fun half m = let
118 : cchiw 3602 val (avail, vB) = pow (m div 2, avail)
119 :     val opp = Op.prodSca
120 :     val name = String.concat["_Pow", Int.toString (m), "_"]
121 :     in assignOP (avail, opp, [vB, vB], name, Ty.intTy) end
122 :     in if ((n mod 2) = 0)
123 : cchiw 3444 then half n
124 :     else let
125 : cchiw 3602 val (avail, vC) = half (n-1)
126 :     val opp = Op.prodSca
127 :     val name = String.concat["_Pow", Int.toString (n), "_"]
128 :     in assignOP (avail, opp, [nU, vC], name, Ty.intTy) end
129 : cchiw 3444 end
130 :    
131 :     in
132 : cchiw 3602 pow (nn, avail)
133 : cchiw 3444 end
134 : cchiw 3602
135 :     fun mkOp1 (E.PowInt n, e) = mkPowInt (e, n)
136 :     | mkOp1 (t, e) = let
137 :     fun mkSingle (opp, name, (avail, nU)) = assignOP (avail, opp, [nU], name, Ty.TensorTy [])
138 : cchiw 3553 val opp = (case t
139 : cchiw 3602 of E.Cosine => Op.Cosine
140 :     | E.ArcCosine => Op.ArcCosine
141 :     | E.Sine => Op.Sine
142 :     | E.ArcSine => Op.ArcSine
143 :     | E.Tangent => Op.Tangent
144 :     | E.ArcTangent => Op.ArcTangent
145 :     | E.Sqrt => Op.Sqrt
146 :     | E.Exp => Op.Exp
147 : cchiw 3553 (*end case*) )
148 : cchiw 3602 in mkSingle (opp, "_op1_", e) end
149 :    
150 : cchiw 3553 (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*code list
151 : cchiw 3444 *apply rator between each items on list1
152 : cchiw 3553 *)
153 : cchiw 3602 fun mkMultiple (availM, list1, rator, ty) = let
154 :     fun add (avail, [], _) = err"no element in mkMultiple"
155 :     | add (avail, [e1], _) = (avail, e1)
156 :     | add (avail, [e1, e2], _) = assignOP (avail, rator, [e1, e2], "mult", ty)
157 :     | add (avail, e1::e2::es, count) = let
158 :     val (avail, vA) = assignOP (avail, rator, [e1, e2], String.concat["mult", Int.toString count], ty)
159 :     in add (avail, vA::es, count-1)
160 : cchiw 3444 end
161 :     in
162 : cchiw 3602 add (availM, list1, List.length list1)
163 : cchiw 3444 end
164 : cchiw 3602 (* *************************** Op. Greek **************************** *)
165 : cchiw 3553 (* deltaToInt:dict*E.mu*E.mu->int
166 : cchiw 3444 * delta function
167 : cchiw 3553 *)
168 : cchiw 3602 fun deltaToInt (mapp, a, b) = let
169 :     val i = mapIndex (a, mapp)
170 :     val j = mapIndex (b, mapp)
171 : cchiw 3553 in if (i = j) then 1 else 0 end
172 : cchiw 3602 fun evalDelta (avail, mapp, a, b) = intToReal (avail, deltaToInt (mapp, a, b))
173 : cchiw 3553 (*eval Epsilon-2d*)
174 : cchiw 3602 fun evalEps2 (avail, mapp, a, b) = let
175 :     val i = mapIndex (E.V a, mapp)
176 :     val j = mapIndex (E.V b, mapp)
177 :     in if (i = j) then intToReal (avail, 0)
178 : cchiw 3444 else
179 : cchiw 3602 if (j>i) then intToReal (avail, 1)
180 :     else intToReal (avail, ~1)
181 : cchiw 3444 end
182 : cchiw 3553 (*eval Epsilon-3d*)
183 : cchiw 3602 fun evalEps3 (avail, mapp, a, b, c) = let
184 :     val i = mapIndex (E.V a, mapp)
185 :     val j = mapIndex (E.V b, mapp)
186 :     val k = mapIndex (E.V c, mapp)
187 : cchiw 3444 in
188 : cchiw 3602 if (i = j orelse j = k orelse i = k) then intToReal (avail, 0)
189 : cchiw 3553 else if (j>i)
190 : cchiw 3602 then if (j>k andalso k>i) then intToReal (avail, ~1) else intToReal (avail, 1)
191 :     else if (i>k andalso k>j) then intToReal (avail, 1) else intToReal (avail, ~1)
192 : cchiw 3444 end
193 : cchiw 3602 fun evalG (avail, mapp, b) = (case b
194 :     of E.Epsilon (i, j, k) => evalEps3 (avail, mapp, i, j, k)
195 :     | E.Eps2 (i, j) => evalEps2 (avail, mapp, i, j)
196 :     | E.Delta (i, j) => evalDelta (avail, mapp, i, j)
197 : cchiw 3553 (*end case*) )
198 : cchiw 3444 end
199 :    
200 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0