SCM Repository
View of /branches/charisee/src/compiler/mid-to-low/step2.sml
Parent Directory
|
Revision Log
Revision 2669 -
(download)
(annotate)
Fri Jun 13 02:08:31 2014 UTC (6 years, 7 months ago) by cchiw
File size: 8196 byte(s)
Fri Jun 13 02:08:31 2014 UTC (6 years, 7 months ago) by cchiw
File size: 8196 byte(s)
represent tensors with arrays
(*general function for scalars*) structure step2 = struct local structure DstIL = LowIL structure DstTy = LowILTypes structure DstOp = LowOps structure Var = LowIL.Var structure E = Ein structure S3=step3 structure genKrn=genKrn structure tS= toStringEin in fun insert (key, value) d =fn s => if s = key then SOME value else d s fun lookup k d = d k val empty =fn key =>NONE fun err _=raise Fail("Invalid Field Here") fun errS str=raise Fail(str) (*Helpers for scalars*) fun mkCons(shape, rest)=let val ty=DstTy.TensorTy shape val a=DstIL.Var.new("Cons" ,ty) val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest)) val _=print("###"^tS.toStringAll(ty,code)) in (a, [code]) end val Sca=DstTy.TensorTy([]) fun mkProdSca rest=S3.aaV(DstOp.prodSca,rest,"prodSca",Sca) fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca) fun mkDivSca rest= S3.aaV(DstOp.divSca,rest,"divSca",Sca) fun mkMultipleSca(ids,rator)=S3.mkMultiple(ids,rator,Sca) fun mkInt n=S3.mkInt n fun prodIter(origIndex,index,nextfn,args)=(let val index'=List.map (fn (e)=>(e-1)) index fun get(n,m,mapp)=let val mapp =insert(n, m) mapp in nextfn(mapp,args) end fun Iter(mapp,[],rest,code,shape,_)=let val (vF,code')=nextfn(mapp,args) in (vF, code'@code) end | Iter(mapp,[0], rest, code,shape,n)=let val (vF,code')= get(n,0,mapp) val(vE,E)=mkCons(shape,[vF]@rest) in (vE, code'@code@E) end | Iter(mapp,[c],rest,code,shape,n)=let (*val (vF,code')= get(n,c,mapp) val (vE,E)=nextfn(mapp,args)*) val (vE,E)=get(n,c,mapp) in Iter(mapp, [c-1], [vE]@rest,E@code,shape,n) end | Iter(mapp,b::c,rest,ccode,s::shape,n)=let val n'=n+1 fun S(0, rest,code)=let val mapp =insert(n, 0) mapp val (v',code')=Iter(mapp,c,[],[],shape,n') val(vA,A)=mkCons(s::shape,[v']@rest) in (vA, code'@code@A) end | S(i, rest, code)= let val mapp =insert(n, i) mapp val (v',code')=Iter(mapp,c,[],[],shape,n') in S(i-1,[v']@rest,code'@code) end val (vA,code')=S(b, [],[]) in (vA,code'@ccode) end | Iter _=raise Fail"index' is larger than origIndex" in Iter(empty,index',[],[],origIndex,0) end) (*Get constant *) fun skeleton A=(case A of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0 | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1 | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1 | _ => 9 (*end case*)) (*Helper Functions for General functions*) fun findIX(v, mapp)=(case (lookup v mapp) of NONE=> errS( "Outside Bound:"^Int.toString(v)) |SOME s => s (*end case*)) fun NegCheckO(vA,A)=(case skeleton A of 0 => mkInt 0 | ~1 => mkInt 1 | 1 => mkInt ~1 | _=> let val (vB,B)=mkInt ~1 val (vD,D)=mkProdSca [vB,vA] in (vD,A@B@D) end (*end case*)) fun SubcheckO((vA,A),(vB,B))=(case((skeleton A),(skeleton B)) of (0,0)=> mkInt 0 |(0,_)=> let val (vD,D)= mkInt ~1 val (vE,E)= mkProdSca [vD,vB] in (vE,B@D@E) end | (_,0)=> (vA,A) | _ => let val (vD,D)= mkSubSca [vA,vB] in (vD, A@B@D) end (*end case*)) (* fun printMapp mapp=(case (lookup 0 mapp) of NONE=>print(String.concat["\n No zero"]) |SOME s => print(String.concat["\n Found 0 =>",Int.toString(s)]) (*end case*)) *) (* val info=(params,args)*) (* general expressions-removes zeros*) fun generalfn(dict,(body,origargs,info))=let val mapp=ref dict fun gen body=let fun AddcheckO ([],[],[])=let val (vA,A)=mkInt 0 in ([vA],A) end | AddcheckO([],ids,code)=(ids,code) | AddcheckO(e1::es,ids,code)=let val (a,b)=gen e1 in (case (skeleton b) of 0 => AddcheckO(es,ids,code) | _ => AddcheckO(es,ids@[a],code@b) (*end case*)) end fun ProdcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end | ProdcheckO([],ids,code)=(ids,code) | ProdcheckO(e1::es,ids,code)=let val (a,b)=gen e1 in (case (skeleton b) of 0 => ([a],b) | 1 => ProdcheckO(es,ids,code) | _ => ProdcheckO(es,ids@[a],code@b) (*end case*)) end fun Sumcheck(sumx,e)=let fun sumloop mapsum=let val _ = mapp:=mapsum val(vA,A)=gen e in (case (skeleton A) of 0 => ([],A) | _ => ([vA],A) (*end case*)) end (*in ([vA],A) end*) fun sumI1(left,(v,0,lb1),[],rest,code)=let val dict=insert(v, lb1) left val (vD,pre)= sumloop dict in (vD@rest,pre@code) end | sumI1(left,(v,i,lb1),[],rest,code)=let val dict=insert(v, (i+lb1)) left val (vD,pre)=sumloop dict in sumI1(dict,( v,i-1,lb1),[],vD@rest,pre@code) end | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let val dict=insert(v, lb1) left in sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) end | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let val dict=insert(v, (s+lb1)) left val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end | sumI1 _ =raise Fail"None Variable-index in summation" val (E.V v,lb,ub)=hd(sumx) in sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[]) end fun iterList(e, DstOp.addSca)=(case e of ([],code)=>let val (vA,A)=mkInt 0 in (vA,A) end | ([id1],code) => (id1,code) | (ids,code) => let val (vB,B)= mkMultipleSca(ids,DstOp.addSca) in (vB,code@B) end (*end case*)) | iterList(e,rator)= (case e of ([id1],code) => (id1,code) | (ids,code) => let val (vB,B)= mkMultipleSca(ids, rator) in (vB,code@B) end (*end case*)) in (case body of E.Field _ => err 1 | E.Partial _ => err 1 | E.Apply _ => err 1 | E.Probe _ => err 1 | E.Conv _ => err 1 | E.Krn _ => err 1 | E.Img _ => err 1 | E.Lift _ => err 1 | E.Value v => mkInt(findIX(v,!mapp)) | E.Const c => mkInt c | E.Epsilon(i,j,k) => S3.evalEps(!mapp,i,j,k) | E.Delta(i,j) => S3.evalDelta2(!mapp,i,j) | E.Tensor(id,ix) => S3.mkSca(!mapp,(id,ix,info)) | E.Neg e => NegCheckO(gen e) | E.Sub (e1,e2) => SubcheckO(gen e1,gen e2) | E.Div(e1,e2) => let val (vA,A)=gen e1 in (case (skeleton A) of 0=> mkInt 0 | _=> let val (vB,B)=gen e2 val (vD,D)= mkDivSca [vA,vB] in (vD, A@B@D) end (*end case*)) end | E.Add e => (iterList(AddcheckO(e,[],[]),DstOp.addSca)) | E.Prod e => iterList(ProdcheckO(e,[],[]),DstOp.prodSca) | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let val harg=List.nth(origargs,Hid) val imgarg=List.nth(origargs,Vid) val h=S3.getKernel(harg) val v=S3.getImage(imgarg) in genKrn.evalField(!mapp,(body,v,h,info)) end | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),DstOp.addSca) (*end case*)) end in gen body end end (* local *) end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |