Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / ProbeEin.sml

# Diff of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

revision 2610, Fri May 2 18:31:56 2014 UTC revision 2611, Mon May 5 21:21:12 2014 UTC
# Line 35  Line 35
35      structure shift=ShiftEin      structure shift=ShiftEin
36      structure split=SplitEin      structure split=SplitEin
37      structure F=Filter      structure F=Filter
38        structure T=TransformEin
39
40  val testing=1  val testing=1
41
# Line 57  Line 58
58      "but found ", SrcIL.vbToString vb])      "but found ", SrcIL.vbToString vb])
59      (* end case *))      (* end case *))
60
(*Transform differentiation index to world-space*)
(*Returns new deltas, summations, and List of Tensor Products*)
fun ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let

val dim'=dim-1
fun setMatrix(imgix,wrdix) = E.Tensor(1,[ E.V imgix, E.V wrdix])

fun changeDel([],n,newdels,sx,rest)=(newdels,sx,rest)
| changeDel((E.C _)::es,_,_,_,_)= raise Fail "unsure what to do with constant differentiation"
| changeDel((E.V v)::es,n,newdels,sx,rest)= let
val P=setMatrix(v, n)
val n'=E.V n
in
changeDel(es,n+1,newdels@[n'],sx@[(n',0,dim')],rest@[P])
end

val n=length(outerShape)
val (newdels,sx,rest)=changeDel(dels,n,[],[],[])
in
(newdels,sx,rest)
end

(*Transformation is Lifted*)
fun ImgtoWorldSpaceLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let
val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)
val tshape=List.tabulate((length(alpha)),fn v=> E.V v)

val newbie'=Ein.EIN{
params=[E.TEN(1,outerShape), E.TEN(1,[dim,dim])],
index=outerShape,
body=E.Sum(sx,E.Prod([E.Tensor(0,tshape@newdels)]@rest))
}

val _ = print(String.concat["\n Transform \n ",(split.printA (newArg,newbie',[oldArg,PArg])) ,"\n"])
val data=assignEin (newArg, newbie', [oldArg,PArg])
val ix=List.tabulate((length(dels)),fn _=> dim)
in
(newdels,ix,[data],[],[],[])
end

fun ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,id)= let
val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)
val params=[E.TEN(id,[dim,dim])]
in
(newdels,[],[],params,sx,rest)
end

(*E.Sum(sx,rest) *)

(*returns final argumentVar, new dels, and assignments*)
fun decideIfTransform(dels,outerShape,alpha,dim,PArg,Pid,lift)=let
val oldArg = DstV.new ("ProbeResult", DstTy.tensorTy outerShape)
in (case (dels,lift)
of ([],_) => (oldArg,oldArg,dels,[],[],[],[],[])
| (_,0) => let  (*need to transform to world-space*)
val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)
val (dels',ix',assigments,_,_,_)=ImgtoWorldSpaceLift(dels,outerShape,alpha,dim,oldArg,newArg,PArg)
in
(oldArg,newArg,dels', ix',assigments,[],[],[])
end
| ( _,_) =>  let
val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)
val (dels',_,_,params,sx,rest)=ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,Pid)
in (oldArg, oldArg, dels',[],[],params,sx,rest)
end
(*end case*))
end
61
62
63  (*Create fractional, and integer position vectors*)  (*Create fractional, and integer position vectors*)
# Line 236  Line 166
166      | mapIndex(E.C c::es,index) = mapIndex(es,index)      | mapIndex(E.C c::es,index) = mapIndex(es,index)
167
168
169    (*Lift probe and Multiply by P*)
170    fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let
171        val _ =print "Lift Probe"
172
(*Currently have three different Functions. One that Lifts Field and Transormation. One that just one lift, and the other does it all in place*)
(*Lift probe and does transformation Lifted*)
fun liftProbeTester(b,(params,args),index, sumIndex,origargs)=let

val E.Probe(E.Conv(V,alpha,H,dels),E.Tensor(t,_))=b
val newId=length(params)
173      val n=length(index)      val n=length(index)

(*Create new tensor replacement*)
val shape=ShapeConv(alpha@dels, n)
val newB=E.Tensor(newId,shape)

(* Create new Param*)
(*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)
val shape'= mapIndex(shape,index)
val newP= E.TEN(1,shape')

val A1=ShapeConv(alpha, n)
val alpha1= mapIndex(A1,index)

(*Expand Probe*)
174      val ns=length sumIndex      val ns=length sumIndex
175        val nshift=length(dx)
176      val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)      val np=length(params)
177      val (oldArg,newArg,dx, ix,assigments,paramsA,sxA,restA) = decideIfTransform(dels,shape',alpha1,dim,PArg,n+1,ns)      val nsumshift =(case ns
178            of 0=>   n
179      val body' =(case ns          |_=>let  val (E.V v,_,_)=List.nth(sumIndex, ns-1)
180          of 0=>    createBody(dim, s,n+length(dx),alpha,dx,0, 1, 3, 2)              in
181          |_=>let                  v+1
val (E.V v,_,_)=List.nth(sumIndex, ns-1)
val v'=v+length(dx) (*Shifted because of the swapped indices *)
val body'=createBody(dim, s,v'+2,alpha,dx,0, 1, 3, 2)
in  E.Sum(sumIndex ,body')
182              end              end
183      (* end case *))      (* end case *))
184
val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
val (p',i',b',a')=shift.clean(params'@paramsA, index@ix,body', args')
val newbie'=Ein.EIN{params=p', index=i', body=b'}

val data=assignEin (oldArg, newbie', a')

val _ = (case testing
of 0 => 1
| _ => (print(String.concat["\n Lift Probe\n", split.printA(oldArg, newbie', a'),"\n"]);1)
(*end case *))
in (newB, (params@[newP],args@[newArg]) ,code@[data]@assigments)
end

(*Lift probe and Multiply by P*)
fun liftProbe(b,(params,args),index, sumIndex,origargs)=let

val E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_))=b
val newId=length(params)
val n=length(index)

(*Create new tensor replacement*)
val shape=ShapeConv(alpha@dx, n)
val newB=E.Tensor(newId,shape)

(* Create new Param*)
(*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)
val shape'= mapIndex(shape,index)
val newP= E.TEN(1,shape')
185
186  val newArg = DstV.new ("PC", DstTy.tensorTy shape')      (*Outer Index-id Of Probe*)
187        val VShape=ShapeConv(alpha, n)
188        val HShape=ShapeConv(dx, n)
189        val shape=VShape@HShape
190
191  (*Expand Probe*)      (* Bindings for Shape*)
192  val ns=length sumIndex      val shapebind= mapIndex(shape,index)
193        val Vshapebind= mapIndex(VShape,index)
194
195  val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)      (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)
196        val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)
197
198        (*New transformations:params, sx, rest, will be empty if no transformation is made*)
199        val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)
200
201  val body' =(case ns      (*rewriteBody*)
202  of 0=>    createBody(dim, s,n,alpha,dx,0, 1, 3, 2)      val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)
|_=>let
val (E.V v,_,_)=List.nth(sumIndex, ns-1)
203
204  val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)      val sx=sumIndex@sxT
205  in  E.Sum(sumIndex ,body')      val body'=(case sx
206  end          of [] =>E.Prod(restT@[bodyExpanded])
207            | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))
208  (* end case *))  (* end case *))
209
210  val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]      (*create new EIN OPerator*)
211  val (p',i',b',a')=shift.clean(params', index,body', args')      val _ =print("Found this many args ")
212  val newbie'=Ein.EIN{params=p', index=i', body=b'}      val _ =print(Int.toString(length(args')))

213
214  val data=assignEin (newArg, newbie', a')      val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT
215        val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])
216        val newbie'=Ein.EIN{params=p', index=i', body=b'}
217        val data=assignEin (oldArg, newbie', a')
218
219  val _ = (case testing  val _ = (case testing
220  of 0 => 1  of 0 => 1
221  | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)  | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)
222  (*end case *))  (*end case *))
223  in (newB, (params@[newP],args@[newArg]) ,code@[data])      in
224            (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)
225  end  end
226
227
# Line 345  Line 232
232
233      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
234      val fid=length(params)      val fid=length(params)
235        val nid=fid+1
236      val n=length(index)      val n=length(index)

(*Expand Probe*)
237      val ns=length sumIndex      val ns=length sumIndex
238      val (dim,args',code,s,P) = handleArgs(V,h,t,(params,args), origargs,0)      val nshift=length(dx)
239      val nid=fid+1      val nsumshift =(case ns
240      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          of 0=> n
val body' =(case ns
of 0=> createBody(dim, s,n,alpha,dx,V, h, nid, fid)
241          |_=>let          |_=>let
242              val (E.V v,_,_)=List.nth(sumIndex, ns-1)              val (E.V v,_,_)=List.nth(sumIndex, ns-1)
243              in createBody(dim, s,v+1,alpha,dx,V, h, nid, fid)              in v+1
244              end              end
245          (* end case *))          (* end case *))
246
247        (*Outer Index-id Of Probe*)
248        val VShape=ShapeConv(alpha, n)
249        val HShape=ShapeConv(dx, n)
250        val shape=VShape@HShape
251        (* Bindings for Shape*)
252        val shapebind= mapIndex(shape,index)
253        val Vshapebind= mapIndex(VShape,index)
254
255
256        val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)
257        val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)
258
259        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
260        val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
261        val body' =(case nshift
262            of 0=> body''
263            | _ => E.Sum(sxT, E.Prod(restT@[body'']))
264            (*end case*))
265        val args'=argsA@[PArg]
266      val _ =(case testing      val _ =(case testing
267          of 0=> 1          of 0=> 1
268          | _ =>  let          | _ =>  let
# Line 367  Line 271
271              in 1 end              in 1 end
272          (* end case *))          (* end case *))
273
274
275      in (body',(params',args') ,code)      in (body',(params',args') ,code)
276      end      end
277
278    (*Checks if vairable occurs just once. If it does then we can lift *)
279    fun checkSum([(E.V i,lb,ub)],E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta)),info,index,origargs)=let
280        val n=length(index)
281        val _=print "in check Sum\n "
282        val _=print(P.printbody(E.Sum([(E.V i,lb,ub)],E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)))))
283        in
284         if (i=n) then (case F.findOcc(i,alpha@dx)
285                of  1 => liftProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)),info,index@[ub], [],origargs)
286                | _ =>  replaceProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)), info,index, [(E.V i,lb,ub)],origargs)
287                (*end case*))
288        else replaceProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)), info,index, [(E.V i,lb,ub)],origargs)
289        end
290
291  fun flatten []=[]  fun flatten []=[]
292      | flatten(e1::es)=e1@(flatten es)      | flatten(e1::es)=e1@(flatten es)
# Line 398  Line 315
315                  (sumIndex:=tl(s);(e',infoK,dataK))                  (sumIndex:=tl(s);(e',infoK,dataK))
316              end              end
317
318            fun filter es=let
319                fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)
320                | filterApply(B::es, doneA, infoA,dataA)= let
321                    val (bodyB, infoB,dataB)= rewriteBody(B,infoA)
322                    in
323                        filterApply(es, doneA@[bodyB], infoB,dataA@dataB)
324                    end
325                in filterApply(es, [],info,[])
326                end
327          in (case b          in (case b
328              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let
329                  val ref sx=sumIndex                  val ref sx=sumIndex
330                  in (case sx                  in (case sx
331                      of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)                      of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
332                      | _ =>  replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)                        | [i]=> checkSum(i,b, info,index,origargs)
333                         | _ => let
334                            val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
335                            in (E.Sum(c,b),m,code)
336                            end
337                  (* end case*))                  (* end case*))
338              end              end
339          | E.Probe(E.Conv _, E.Tensor _) =>let          | E.Probe(E.Conv _, E.Tensor _) =>let
340              val ref sx=sumIndex              val ref sx=sumIndex
341              in (case sx              in (case sx
342                  of []=> liftProbe(b, info,index, [],origargs)                  of []=> liftProbe(b, info,index, [],origargs)
343                    | [i]=> checkSum(i,b, info,index,origargs)
344                  | _=> replaceProbe(b, info,index, flatten sx,origargs)                  | _=> replaceProbe(b, info,index, flatten sx,origargs)
345               (* end case*))               (* end case*))
346              end              end
# Line 434  Line 365
365              val (bodyB, infoB,dataB)= rewriteBody(b,infoA)              val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
366              in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end              in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
368              fun filter([], done, info', data)= let              val (done, info',data')= filter es
370                      in (e, info',data)              in (e, info',data')
371                      end                      end
| filter(e::es, done, info',data)= let
val (body', info'',data')= rewriteBody(e,info')
in filter(es, done@[body'], info'',data@data') end
in filter(es, [],info,[]) end

372          | E.Prod es=> let          | E.Prod es=> let
373              fun filter([], done, info',data)= let              val (done, info',data')= filter es
374                      val (_, e)=F.mkProd done                      val (_, e)=F.mkProd done
375                      in  (e,info', data)              in (e, info',data')
376                      end                      end
| filter(e::es, done, info',data)= let
val (body', info'',data')= rewriteBody(e, info')
in filter(es, done@[body'], info'',data@data') end
in filter(es, [],info,[]) end
377          | _=>  (b,info,[])          | _=>  (b,info,[])
378          (* end case *))          (* end case *))
379          end          end

Legend:
 Removed from v.2610 changed lines Added in v.2611