SCM Repository
[diderot] / branches / charisee / src / compiler / high-to-mid / ProbeEin.sml |
View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
Parent Directory
|
Revision Log
Revision 2606 -
(download)
(annotate)
Wed Apr 30 16:05:25 2014 UTC (5 years, 7 months ago) by cchiw
File size: 11486 byte(s)
Wed Apr 30 16:05:25 2014 UTC (5 years, 7 months ago) by cchiw
File size: 11486 byte(s)
added files
(* examples.sml * * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. *) (* A couple of different approaches. One approach is to find all the Probe(Conv). Gerenerate exp for it Then use Subst function to sub in. That takes care for index matching and *) (*This approach creates probe expanded terms, and adds params to the end. *) structure ProbeEin = struct local structure E = Ein structure mk= mkOperators structure SrcIL = HighIL structure SrcTy = HighILTypes structure SrcOp = HighOps structure SrcSV = SrcIL.StateVar structure VTbl = SrcIL.Var.Tbl structure DstIL = MidIL structure DstTy = MidILTypes structure DstOp = MidOps structure DstV = DstIL.Var structure SrcV = SrcIL.Var structure P=Printer structure shift=ShiftEin structure split=SplitEin structure F=Filter val testing=0 datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int datatype peanut2= O2 of SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int in fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args))) fun getRHS x = (case SrcIL.Var.binding x of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args) | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x' | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args) | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args) | SrcIL.VB_NONE=>(S2 2,[]) | vb => raise Fail(concat[ "expected rhs operator for ", SrcIL.Var.toString x, "but found ", SrcIL.vbToString vb]) (* end case *)) (*Create fractional, and integer position vectors*) fun transformToImgSpace (dim,v,posx)=let val translate=DstOp.Translate v val transform=DstOp.Transform v val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*) val T = DstV.new ("T", DstTy.tensorTy [dim,dim]) (*translate*) val x = DstV.new ("x", DstTy.vecTy dim) (*Image-Space position*) val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*) val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*) val n = DstV.new ("n", DstTy.iVecTy dim) (*integer position*) val PosToImgSpace=mk.transform(dim,dim) val code=[ assign(M, transform, []), assign(T, translate, []), assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*) assign(nd, DstOp.Floor dim, [x]), (*nd *) assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*) assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*) ] in ([n,f],code) end fun replaceH(kvar, place,args)=let val l1=List.take(args, place) val l2=List.drop(args,place+1) in l1@[kvar]@l2 end (*Get Img, and Kern Args*) fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg) of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let val hvar=DstV.new ("KNL", DstTy.KernelTy) val imgvar=DstV.new ("IMG", DstTy.ImageTy img) val argsVK= (case lift of 0=> let val argsN=replaceH(hvar, hid,args) in replaceH(imgvar, V,argsN) end | _ => [imgvar, hvar] (* end case *)) val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])] in (Kernel.support h ,img, assigments,argsVK) end | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument" | _ => raise Fail "Not a kernel argument" fun handleArgs(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),origargs,lift)=let val E.IMG(dim)=List.nth(params,V) val kArg=List.nth(origargs,h) val imgArg=List.nth(origargs,V) val newposArg=List.nth(args, t) val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift) val (argsT,code')=transformToImgSpace(dim,img,newposArg) in (dim,argsVH@argsT,argcode@code', s) end | handleArgs _ =raise Fail"Expression is wrong for handleArgs" (*createDels=> creates the kronecker deltas for each Kernel*) fun createDels([],_)= [] | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim) (*Created new body for probe*) fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let (*sumIndex creating summaiton Index for body*) fun sumIndex(0)=[] |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)] (*createKRN Image field and kernels *) fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest) | createKRN(dim,imgpos,rest)=let val dim'=dim-1 val sum=sx+dim' val dels=createDels(deltas,dim') val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]] val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum))) in createKRN(dim',pos@imgpos,[rest']@rest) end val exp=createKRN(dim, [],[]) val esum=sumIndex (dim) in E.Sum(esum, exp) end fun ShapeConv([],n)=[] | ShapeConv(E.C c::es, n)=ShapeConv(es, n) | ShapeConv(E.V v::es, n)= if(n>v) then [E.V v] @ ShapeConv(es, n) else ShapeConv(es,n) fun mapIndex([],_)=[] | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index) | mapIndex(E.C c::es,index) = mapIndex(es,index) (*Lift probe*) fun liftProbe(b,(params,args),index, sumIndex,origargs)=let val E.Probe(E.Conv(_,alpha,_,dx),pos)=b val newId=length(params) val n=length(index) (*Create new tensor replacement*) val shape=ShapeConv(alpha@dx, n) val newB=E.Tensor(newId,shape) (* Create new Param*) (* val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*) val shape'= mapIndex(shape,index) val newP= E.TEN(1,shape') (*Create new Arg*) val newArg = DstV.new ("PC", DstTy.tensorTy shape') (*Expand Probe*) val ns=length sumIndex val (dim,args',code,s) = handleArgs(b,(params,args), origargs,1) val body' =(case ns of 0=> createBody(dim, s,n,alpha,dx,0, 1, 3, 2) |_=>let val (E.V v,_,_)=List.nth(sumIndex, ns-1) val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2) in E.Sum(sumIndex ,body') end (* end case *)) val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])] val (p',i',b',a')=shift.clean(params', index,body', args') val newbie'=Ein.EIN{params=p', index=i', body=b'} val data=assignEin (newArg, newbie', a') val _ = (case testing of 0 => 1 | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1) (*end case *)) in (newB, (params@[newP],args@[newArg]) ,code@[data]) end (* Expand probe in place *) fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let val E.Probe(E.Conv(V,alpha,h,dx),pos)=b val fid=length(params) val n=length(index) (*Expand Probe*) val ns=length sumIndex val (dim,args',code,s) = handleArgs(b,(params,args), origargs,0) val nid=fid+1 val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])] val body' =(case ns of 0=> createBody(dim, s,n,alpha,dx,V, h, nid, fid) |_=>let val (E.V v,_,_)=List.nth(sumIndex, ns-1) in createBody(dim, s,v+1,alpha,dx,V, h, nid, fid) end (* end case *)) val _ =(case testing of 0=> 1 | _ => let val subexp=Ein.EIN{params=params', index=index, body=body'} val _= print(String.concat["\n Don't replace probe \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"]) in 1 end (* end case *)) in (body',(params',args') ,code) end fun flatten []=[] | flatten(e1::es)=e1@(flatten es) (* sx-[] then move out, otherwise keep in *) fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let val dummy=E.Const 0 val sumIndex=ref [] (*b-current body, info-original ein op, data-new assigments*) fun rewriteBody(b,info)= let fun callfn(c1,body)=let val ref x=sumIndex val c'=[c1]@x val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info)) val ref s=sumIndex val z=hd(s) val e'=( case bodyK of E.Const _ =>bodyK | _ => E.Sum(z,bodyK) (*end case*)) in (sumIndex:=tl(s);(e',infoK,dataK)) end in (case b of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let val ref sx=sumIndex in (case sx of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs) | _ => replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs) (* end case*)) end | E.Probe(E.Conv _, E.Tensor _) =>let val ref sx=sumIndex in (case sx of []=> liftProbe(b, info,index, [],origargs) | _=> replaceProbe(b, info,index, flatten sx,origargs) (* end case*)) end | E.Probe _=> (dummy,info,[]) | E.Conv _=> (dummy,info,[]) | E.Lift _=> (dummy,info,[]) | E.Field _ => (dummy,info,[]) | E.Apply _ => (dummy,info,[]) | E.Neg e=> let val (body',info',data')=rewriteBody(e,info) in (E.Neg(body'),info',data') end | E.Sum (c,e)=> callfn(c,e) | E.Sub(a,b)=>let val (bodyA,infoA,dataA)= rewriteBody(a,info) val (bodyB, infoB, dataB)= rewriteBody(b,infoA) in (E.Sub(bodyA, bodyB),infoB,dataA@dataB) end | E.Div(a,b)=>let val (bodyA,infoA,dataA)= rewriteBody(a,info) val (bodyB, infoB,dataB)= rewriteBody(b,infoA) in (E.Div(bodyA, bodyB),infoB,dataA@dataB) end | E.Add es=> let fun filter([], done, info', data)= let val (_, e)=F.mkAdd done in (e, info',data) end | filter(e::es, done, info',data)= let val (body', info'',data')= rewriteBody(e,info') in filter(es, done@[body'], info'',data@data') end in filter(es, [],info,[]) end | E.Prod es=> let fun filter([], done, info',data)= let val (_, e)=F.mkProd done in (e,info', data) end | filter(e::es, done, info',data)= let val (body', info'',data')= rewriteBody(e, info') in filter(es, done@[body'], info'',data@data') end in filter(es, [],info,[]) end | _=> (b,info,[]) (* end case *)) end val empty =fn key =>NONE val _ =(case testing of 0 => 1 | _ => (print "\n ************************** \n Starting Expand";1) (*end case*)) val (body',(params',args'),newbies)=rewriteBody(body,(params,args)) val e'=Ein.EIN{params=params', index=index, body=body'} val _ =(case testing of 0 => 1 | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1) (*end case*)) in ((e',args'),newbies) end end; (* local *) end (* local *)
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |