Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/lift-ein.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/lift-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3382, Sat Nov 7 03:51:29 2015 UTC revision 3383, Mon Nov 9 02:39:26 2015 UTC
# Line 23  Line 23 
23      fun cleanParams e = cleanP.cleanParams e      fun cleanParams e = cleanP.cleanParams e
24      fun cleanIndex e = cleanI.cleanIndex e      fun cleanIndex e = cleanI.cleanIndex e
25      fun toStringBind e= MidToString.toStringBind e      fun toStringBind e= MidToString.toStringBind e
26    fun mkEin e=Ein.mkEin e
27    fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
28      fun itos i = Int.toString i      fun itos i = Int.toString i
29      fun err str = raise Fail str      fun err str = raise Fail str
30      val cnt = ref 0      val cnt = ref 0
# Line 35  Line 37 
37          end          end
38      fun testp n=(case testing      fun testp n=(case testing
39          of 0=> 1          of 0=> 1
40          | _ =>(print(String.concat n);1)          | _ =>( (String.concat n);1)
41          (*end case*))          (*end case*))
42    
43    
44        fun cut(name,e,params,index,sx,argsOrig,fieldset,cntinplace,cntlift,newvx) =let
45    
46            val _ = "\ncutting"
47    val _ = (String.concat["\nto intput to cutting :",Int.toString(length(params)),
48    "args:",Int.toString(length(argsOrig))])
49            (*clean and rewrite current body*)
50            val (tshape,sizes,body)=cleanIndex(e,index,sx)
51            val id=length(params)
52            val Rparams=params@[E.TEN(1,sizes)]
53            val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
54            val Rargs=argsOrig@[M]
55            val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
56            val _= " past first clean Params"
57    
58            (*shift indices in probe body from constant to variable*)
59            val (y, DstIL.EINAPP(ein,args))=einapp
60            val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
61            val index0=Ein.index ein
62            val index1 = index0@[3]
63            val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
64    
65            (* clean to get body indices in order *)
66            val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
67            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
68            val _ = "before first cleanParam"
69    
70    
71            val ein1 = mkEin(Ein.params ein,index1,body1)
72            val code1= (lhs1,mkEinApp(ein1,args))
73            val (fieldset',lhs0,cntinplace',cntlift')=(case LiftSet.rtnVarN(fieldset,code1)
74                of(fieldset,NONE)=>  (fieldset,lhs1,cntinplace+1,cntlift)
75                | (fieldset,SOME v)=>(fieldset,v,cntinplace,cntlift+1)
76                (*end case*))
77    
78    
79            (*Probe that tensor at a constant position  c1*)
80            val param0 = [E.TEN(1,index1)]
81            val nx=List.tabulate(length(dx),fn n=>E.V n)
82            val Re =  E.Tensor(id,[c1]@tshape)
83            val Rparams=params@param0
84            val Rargs=argsOrig@[lhs1]
85    
86    
87            val newbies=[code1]
88    val _ = (String.concat["\nretuning from cutting :",Int.toString(length(Rparams)),
89    "args:",Int.toString(length(Rargs))])
90               val _= " past cut "
91    
92        in
93             (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')
94        end
95    
96    
97      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
98      *lifts expression and returns replacement tensor      *lifts expression and returns replacement tensor
99      * cleans the index and params of subexpression      * cleans the index and params of subexpression
100      *creates new param and replacement tensor for the original ein_exp      *creates new param and replacement tensor for the original ein_exp
101      *)      *)
102      fun lift(name,e,params,index,sx,args,fieldset,cntinplace,cntlift)=let      fun lift(name,e,params,index,sx,args,fieldset,cntinplace,cntlift)=let
103          val _ = "\n ****** start lift***********\n"          val _ = " \n in side lift"
104    val _ = (String.concat["\nto intput to lift :",Int.toString(length(params)),
105    "args:",Int.toString(length(args))])
106    
107          val (tshape,sizes,body)=cleanIndex(e,index,sx)          val (tshape,sizes,body)=cleanIndex(e,index,sx)
108          val id=length(params)          val id=length(params)
109          val Rparams=params@[E.TEN(1,sizes)]          val Rparams=params@[E.TEN(1,sizes)]
110          val Re=E.Tensor(id,tshape)          val Re=E.Tensor(id,tshape)
111          val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)          val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
112          val Rargs=args@[M]          val Rargs=args@[M]
113    val _ = (String.concat["\nto cleanParams:",Int.toString(length(Rparams)),
114    "args:",Int.toString(length(Rargs))])
115          val einapp=cleanParams(M,body,Rparams,sizes,Rargs)          val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
116          val (_,einapp0)=einapp          val (_,einapp0)=einapp
117    
118          val (Rargs,newbies,fieldset',cntinplace',cntlift') =(case numFlag          val (Rargs,newbies,fieldset',cntinplace',cntlift') =(case numFlag
119              of 1=> let              of 1=> let
120                  val (fieldset',var) = einSet.rtnVar(fieldset,M,einapp0)                  val (fieldset',var) = LiftSet.rtnVar(fieldset,M,einapp0)
121                  in (case var                  in (case var
122                      of NONE=> (args@[M],[einapp],fieldset',cntinplace+1,cntlift)                      of NONE=> let
123                            val MidIL.EINAPP(ein0,arg0) =einapp0
124                            in (args@[M],[einapp],fieldset',cntinplace+1,cntlift)
125                            end
126                      | SOME v=> (args@[v],[],fieldset',cntinplace,cntlift+1)                      | SOME v=> (args@[v],[],fieldset',cntinplace,cntlift+1)
127                      (*end case*))                      (*end case*))
128                  end                  end
129              | _=>(args@[M],[einapp],fieldset,cntinplace,cntlift)              | _=>(args@[M],[einapp],fieldset,cntinplace,cntlift)
130              (*end case*))              (*end case*))
131          val _=(String.concat["\n in place",Int.toString(cntinplace'),"- l",Int.toString(cntlift')])          val _ = " \n out side lift"
             val _="\n *********end lift*********\n"  
132          in          in
133              (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')              (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')
134          end          end
135    
136    
137    (*
138    fun ff(name,e,params,index,sx,args,fieldset,cntinplace,cntlift,newvx)=let
139        val (tshape,sizes,body)=cleanIndex(e,index,sx)
140        in (case body
141    of    E.Probe(E.Conv(_,[E.C _ ],_,[]),pos)
142    => liftFieldVec (0,e,fieldset)
143    | E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))
144    => liftFieldVec (1,e,fieldset)
145    | E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))
146    => liftFieldVec (2,e,fieldset)
147    | E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))
148    => liftFieldVec (3,e,fieldset)
149    
150    *)
151      fun liftfields(y,DstIL.EINAPP(ein0,args0))=let      fun liftfields(y,DstIL.EINAPP(ein0,args0))=let
152          val sx = ref []          val sx = ref []
153          val index=Ein.index ein0          val index=Ein.index ein0
# Line 126  Line 204 
204                  in                  in
205                  (E.ArcSine e1',data')                  (E.ArcSine e1',data')
206                  end                  end
207    
208            | E.Probe(E.Conv(_,[E.C _ ],_,[]),pos)=> let
209                val (params,args,code,fieldset,cntinplace,cntlift)=data
210                val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,0)
211    
212                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
213                in
214                (body',data')
215                end
216    | E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0]),pos)=> let
217    val (params,args,code,fieldset,cntinplace,cntlift)=data
218    val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,1)
219    
220    val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
221    in
222    (body',data')
223    end
224    | E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0,E.V 1]),pos)=> let
225    val (params,args,code,fieldset,cntinplace,cntlift)=data
226    val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,2)
227    
228    val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
229    in
230    (body',data')
231    end
232    
233    | E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0,E.V 1,E.V 2]),pos)=> let
234    val (params,args,code,fieldset,cntinplace,cntlift)=data
235    val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,3)
236    
237    val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
238    in
239    (body',data')
240    end
241    | E.Probe(E.Conv(_,[E.C _ ],_,dx),pos)=> let
242    val (params,args,code,fieldset,cntinplace,cntlift)=data
243    val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,length(dx))
244    
245    val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
246    in
247    (body',data')
248    end
249    
250    
251    
252              | E.Probe _=> let              | E.Probe _=> let
253                  val (params,args,code,fieldset,cntinplace,cntlift)=data                  val (params,args,code,fieldset,cntinplace,cntlift)=data
254                  val (body',params',args',code',fieldset',cntinplace',cntlift')=lift("probe",b,params,index,(!sx),args,fieldset,cntinplace,cntlift)                  val (body',params',args',code',fieldset',cntinplace',cntlift')=lift("probe",b,params,index,(!sx),args,fieldset,cntinplace,cntlift)
# Line 198  Line 321 
321            | scan(E.Sum(sx,E.Probe p) ,data)=(E.Sum(sx,E.Probe p) ,data)            | scan(E.Sum(sx,E.Probe p) ,data)=(E.Sum(sx,E.Probe p) ,data)
322            | scan e = rewrite e            | scan e = rewrite e
323    
324         val fieldset= einSet.EinSet.empty         val fieldset= LiftSet.LiftSet.empty
325          val data= (Ein.params ein0,args0,[],fieldset,0,0)          val data= (Ein.params ein0,args0,[],fieldset,0,0)
326          val (body',data')=scan(Ein.body ein0,data)          val (body',data')=scan(Ein.body ein0,data)
327          val (params',args',code',fieldset',cntinplace',cntlift')=data'          val (params',args',code',fieldset',cntinplace',cntlift')=data'
328          val k=(toStringBind (y,DstIL.EINAPP(Ein.EIN{params=params',index=index,body=body'},args')))          val k=(toStringBind (y,DstIL.EINAPP(Ein.EIN{params=params',index=index,body=body'},args')))
329          val _=(String.concat["\n in place",Int.toString(cntinplace'),"- l",Int.toString(cntlift')])          val _= print(String.concat["\n in place",Int.toString(cntinplace'),"- l",Int.toString(cntlift')])
330            val _ = "\n **last clean params"
331          val einapp= cleanParams(y,body',params',index,args')          val einapp= cleanParams(y,body',params',index,args')
332          in          in
333              (einapp,code')              (einapp,code')
# Line 211  Line 335 
335    
336    
337      fun testLift e1=  let      fun testLift e1=  let
338          val _ =print"\nUses LIFT"          val _ = "\nUses LIFT"
339          val (einapp1,e2)=liftfields e1          val (einapp1,e2)=liftfields e1
340          val n=length(e2)          val n=length(e2)
341          val _ =(case n          val _ =(case n

Legend:
Removed from v.3382  
changed lines
  Added in v.3383

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0