Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/float-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/float-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3555 - (view) (download)

1 : jhr 3555 structure Float : sig
2 :    
3 :     val transform : MidIR.var * Ein.ein * MidIR.var list -> MidIR.assignment list
4 :    
5 :     end = struct
6 :    
7 :     structure IR = MidIR
8 :     structure E = Ein
9 :    
10 :     fun cut (name, origProbe, params, index, sx, argsOrig, avail, newvx) = let
11 :     (*clean and rewrite current body*)
12 :     val (tshape, sizes, body) = cleanIndex (origProbe, index, sx)
13 :     val id = length params
14 :     val Rparams = params@[E.TEN(true, sizes)]
15 :     val M = DstV.new (concat[name, "_l_", itos id], DstTy.TensorTy sizes)
16 :     val (y, DstIL.EINAPP(ein, args)) = cleanParams(M, body, Rparams, sizes, argsOrig@[M])
17 :     (*shift indices in probe body from constant to variable*)
18 :     val E.Probe(E.Conv(V, [c1], h, dx), pos) = Ein.body ein
19 :     val index0 = Ein.index ein
20 :     (* FIXME: this code is specialized to 3D *)
21 :     val index1 = index0@[3]
22 :     val body1_unshifted = E.Probe(E.Conv(V, [E.V newvx], h, dx), pos)
23 :     (* clean to get body indices in order *)
24 :     val (_ , _, body1) = cleanIndex.cleanIndex (body1_unshifted, index1, [])
25 :     val lhs1 = DstV.new ("L", DstTy.TensorTy index1)
26 :     val ein1 = mkEin(Ein.params ein,index1,body1)
27 :     val lhs2 = AvailRHS.addAssign avail (lhs1, mkEinApp(ein1, args))
28 :     val Rargs = argsOrig @ [lhs2]
29 :     (*Probe that tensor at a constant position c1*)
30 :     val nx = List.mapi (fn (i, _) => E.V i) dx
31 :     val Re = E.Tensor(id, c1 :: tshape)
32 :     val Rparams = params @ [E.TEN(true, index1)]
33 :     in
34 :     (Re, Rparams, Rargs)
35 :     end
36 :    
37 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
38 :     *lifts expression and returns replacement tensor
39 :     * cleans the index and params of subexpression
40 :     *creates new param and replacement tensor for the original ein_exp
41 :     *)
42 :     fun lift (name, e, params, index, sx, args, avail) = let
43 :     val (tshape, sizes, body) = cleanIndex(e, index, sx)
44 :     val id = length params
45 :     val Rparams = params @ [E.TEN(true, sizes)]
46 :     val Re = E.Tensor(id, tshape)
47 :     val M = DstV.new (concat[name, "_l_", itos id], DstTy.TensorTy sizes)
48 :     val (_, einapp) = cleanParams (M, body, Rparams, sizes, args @ [M])
49 :     val var = AvailRHS.add avail (M, einapp)
50 :     val Rargs = argsOrig @ [var]
51 :     in
52 :     (Re, Rparams, Rargs)
53 :     end
54 :    
55 :     fun isOp e = (case e
56 :     of E.Op1 _ => true
57 :     | E.Op2 _ => true
58 :     | E.Opn _ => true
59 :     | E.Sum _ => true
60 :     | E.Probe _ => true
61 :     | _ => false
62 :     (* end case *))
63 :    
64 :     fun transform (y, ein as Ein.EIN{body=E.Probe _, ...}, args) =
65 :     [IR.ASSGN(y, IR.EINAPP(ein, args))]
66 :     | transform (y, ein as Ein.EIN{body=E.Sum(_, E.Probe _), ...}, args) =
67 :     [IR.ASSGN(y, IR.EINAPP(ein, args))]
68 :     | transform (y, Ein.EIN{params, index, body}, args) = let
69 :     val avail = AvailRHS.new()
70 :     fun filterOps (es, params, args, index, sx) = let
71 :     fun filter ([], es', params, args) = (rev es', params, args)
72 :     | filter (e::es, es', params, args) = if isOp e
73 :     then let
74 :     val (e', params', args') = lift("op1_e3", e, params, index, sx, args, avail)
75 :     in
76 :     filterOps (es, e'::es', params', args')
77 :     end
78 :     else filter (es, e::es', params, args)
79 :     in
80 :     filter (es, [], params, args)
81 :     end
82 :     fun rewrite (sx, exp, params, args) = (case exp
83 :     of E.Probe(E.Conv(_, [E.C _], _, []), _) =>
84 :     cut ("cut", exp, params, index, sx, args, avail, 0)
85 :     | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0]), _) =>
86 :     cut ("cut", exp, params, index, sx, args, avail, 1)
87 :     | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1]), _) =>
88 :     cut ("cut", exp, params, index, sx, args, avail, 2)
89 :     | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1, E.V 2]), _) =>
90 :     cut ("cut", exp, params, index, sx, args, avail, 3)
91 :     | E.Probe _ => lift ("probe", exp, params, index, sx, args, avail)
92 :     | E.Sum(_, E.Probe _) => lift ("probe", exp, params, index, sx, args, avail)
93 :     | E.Op1(op1, e1) => let
94 :     val (e1', params', args') = rewrite (sx, e1, params, args)
95 :     val ([e1], params', args') = filterOps ([e1'], params', args', index, sx)
96 :     in
97 :     (E.Op1(op1, e1), params', args')
98 :     end
99 :     | E.Op2(op2, e1, e2) => let
100 :     val (e1', params', args') = rewrite (sx, e1, params, args)
101 :     val (e2', params', args') = rewrite (sx, e2, params', args')
102 :     val ([e1', e2'], params', args') =
103 :     filterOps ([e1', e2'], params', args', index, sx)
104 :     in
105 :     (E.Op2(op2, e1', e2'), params', args')
106 :     end
107 :     | E.Opn(opn, es) => let
108 :     fun iter ([], es, params, args) = (List.rev es, params, args)
109 :     | iter (e::es, es', params, args) = let
110 :     val (e', params', args') = rewrite (e, params, args)
111 :     in
112 :     iter (es, e'::es', params', args')
113 :     end
114 :     val (es, params, args) = iter (es, [], params, args)
115 :     val (es, params, args) = filterOps (es, params, args, index, sx)
116 :     in
117 :     (E.Opn(opn, es), params, args)
118 :     end
119 :     | E.Sum(sx1, e) => let
120 :     val (e', params', args') = rewrite (sx1@sx, e, params, args)
121 :     in
122 :     (E.Sum(sx1, e'), params', args')
123 :     end
124 :     | _ => (exp, params, args)
125 :     (* end case *))
126 :     val (body', params', args') = rewrite ([], body, params, args)
127 :     val einapp = cleanParams (y, body', params', index, args')
128 :     in
129 :     List.rev (IR.ASSGN(y, einapp) :: AvailEin.getAssignments avail)
130 :     end
131 :    
132 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0