SCM Repository
[diderot] / branches / charisee / src / compiler / high-to-mid / split-einHtM.sml |
View of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml
Parent Directory
|
Revision Log
Revision 2555 -
(download)
(annotate)
Mon Mar 3 19:14:57 2014 UTC (7 years, 1 month ago) by cchiw
File size: 10018 byte(s)
Mon Mar 3 19:14:57 2014 UTC (7 years, 1 month ago) by cchiw
File size: 10018 byte(s)
Code Clean up
(* Split Functions before code generation process*) structure splitHtM = struct local structure E = Ein structure DstIL = MidIL structure DstTy = MidILTypes structure shift=shiftHtM structure P=Printer structure Var = MidIL.Var structure HVar = HighIL.Var in fun printA(id,e,arg)=let val a=String.concatWith " , " (List.map Var.toString arg) in String.concat([(Var.toString id)," ==",P.printerE e, a]) end fun printAA(id,e,arg)=let val a=String.concatWith " , " (List.map HVar.toString arg) in String.concat([(Var.toString id)," ==",P.printerE e, a]) end fun createEin( params,index, body)=Ein.EIN{params=params, index=index, body=body} fun flat xs = List.foldr op@ [] xs val counter=ref 0 (*How to create new ein variable*) fun fresh ty=let val ref x=counter val m=x+1 val x=DstIL.Var.new("Q" ^ Int.toString(m) ,ty) in (counter:=m;x) end fun createnewb (params,index,args,(id,e))=let val (p',b',args)=shift.cleanParams(e,params,args) val a=createEin(p',index, b') in (id,a,args) end fun createnewP (params,args,(id,e,ix))=let val (p',b',args)=shift.cleanParams(e,params,args) val a=createEin(p',ix, b') in (id,a,args) end fun findOp e=(case e of E.Neg _=>1 | E.Add _=>1 | E.Sub _=>1 | E.Prod _=>1 | E.Div _=>1 | E.Sum _ =>1 | _=>0 (*end case*)) (*Outside Operator is Neg*) fun handleNeg(params, index,e1,args)=let val id=ref (length params) val n=length index val ix=List.tabulate (n,fn v=> E.V(v)) fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end fun divsort(e)= let val s=findOp e in (case s of 0=>(e,[],[],[]) | _=> let val q=fresh(DstTy.TensorTy(index)) in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end (*end case*)) end val (lft1, newbies1,params1,args1)=divsort(e1) val (p',b',args')= shift.cleanParams(E.Neg(lft1),params@params1, args@args1) val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1 in (z1,(p',b',args')) end (*Outside Operator is Add*) fun handleAdd(params, index,list1,args)=let val aa=print "ADDXX" val id=ref (length params) val n=length index val ix=List.tabulate (n,fn v=> E.V(v)) fun mkTensor _=let val ref idx= id in (id:=(idx+1);[E.Tensor(idx,ix)]) end fun foundOp(e,es,(lft,newbies,params,args))=let val q=fresh(DstTy.TensorTy(index)) in (es,(lft@(mkTensor 0), newbies@[(q, e)],params@[E.TEN(1,index)],args@[q])) end fun sort([], m)=m | sort(e::es,m)=(case e of E.Add p => sort(p@es, m) | E.Sub _=>sort (foundOp(e, es,m)) | E.Prod _=>sort (foundOp(e, es,m)) | E.Div _=>sort (foundOp(e, es,m)) | E.Neg _=>sort (foundOp(e, es,m)) | E.Sum _=>sort (foundOp(e, es,m)) | _ => let val (l,n, p, a)=m in sort(es,(l@[e],n,p,a)) end (*end case *)) val (lft, newbies,params',args')=sort(list1,([],[],[],[])) val (p',b',args')= shift.cleanParams(E.Add(lft),params@params', args@args') val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies in (z,(p',b',args')) end (*Outside Operator is Sub*) fun handleSub(params, index,e1,e2,args)=let val gg=print "SUBXX" val id=ref (length params) val n=length index val ix=List.tabulate (n,fn v=> E.V(v)) fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end fun subsort(e)= let val s=findOp e in (case s of 0=>(e,[],[],[]) | _=> let val q=fresh(DstTy.TensorTy(index)) in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end (*end case*)) end val (lft1, newbies1,params1,args1)=subsort e1 val (lft2, newbies2,params2,args2)=subsort e2 val (p',b',args')= shift.cleanParams(E.Sub(lft1,lft2), params@params1@params2, args@args1@args2) val newbies=newbies1@newbies2 val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies in (z,(p',b',args')) end (*Outside Operator is Div *) fun handleDiv(params, index,e1,e2,args)=let val id=ref (length params) val n=length index val ix=List.tabulate (n,fn v=> E.V(v)) fun mkTensor _=let val ref idx= id val _ =id:=(idx+1) in (E.Tensor(idx,ix),E.TEN(1,index)) end fun mkSca _= let val ref idx= id val _ =id:=(idx+1) in (E.Tensor(idx,[]),E.TEN(1,[])) end fun divsort(e,nextfn)= let val s=findOp e in (case s of 0=>(e,[],[],[]) | _=> let val q=fresh(DstTy.TensorTy(index)) val (a,b)=nextfn 0 in (a,[b], [(q, e)],[q]) end (*end case*)) end val (lft1,params1, newbies1,args1)=divsort(e1,mkTensor) val (lft2,params2, newbies2,args2)=divsort(e2,mkSca) val (p',b',args')= shift.cleanParams(E.Div(lft1,lft2),params@params1@params2, args@args1@args2) val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1 val z2=List.map (fn(e)=> createnewb(params,[],args,e) ) newbies2 in (z1@z2,(p',b',args')) end fun hProd(params, index,list1,args)=let val id=ref (length params) val n=length index fun foundOp (e,es,(lft,newbies, params, args))=let val ref idx= id val (ix,index',e')=shift.cleanIndex(e, n, index) val (p,ix,e')=([E.Tensor(idx,ix)],index',e') val q=fresh(DstTy.TensorTy(ix)) in (id:=(idx+1);(es,(lft@p, newbies@[(q, e',ix)],params@[E.TEN(1,ix)],args@[q]))) end fun sort([], m)=m | sort(e::es,m)=(case e of E.Add _ => sort (foundOp(e,es,m)) | E.Sub _=>sort (foundOp(e, es,m)) | E.Prod p=>(sort (p@es,m)) | E.Div _=>sort (foundOp(e, es,m)) | E.Neg _=>sort (foundOp(e, es,m)) | E.Sum _=>sort (foundOp(e, es,m)) | E.Probe _=> raise Fail("Probe- Should have been expanded") | _ => let val (l,n, p, a)=m in sort(es,(l@[e],n,p,a)) end (*end case *)) in sort(list1,([],[],[],[])) end fun handleProd(params, index,list1,args)=let val (lft, newbies,params',args')=hProd(params, index,list1,args) val (p',b',args')= shift.cleanParams(E.Prod lft, params@params', args@args') val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies in (z,(p',b',args')) end fun handleSumProd(paramsO, indO,sxO,list1O,argsO)=let val id=ref (length paramsO) val n=length indO val m=print (String.concat["\n In Sum Prod", "n",Int.toString(n)]) val (params,ind,E.Sum(sx,E.Prod list1),args)=shiftHtM.clean(paramsO, indO,E.Sum(sxO,E.Prod list1O), argsO) fun g(lft,[],_)=(1,lft) | g(lft,(E.V s,0,ub)::es,n')= if(s=n') then (g(lft@[ub+1],es,n'+1)) else (0,[]) | g _ =(0,[]) (*Can't be split, weird bound*) val (c,index')= g([],sx,n) in case c of 0=> ([],(params,E.Sum(sx, E.Prod(list1)),args)) |_=>let val index=ind@index' val (lft, newbies,params',args')=hProd(params, index,list1,args) val (p',b',args')= shift.cleanParams(E.Sum(sx,E.Prod lft), params@params', args@args') val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies in (z,(p',b',args')) end end fun genfn(id,Ein.EIN{params, index, body},args)= let val mm=print(P.printbody(body)) val notDone=([],(params,body,args)) fun gen body=(case body of E.Field _ =>raise Fail(concat["Invalid Field here "] ) | E.Partial _ =>raise Fail(concat["Invalid Partial here "] ) | E.Apply _ =>raise Fail(concat["Invalid Apply here "] ) | E.Probe _ => raise Fail("Probe- Should have been expanded") | E.Conv _ =>notDone | E.Krn _ =>notDone | E.Img _=> notDone | E.Const _=> notDone | E.Tensor(id,[])=> notDone | E.Prod(E.Img _ :: _)=>notDone | E.Neg(E.Neg e)=> gen e | E.Neg e=> handleNeg(params, index,e,args) | E.Add a => (handleAdd(params, index,a,args)) | E.Sub(E.Sub(a,b),E.Sub(c,d))=> gen(E.Sub(E.Add[a,d],E.Add[b,c])) | E.Sub(E.Sub(a,b),e2)=>gen (E.Sub(a,E.Add[b,e2])) | E.Sub(e1,E.Sub(c,d))=>gen(E.Add([E.Sub(e1,c),d])) | E.Sub(e1,e2)=>(handleSub(params, index,e1,e2,args)) | E.Div(E.Div(a,b),E.Div(c,d))=> gen(E.Div(E.Prod[a,d],E.Prod[b,c])) | E.Div(E.Div(a,b),c)=> gen(E.Div(a, E.Prod[b,c])) | E.Div(a,E.Div(b,c))=> gen(E.Div(E.Prod[a,c],b)) | E.Div(e1,e2)=>handleDiv(params, index,e1,e2,args) | E.Prod e=> (handleProd(params, index,e,args)) | E.Sum(_,E.Prod(E.Img _ :: _ ))=>notDone | E.Sum(sx,E.Prod e)=>(handleSumProd(params, index,sx,e,args)) | _=> notDone (*end case*)) val (newbie,(p,b,arg))= gen body val e'=createEin(p,index, b) val f= (id,e',arg) in (newbie, f) end fun splitIt (change,e)=let val (newbie, e')= genfn e in (case length(newbie) of 0=>(change,[e']) | _=> let val a=List.map (fn(e1)=>splitIt(1,e1)) newbie val newbie'=flat(List.map (fn(e1,e2)=>e2) a) in (1,newbie'@[e']) end (*end case *)) end fun splitein(id,E.EIN{params,index,body},arg)=let val m=print(printA(id,E.EIN{params=params,index=index,body=body},arg)) val g=print "\n \t changed to =>\n \t" val (p',i',b',args')=shiftHtM.clean(params, index, body, arg) val einn'=createEin(p',i', b') val m=print(printA(id,einn',args')) in splitIt(0,(id,einn',args')) end end (* local *) end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |