Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/float-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/float-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5570 - (view) (download)

1 : jhr 3556 (* float-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 : jhr 3555
9 : jhr 3556 structure FloatEin : sig
10 :    
11 : jhr 3582 val transform : MidIR.var * Ein.ein * MidIR.var list -> MidIR.assign list
12 : jhr 3555
13 :     end = struct
14 :    
15 :     structure IR = MidIR
16 : jhr 3556 structure V = IR.Var
17 :     structure Ty = MidTypes
18 : jhr 3555 structure E = Ein
19 :    
20 : jhr 4229 fun cut (name, origProbe, params, index, sx, argsOrig, avail, newvx) = let
21 : jhr 4317 (* clean and rewrite current body *)
22 : jhr 4229 (* DEBUG val _ = print(String.concat["\n\n impact cut"]) *)
23 : jhr 4317 val (tshape1, sizes1, body1) = CleanIndex.clean (origProbe, index, sx)
24 :     val id = length params
25 :     val Rparams = params@[E.TEN(true, sizes1)]
26 :     val y = V.new (concat[name, "_l_", Int.toString id], Ty.tensorTy sizes1)
27 :     val IR.EINAPP(ein, args) = CleanParams.clean (body1, Rparams, sizes1, argsOrig@[y])
28 :     (* shift indices in probe body from constant to variable *)
29 :     val Ein.EIN{
30 :     body = E.Probe(E.Conv(V, [c1], h, dx), pos),
31 :     index = index0,
32 :     params = params0
33 :     } = ein
34 :     (* only called with vector fields*)
35 :     val E.IMG(dim,[i]) = List.nth(params0,V)
36 :     val index1 = index0@[i]
37 :     val unshiftedBody = E.Probe(E.Conv(V, [E.V newvx], h, dx), pos)
38 :     (* clean to get body indices in order *)
39 :     val (_ , sizes2, body2) = CleanIndex.clean (unshiftedBody, index1, [])
40 :     val ein2 = E.EIN{params = params0, index = sizes2, body = body2}
41 : jhr 4229 (* DEBUG val _ = print(String.concat["\n\n ein2:",EinPP.toString(ein2)]) *)
42 : jhr 4317 val lhs = AvailRHS.addAssign (avail, "L", Ty.tensorTy sizes2, IR.EINAPP(ein2, args))
43 :     val Rargs = argsOrig @ [lhs]
44 :     (*Probe that tensor at a constant position c1*)
45 :     val nx = List.mapi (fn (i, _) => E.V i) dx
46 :     val Re = E.Tensor(id, c1 :: tshape1)
47 :     val Rparams = params @ [E.TEN(true, sizes2)]
48 :     in
49 :     (Re, Rparams, Rargs)
50 :     end
51 : jhr 3555
52 : jhr 4229 (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
53 :     * lifts expression and returns replacement tensor
54 : jhr 3555 * cleans the index and params of subexpression
55 : jhr 4229 * creates new param and replacement tensor for the original ein_exp
56 : jhr 3555 *)
57 :     fun lift (name, e, params, index, sx, args, avail) = let
58 : jhr 4317 val (tshape, sizes, body) = CleanIndex.clean(e, index, sx)
59 :     val id = length params
60 :     val Re = E.Tensor(id, tshape)
61 :     val einapp = CleanParams.clean (body, params, sizes, args)
62 :     val lhs = AvailRHS.addAssign (
63 :     avail,
64 :     concat[name, "_l_", Int.toString id], Ty.tensorTy sizes,
65 :     CleanParams.clean (body, params, sizes, args))
66 :     in
67 :     (Re, params @ [E.TEN(true, sizes)], args @ [lhs])
68 :     end
69 : jhr 3555
70 :     fun isOp e = (case e
71 : jhr 4317 of E.Op1 _ => true
72 :     | E.Op2 _ => true
73 : cchiw 5241 | E.Op3 _ => true
74 : jhr 4317 | E.Opn _ => true
75 :     | E.Sum _ => true
76 :     | E.Probe _ => true
77 :     | _ => false
78 :     (* end case *))
79 : jhr 3555
80 :     fun transform (y, ein as Ein.EIN{body=E.Probe _, ...}, args) =
81 : jhr 4317 [(y, IR.EINAPP(ein, args))]
82 : jhr 3555 | transform (y, ein as Ein.EIN{body=E.Sum(_, E.Probe _), ...}, args) =
83 : jhr 4317 [(y, IR.EINAPP(ein, args))]
84 : jhr 3555 | transform (y, Ein.EIN{params, index, body}, args) = let
85 : jhr 4317 val avail = AvailRHS.new()
86 :     fun filterOps (es, params, args, index, sx) = let
87 :     fun filter ([], es', params, args) = (rev es', params, args)
88 :     | filter (e::es, es', params, args) = if isOp e
89 :     then let
90 :     val (e', params', args') = lift("op1_e3", e, params, index, sx, args, avail)
91 :     in
92 :     filter (es, e'::es', params', args')
93 :     end
94 :     else filter (es, e::es', params, args)
95 :     in
96 :     filter (es, [], params, args)
97 :     end
98 :     fun rewrite (sx, exp, params, args) = (case exp
99 :     of E.Probe(E.Conv(_, [E.C _], _, []), _) =>
100 :     cut ("cut", exp, params, index, sx, args, avail, 0)
101 :     | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0]), _) =>
102 :     cut ("cut", exp, params, index, sx, args, avail, 1)
103 :     | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1]), _) =>
104 :     cut ("cut", exp, params, index, sx, args, avail, 2)
105 :     | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1, E.V 2]), _) =>
106 :     cut ("cut", exp, params, index, sx, args, avail, 3)
107 :     | E.Probe _ => lift ("probe", exp, params, index, sx, args, avail)
108 : jhr 5570 | E.OField _ => lift ("probe", exp, params, index, sx, args, avail)
109 : jhr 4317 | E.Sum(_, E.Probe _) => lift ("probe", exp, params, index, sx, args, avail)
110 :     | E.Op1(op1, e1) => let
111 :     val (e1', params', args') = rewrite (sx, e1, params, args)
112 :     val ([e1], params', args') = filterOps ([e1'], params', args', index, sx)
113 :     in
114 :     (E.Op1(op1, e1), params', args')
115 :     end
116 :     | E.Op2(op2, e1, e2) => let
117 :     val (e1', params', args') = rewrite (sx, e1, params, args)
118 :     val (e2', params', args') = rewrite (sx, e2, params', args')
119 :     val ([e1', e2'], params', args') =
120 :     filterOps ([e1', e2'], params', args', index, sx)
121 :     in
122 :     (E.Op2(op2, e1', e2'), params', args')
123 :     end
124 : cchiw 5241 | E.Op3(op3, e1, e2, e3) => let
125 :     val (e1', params', args') = rewrite (sx, e1, params, args)
126 :     val (e2', params', args') = rewrite (sx, e2, params', args')
127 :     val (e3', params', args') = rewrite (sx, e3, params', args')
128 :     val ([e1', e2', e3'], params', args') =
129 :     filterOps ([e1', e2', e3'], params', args', index, sx)
130 :     in
131 :     (E.Op3(op3, e1', e2', e3'), params', args')
132 :     end
133 : jhr 4317 | E.Opn(opn, es) => let
134 :     fun iter ([], es, params, args) = (List.rev es, params, args)
135 :     | iter (e::es, es', params, args) = let
136 :     val (e', params', args') = rewrite (sx, e, params, args)
137 :     in
138 :     iter (es, e'::es', params', args')
139 :     end
140 :     val (es, params, args) = iter (es, [], params, args)
141 :     val (es, params, args) = filterOps (es, params, args, index, sx)
142 :     in
143 :     (E.Opn(opn, es), params, args)
144 :     end
145 :     | E.Sum(sx1, e) => let
146 :     val (e', params', args') = rewrite (sx1@sx, e, params, args)
147 :     in
148 :     (E.Sum(sx1, e'), params', args')
149 :     end
150 :     | _ => (exp, params, args)
151 :     (* end case *))
152 :     val (body', params', args') = rewrite ([], body, params, args)
153 : jhr 3576 val einapp = CleanParams.clean (body', params', index, args')
154 : jhr 4317 in
155 :     List.rev ((y, einapp) :: AvailRHS.getAssignments avail)
156 :     end
157 : jhr 3555
158 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0