Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-opt/apply.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-opt/apply.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5574 - (view) (download)

1 : jhr 3515 (* apply.sml
2 :     *
3 :     * Apply EIN operator arguments to EIN operator.
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure Apply : sig
12 :    
13 : jhr 5570 val apply : Ein.ein * int * Ein.ein * HighIR.var list * HighIR.var list -> Ein.ein option
14 : jhr 3521
15 : jhr 3515 end = struct
16 :    
17 :     structure E = Ein
18 :    
19 : jhr 3521 structure IMap = IntRedBlackMap
20 : jhr 3515
21 : jhr 3521 fun mapId (i, dict, shift) = (case IMap.find(dict, i)
22 : jhr 4317 of NONE => i + shift
23 : jhr 3521 | SOME j => j
24 : jhr 4317 (* end case *))
25 : jhr 3515
26 : jhr 3521 fun mapIndex (ix, dict, shift) = (case IMap.find(dict, ix)
27 :     of NONE => E.V(ix + shift)
28 :     | SOME j => j
29 :     (* end case *))
30 : jhr 3515
31 : jhr 3521 fun mapId2 (i, dict, shift) = (case IMap.find(dict, i)
32 :     of NONE => (
33 : jhr 4317 print(concat["Error: ", Int.toString i, " is out of range\n"]);
34 :     i+shift)
35 : jhr 3521 | SOME j => j
36 :     (* end case *))
37 : jhr 3515
38 : jhr 5570 fun rewriteSubst (e, subId, mx, paramShift, sumShift, newArgs, done) = let
39 : jhr 4317 fun insertIndex ([], _, dict, shift) = (dict, shift)
40 :     | insertIndex (e::es, n, dict, _) = let
41 :     val shift = (case e of E.V ix => ix - n | E.C i => i - n)
42 :     in
43 :     insertIndex(es, n+1, IMap.insert(dict, n, e), shift)
44 :     end
45 :     val (subMu, shift) = insertIndex(mx, 0, IMap.empty, 0)
46 :     val shift' = Int.max(sumShift, shift)
47 : cchiw 5006 val insideComp = ref(false)
48 : jhr 5007 fun mapMu (E.V i) = if (!insideComp)
49 : jhr 5215 then E.V i
50 :     else mapIndex(i, subMu, shift')
51 : jhr 5007 | mapMu c = c
52 : jhr 3521 fun mapAlpha mx = List.map mapMu mx
53 :     fun mapSingle i = let
54 : jhr 4317 val E.V v = mapIndex(i, subMu, shift')
55 : jhr 3521 in
56 :     v
57 :     end
58 : cchiw 3978 fun mapSum l = List.map (fn (a, b, c) => (mapSingle a, b, c)) l
59 : jhr 5570 fun mapParam id = let
60 :     val vA = List.nth(newArgs, id)
61 :     fun iter ([], _) = mapId2(id, subId, 0)
62 :     | iter (e1::es, n) = if (HighIR.Var.same(e1, vA)) then n else iter(es, n+1)
63 :     in
64 :     iter (done@newArgs, 0)
65 :     end
66 : jhr 3521 fun apply e = (case e
67 : jhr 4317 of E.Const _ => e
68 :     | E.ConstR _ => e
69 :     | E.Tensor(id, mx) => E.Tensor(mapParam id, mapAlpha mx)
70 : cchiw 4555 | E.Zero(mx) => E.Zero(mapAlpha mx)
71 : jhr 4317 | E.Delta(i, j) => E.Delta(mapMu i, mapMu j)
72 :     | E.Epsilon(i, j, k) => E.Epsilon(mapMu i, mapMu j, mapMu k)
73 :     | E.Eps2(i, j) => E.Eps2(mapMu i,mapMu j)
74 :     | E.Field(id, mx) => E.Field(mapParam id, mapAlpha mx)
75 :     | E.Lift e1 => E.Lift(apply e1)
76 :     | E.Conv (v, mx, h, ux) => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)
77 :     | E.Partial mx => E.Partial (mapAlpha mx)
78 :     | E.Apply(e1, e2) => E.Apply(apply e1, apply e2)
79 :     | E.Probe(f, pos) => E.Probe(apply f, apply pos)
80 :     | E.Value _ => raise Fail "expression before expand"
81 :     | E.Img _ => raise Fail "expression before expand"
82 :     | E.Krn _ => raise Fail "expression before expand"
83 : jhr 5570 | E.OField(E.CFExp es, e2,dx) => let
84 : jhr 5574 val es = List.map (fn (id, inputTy) => (mapParam id, inputTy)) es
85 : jhr 5570 val e2 = apply e2
86 :     val dx = apply dx
87 :     in
88 :     E.OField(E.CFExp es, e2,dx)
89 :     end
90 : jhr 4317 | E.Sum(c, esum) => E.Sum(mapSum c, apply esum)
91 :     | E.Op1(op1, e1) => E.Op1(op1, apply e1)
92 :     | E.Op2(op2, e1, e2) => E.Op2(op2, apply e1, apply e2)
93 : cchiw 5241 | E.Op3(op3, e1, e2, e3) => E.Op3(op3, apply e1, apply e2, apply e3)
94 : jhr 4317 | E.Opn(opn, e1) => E.Opn(opn, List.map apply e1)
95 :     (* end case *))
96 : jhr 3521 in
97 : jhr 4317 apply e
98 : jhr 3521 end
99 : jhr 3515
100 : jhr 3521 (* params subst *)
101 :     fun rewriteParams (params, params2, place) = let
102 : jhr 4317 val beg = List.take(params, place)
103 :     val next = List.drop(params, place+1)
104 :     val params' = beg@params2@next
105 :     val n= length params
106 :     val n2 = length params2
107 :     val nbeg = length beg
108 :     val nnext = length next
109 :     fun createDict (0, shift1, shift2, dict) = dict
110 :     | createDict (n, shift1, shift2, dict) =
111 :     createDict (n-1, shift1, shift2, IMap.insert (dict, n+shift1, n+shift2))
112 :     val origId = createDict (nnext, place, place+n2-1, IMap.empty)
113 :     val subId = createDict (n2, ~1, place-1, IMap.empty)
114 :     in
115 :     (params', origId, subId, nbeg)
116 :     end
117 : jhr 3515
118 : jhr 3535 (* Looks for params id that match substitution *)
119 : jhr 5570 fun apply (e1 as E.EIN{params, index, body}, place, e2, newArgs, done) = let
120 : jhr 4317 val E.EIN{params=params2, index=index2, body=body2} = e2
121 :     val changed = ref false
122 :     val (params', origId, substId, paramShift) = rewriteParams(params, params2, place)
123 : jhr 3521 val sumIndex = ref(length index)
124 :     fun rewrite (id, mx, e) = let
125 :     val x = !sumIndex
126 : jhr 5007 in
127 : jhr 3521 if (id = place)
128 : jhr 4317 then if (length mx = length index2)
129 :     then (
130 :     changed := true;
131 : jhr 5570 rewriteSubst (body2, substId, mx, paramShift, x, newArgs, done))
132 : jhr 4317 else raise Fail "argument/parameter mismatch"
133 :     else (case e
134 :     of E.Tensor(id, mx) => E.Tensor(mapId(id, origId, 0), mx)
135 :     | E.Field(id, mx) => E.Field(mapId(id, origId, 0), mx)
136 :     | _ => raise Fail "term to be replaced is not a Tensor or Fields"
137 :     (* end case *))
138 : jhr 3521 end
139 : jhr 4317 fun sumI e = let val (v,_,_) = List.last e in v end
140 :     fun apply b = (case b
141 :     of E.Tensor(id, mx) => rewrite (id, mx, b)
142 :     | E.Field(id, mx) => rewrite (id, mx, b)
143 : cchiw 4555 | E.Zero(mx) => b
144 : jhr 4317 | E.Lift e1 => E.Lift(apply e1)
145 :     | E.Conv(v, mx, h, ux) => E.Conv(mapId(v, origId, 0), mx, mapId(h, origId, 0), ux)
146 :     | E.Apply(e1, e2) => E.Apply(apply e1, apply e2)
147 :     | E.Probe(f, pos) => E.Probe(apply f, apply pos)
148 :     | E.Value _ => raise Fail "expression before expand"
149 :     | E.Img _ => raise Fail "expression before expand"
150 :     | E.Krn _ => raise Fail "expression before expand"
151 : jhr 5570 | E.OField(E.CFExp es, e2, E.Partial alpha) => let
152 :     val ps = List.map (fn (id, inputTy) => (mapId(id, origId, 0), inputTy)) es
153 :     in
154 :     E.OField(E.CFExp ps, apply e2, E.Partial alpha)
155 :     end
156 :     | E.Poly _ => raise Fail "expression before expand"
157 : jhr 4821 | E.Sum(indices, esum) => let
158 :     val (ix, _, _) = List.last indices
159 :     in
160 :     sumIndex := ix;
161 :     E.Sum(indices, apply esum)
162 :     end
163 : jhr 4317 | E.Op1(op1, e1) => E.Op1(op1, apply e1)
164 :     | E.Op2(op2, e1, e2) => E.Op2(op2, apply e1, apply e2)
165 : cchiw 5241 | E.Op3(op3, e1, e2, e3) => E.Op3(op3, apply e1, apply e2, apply e3)
166 : jhr 4317 | E.Opn(opn, es) => E.Opn(opn, List.map apply es)
167 :     | _ => b
168 :     (* end case *))
169 :     val body'' = apply body
170 :     in
171 : jhr 4821 if (! changed)
172 :     then SOME(E.EIN{params=params', index=index, body=body''})
173 :     else NONE
174 : jhr 4317 end
175 : jhr 3515
176 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0