SCM Repository
[diderot] / branches / charisee_dev / src / compiler / high-to-mid / ProbeEin.sml |
View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
Parent Directory
|
Revision Log
Revision 3311 -
(download)
(annotate)
Fri Oct 16 20:09:14 2015 UTC (6 years, 8 months ago) by cchiw
File size: 13814 byte(s)
Fri Oct 16 20:09:14 2015 UTC (6 years, 8 months ago) by cchiw
File size: 13814 byte(s)
no lift
(* Expands probe ein * * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. *) structure ProbeEin = struct local structure E = Ein structure DstIL = MidIL structure DstOp = MidOps structure P = Printer structure T = TransformEin structure MidToS = MidToString structure DstV = DstIL.Var structure DstTy = MidILTypes in (* This file expands probed fields * Take a look at ProbeEin tex file for examples *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list ) * Param_ids are used to note the placement of the argument in the midIL.var list * Index_ids keep track of the shape of an Image or differentiation. * Mu bind Index_id * Generally, we will refer to the following *dim:dimension of field V * s: support of kernel H * alpha: The alpha in <V_alpha * H^(deltas)> * deltas: The deltas in <V_alpha * H^(deltas)> * Vid:param_id for V * hid:param_id for H * nid: integer position param_id * fid :fractional position param_id * img-imginfo about V *) val testing=0 val testlift=0 val detflag =false val cnt = ref 0 fun transformToIndexSpace e=T.transformToIndexSpace e fun transformToImgSpace e=T.transformToImgSpace e fun toStringBind e=(MidToString.toStringBind e) fun mkEin e=Ein.mkEin e fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args) fun testp n=(case testing of 0=> 1 | _ =>(print(String.concat n);1) (*end case*)) fun getRHSDst x = (case DstIL.Var.binding x of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args) | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x' | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb]) (* end case *)) (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments returns the support of ther kernel, and image *) fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg) of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let in ((Kernel.support h) ,img,ImageInfo.dim img) end | _ => raise Fail "Expected Image and kernel arguments" (*end case*)) (*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var * uses the Param_ids for the image, kernel, and tensor * and gets the mid-IL vars for each. *Transforms the position to index space *P is the mid-il var for the (transformation matrix)transpose *) fun handleArgs(Vid,hid,tid,args)=let val imgArg=List.nth(args,Vid) val hArg=List.nth(args,hid) val newposArg=List.nth(args,tid) val (s,img,dim) =getArgsDst(hArg,imgArg,args) val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg) in (dim,args@argsT,code, s,P) end (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id * expands the body for the probed field *) fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let (*1-d fields*) fun createKRND1 ()=let val sum=sx val dels=List.map (fn e=>(E.C 0,e)) deltas val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]] val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum))) in E.Prod [E.Img(Vid,alpha,pos),rest] end (*createKRN Image field and kernels *) fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest) | createKRN(dim,imgpos,rest)=let val dim'=dim-1 val sum=sx+dim' val dels=List.map (fn e=>(E.C dim',e)) deltas val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]] val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum))) in createKRN(dim',pos@imgpos,[rest']@rest) end val exp=(case dim of 1 => createKRND1() | _=> createKRN(dim, [],[]) (*end case*)) (*sumIndex creating summaiton Index for body*) val slb=1-s val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s))) in E.Sum(esum, exp) end (*getsumshift:sum_indexid list* int list-> int *get fresh/unused index_id, returns int *) fun getsumshift(sx,index) =let val nsumshift= (case sx of []=> length(index) | _=>let val (E.V v,_,_)=List.hd(List.rev sx) in v+1 end (* end case *)) val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa), "\nThink nshift is ", Int.toString nsumshift] in nsumshift end (*formBody:ein_exp->ein_exp *just does a quick rewrite *) fun formBody(E.Sum([],e))=formBody e | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e) | formBody(E.Prod [e])=e | formBody e=e (* silly change in order of the product to match vis branch WorldtoSpace functions*) fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body]))) | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps))) fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])]) | multiMergePs e=multiPs e (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list -> ein_exp* *code * Transforms position to world space * transforms result back to index_space * rewrites body * replace probe with expanded version *) (* fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*) fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx) =let val originalb=Ein.body e val params=Ein.params e val index=Ein.index e val _ = testp["\n***************** \n Replace ************ \n"] val _= toStringBind (y, DstIL.EINAPP(e,args)) val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p val fid=length(params) val nid=fid+1 val Pid=nid+1 val nshift=length(dx) val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args) val freshIndex=getsumshift(sx,index) val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid) val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])] val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid) val body' = multiPs(Ps,newsx1,body') val body'=(case originalb of E.Sum(sx, E.Probe _) => E.Sum(sx,body') | E.Sum(sx,E.Prod[eps0,E.Probe _ ]) => E.Sum(sx,E.Prod[eps0,body']) | _ => body' (*end case*)) val args'=argsA@[PArg] val einapp=(y,mkEinApp(mkEin(params',index,body'),args')) in code@[einapp] end val tsplitvar=true fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let val Pid=0 val tid=1 (*Assumes body is already clean*) val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid) (*need to rewrite dx*) val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx of []=> ([],index,E.Conv(9,alpha,7,newdx)) | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx) (*end case*)) val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)] fun filterAlpha []=[] | filterAlpha(E.C _::es)= filterAlpha es | filterAlpha(e1::es)=[e1]@(filterAlpha es) val tshape=filterAlpha(alpha')@newdx val t=E.Tensor(tid,tshape) val (splitvar,body)=(case originalb of E.Sum(sx, E.Probe _) => (false,E.Sum(sx,multiPs(Ps,newsx,t))) | E.Sum(sx,E.Prod[eps0,E.Probe _ ]) => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)])) | _ => (case tsplitvar of(* true => (true,multiMergePs(Ps,newsx,t)) (*pushes summations in place*) | false*) _ => (true,multiPs(Ps,newsx,t)) (*end case*)) (*end case*)) val ein0=mkEin(params,index,body) in (splitvar,ein0,sizes,dx,alpha') end fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let val _=testp["\n******* Lift ******** \n"] val originalb=Ein.body e val params=Ein.params e val index=Ein.index e val _= toStringBind (y, DstIL.EINAPP(e,args)) val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p val fid=length(params) val nid=fid+1 val nshift=length(dx) val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args) val freshIndex=getsumshift(sx,index) (*transform T*P*P..Ps*) val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx) val FArg = DstV.new ("F", DstTy.TensorTy(sizes)) val einApp0=mkEinApp(ein0,[PArg,FArg]) val rtn0=(case splitvar of false => [(y,einApp0)] | _ => Split.splitEinApp (y,einApp0) (*end case*)) (*lifted probe*) val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])] val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid) val ein1=mkEin(params',sizes,body') val einApp1=mkEinApp(ein1,args') val rtn1=(FArg,einApp1) val rtn=code@[rtn1]@rtn0 val _= List.map toStringBind ([rtn1]@rtn0) in rtn end (* expandEinOp: code-> code list * A this point we only have simple ein ops * Looks to see if the expression has a probe. If so, replaces it. * Note how we keeps eps expressions so only generate pieces that are used *) fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let fun checkConst ([],a) = liftProbe a | checkConst ((E.C _::_),a) = replaceProbe a | checkConst ((_ ::es),a)= checkConst(es,a) fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))= let val _= toStringBind e val index0=Ein.index ein val index1 = index0@[3] val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos) (* clean to get body indices in order *) val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[]) val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1] val lhs1=DstV.new ("L", DstTy.TensorTy(index1)) val ein1 = mkEin(Ein.params ein,index1,body1) val code1= (lhs1,mkEinApp(ein1,args)) val codeAll= (case dx of []=> replaceProbe(1,code1,body1,[]) | _ =>liftProbe(1,code1,body1,[]) (*end case*)) (*Probe that tensor at a constant position E.C c1*) val param0 = [E.TEN(1,index1)] val nx=List.tabulate(length(dx)+1,fn n=>E.V n) val body0 = E.Tensor(0,[E.C c1]@nx) val ein0 = mkEin(param0,index0,body0) val einApp0 = mkEinApp(ein0,[lhs1]) val code0 = (y,einApp0) val _= toStringBind code0 in codeAll@[code0] end fun rewriteBody b=(case (detflag,b) of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos)) => liftFieldMat (1,b) | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos)) => liftFieldMat (2,b) | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos)) => liftFieldMat (3,b) | (_,E.Probe(E.Conv(_,_,_,[]),_)) => replaceProbe(0,e,b,[]) | (_,E.Probe(E.Conv (_,alpha,_,dx),_)) => checkConst(dx,(0,e,b,[])) (*scans dx for contant*) | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))) => replaceProbe(0,e,p, sx) (*no dx*) | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))) => checkConst(dx,(0,e,p,sx)) (*scalar field*) | (_,E.Sum(sx,E.Probe p)) => replaceProbe(0,e,E.Probe p, sx) | (_,E.Sum(sx,E.Prod[eps,E.Probe p])) => replaceProbe(0,e,E.Probe p,sx) | (_,_) => [e] (* end case *)) val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args)) fun matchField b=(case b of E.Probe _ => 1 | E.Sum (_, E.Probe _)=>1 | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1 | _ =>0 (*end case*)) in (case var of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0)) | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1)) (*end case*)) end end; (* local *) end (* local *)
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |