Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-scalar.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-scalar.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5570 - (view) (download)

1 : jhr 3646 (* ein-to-scalar.sml
2 :     *
3 :     * Generate LowIR scalar computations that implement Ein expressions.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2016 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure EinToScalar : sig
12 :    
13 : jhr 3728 (* expand a scalar-valued Ein operator application to LowIR code; return the LowIR
14 :     * vaiable that holds the result of the application (the assignments will be added
15 :     * to avail).
16 :     *)
17 :     val expand : {
18 : jhr 4317 avail : AvailRHS.t, (* the generated LowIR assignments *)
19 :     mapp : int IntRedBlackMap.map, (* mapping from deBruijn indices to argument IDs *)
20 :     body : Ein.ein_exp, (* the EIN operator body *)
21 :     lowArgs : LowIR.var list (* corresponding LowIR arguments *)
22 :     } -> LowIR.var
23 : jhr 3646
24 :     end = struct
25 :    
26 :     structure IR = LowIR
27 :     structure Ty = LowTypes
28 :     structure Op = LowOps
29 :     structure Var = LowIR.Var
30 :     structure E = Ein
31 : jhr 3648 structure Mk = MkLowIR
32 : jhr 3646 structure IMap = IntRedBlackMap
33 :    
34 : jhr 3653 fun mapIndex (mapp, id) = (case IMap.find(mapp, id)
35 : jhr 4317 of SOME x => x
36 :     | NONE => raise Fail(concat["mapIndex(_, V ", Int.toString id, "): out of bounds"])
37 :     (* end case *))
38 : jhr 3649
39 : jhr 5570 fun unwrapPoly (mapp, E.Poly(t, n, dx)) = let
40 :     val zero = E.Const 0
41 :     val one = E.Const 1
42 :     val ec = E.Const n
43 :     val ecc = E.Const(n-1)
44 :     val eccc = E.Const(n-2)
45 :     val ndel = length dx
46 :     fun mkProd es = E.Opn(E.Prod, es)
47 :     (* Take a certain number of derivatives*)
48 :     fun getDel () =
49 :     if (n < ndel) then zero
50 :     else (case (ndel, n)
51 :     of (0, _) => mkProd (List.tabulate (n, fn _ => t))
52 :     | (1, 1) => one
53 :     | (1, _) => mkProd (ec :: List.tabulate (n-1, fn _ => t))
54 :     | (2, 2) => ec
55 :     | (2, _) => mkProd (ec :: ecc :: List.tabulate (n-2, fn _ => t))
56 :     | (3, 3) => mkProd ([ec, ecc])
57 :     | (3, _) => mkProd (ec :: ecc :: eccc :: List.tabulate (n-3, fn _ => t))
58 :     | _ => raise Fail"add more cases for derivative"
59 :     (* end case *))
60 :     val E.Tensor(_, shape) = t
61 :     in
62 :     case (shape, dx)
63 :     of ([],_) => getDel ()
64 :     | ([E.C c], _) => let
65 :     fun iter [] = getDel ()
66 :     | iter (vx::es) = if (Mk.lookupMu(mapp, vx) = c) then iter es else zero
67 :     in
68 :     iter dx
69 :     end
70 :     (* end case *)
71 :     end
72 :    
73 : jhr 3746 fun expand {avail, mapp, body, lowArgs} = let
74 : jhr 4317 fun gen (mapp, body) = let
75 :     (*********sumexpression ********)
76 :     fun tb n = List.tabulate (n, fn e => e)
77 :     fun sumCheck (mapp, (v, lb, ub) :: sumx, e) = let
78 :     fun sumloop mapp = gen (mapp, e)
79 :     fun sumI1 (left, (v, [i], lb1), [], rest) = let
80 :     val mapp = IMap.insert (left, v, lb1+i)
81 :     val vD = gen (mapp, e)
82 :     in
83 :     rest@[vD]
84 :     end
85 :     | sumI1 (left, (v, i::es, lb1), [], rest) = let
86 :     val mapp = IMap.insert (left, v, i+lb1)
87 :     val vD = gen (mapp, e)
88 :     in
89 :     sumI1 (mapp, (v, es, lb1), [], rest@[vD])
90 :     end
91 :     | sumI1 (left, (v, [i], lb1), (a, lb2, ub2) ::sx, rest) =
92 :     sumI1 (IMap.insert (left, v, lb1+i), (a, tb (ub2-lb2+1), lb2), sx, rest)
93 :     | sumI1 (left, (v, s::es, lb1), (a, lb2, ub2) ::sx, rest) = let
94 :     val mapp = IMap.insert (left, v, s+lb1)
95 :     val xx = tb (ub2-lb2+1)
96 :     val rest' = sumI1 (mapp, (a, xx, lb2), sx, rest)
97 :     in
98 :     sumI1 (mapp, (v, es, lb1), (a, lb2, ub2) ::sx, rest')
99 :     end
100 :     | sumI1 _ = raise Fail "None Variable-index in summation"
101 :     in
102 : jhr 3745 sumI1 (mapp, (v, tb (ub-lb+1), lb), sumx, [])
103 : jhr 4317 end
104 :     in
105 :     case body
106 :     of E.Value v => Mk.intToRealLit (avail, mapIndex (mapp, v))
107 :     | E.Const c => Mk.intToRealLit (avail, c)
108 :     | E.Delta(i, j) => Mk.delta (avail, mapp, i, j)
109 :     | E.Epsilon(i, j, k) => Mk.epsilon3 (avail, mapp, i, j, k)
110 :     | E.Eps2(i, j) => Mk.epsilon2 (avail, mapp, i, j)
111 :     | E.Tensor(id, ix) => Mk.tensorIndex (avail, mapp, List.nth(lowArgs, id), ix)
112 : jhr 5570 | E.Zero _ => Mk.intToRealLit (avail, 0)
113 :     | E.Poly _ => gen(mapp, unwrapPoly(mapp, body))
114 : jhr 4317 | E.Op1(op1, e1) => let
115 :     val arg = gen (mapp, e1)
116 :     in
117 :     case op1
118 :     of E.Neg => Mk.realNeg (avail, arg)
119 :     | E.Sqrt => Mk.realSqrt (avail, arg)
120 :     | E.Cosine => Mk.realCos (avail, arg)
121 :     | E.ArcCosine => Mk.realArcCos (avail, arg)
122 :     | E.Sine => Mk.realSin (avail, arg)
123 :     | E.ArcSine => Mk.realArcSin (avail, arg)
124 :     | E.Tangent => Mk.realTan (avail, arg)
125 :     | E.ArcTangent => Mk.realArcTan (avail, arg)
126 : cchiw 3969 | E.Exp => Mk.realExp (avail, arg)
127 : jhr 4317 | E.PowInt n => Mk.intPow (avail, arg, n)
128 : cchiw 5006 | E.Abs => Mk.realAbs(avail, arg)
129 : jhr 5296 | E.Sgn => Mk.realSign(avail, arg)
130 : jhr 5007 (* end case *)
131 : jhr 4317 end
132 :     | E.Op2(E.Sub, e1, e2) => Mk.realSub (avail, gen (mapp, e1), gen (mapp, e2))
133 : jhr 5258 | E.Op3(E.Clamp, e1, e2, e3) =>
134 :     Mk.realClamp(avail, gen(mapp, e1), gen(mapp, e2), gen(mapp, e3))
135 : jhr 4317 | E.Opn(E.Add, es) =>
136 :     Mk.reduce (avail, Mk.realAdd, List.map (fn e => gen(mapp, e)) es)
137 :     | E.Opn(E.Prod, es) =>
138 :     Mk.reduce (avail, Mk.realMul, List.map (fn e => gen(mapp, e)) es)
139 :     | E.Op2(E.Div, e1 as E.Tensor (_, [_]), e2 as E.Tensor (_, [])) =>
140 :     gen (mapp, E.Opn(E.Prod, [E.Op2 (E.Div, E.Const 1, e2), e1]))
141 :     | E.Op2(E.Div, e1, e2) => Mk.realDiv (avail, gen (mapp, e1), gen (mapp, e2))
142 :     | E.Sum(sx, E.Opn(E.Prod, (img as E.Img _) :: (kargs as (E.Krn _ :: _)))) =>
143 :     FieldToLow.expand {
144 :     avail = avail, mapp = mapp,
145 :     sx = sx, img = img, krnargs = kargs,
146 :     args = lowArgs
147 :     }
148 :     | E.Sum(sumx, e) =>
149 :     Mk.reduce (avail, Mk.realAdd, sumCheck (mapp, sumx, e))
150 :     | E.Probe(E.Epsilon e1, e2) => gen(mapp,E.Epsilon e1)
151 :     | E.Probe(E.Eps2 e1, e2) => gen(mapp,E.Eps2 e1)
152 :     | E.Probe(E.Const e1, e2) => gen(mapp, E.Const e1)
153 :     | E.Probe(E.Delta e1, e2) => gen(mapp, E.Delta e1)
154 :     | E.Probe e => raise Fail("probe ein-exp: " ^ EinPP.expToString body)
155 :     | E.Field _ => raise Fail("field should have been replaced: " ^ EinPP.expToString body)
156 :     | _ => raise Fail("unsupported ein-exp: " ^ EinPP.expToString body)
157 :     (*end case*)
158 :     end
159 :     in
160 :     gen (mapp, body)
161 :     end
162 : jhr 3646
163 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0