SCM Repository
View of /branches/charisee/src/compiler/mid-to-low/step3.sml
Parent Directory
|
Revision Log
Revision 2646 -
(download)
(annotate)
Thu May 29 15:52:23 2014 UTC (6 years, 7 months ago) by cchiw
File size: 6943 byte(s)
Thu May 29 15:52:23 2014 UTC (6 years, 7 months ago) by cchiw
File size: 6943 byte(s)
opr to clang
(*Helper function gen-ein*) structure step3 = struct local structure DstIL = LowIL structure DstTy = LowILTypes structure DstOp = LowOps structure Var = LowIL.Var structure SrcIL = MidIL structure SrcOp = MidOps structure E = Ein structure tS= toStringEin in val testing=1 val bV= ref 0 fun err str=raise Fail(str) val Sca=DstTy.TensorTy [] fun lookup k d = d k fun q e1=Int.toString e1 fun insert (key, value) d =fn s => if s = key then (print(String.concat[Int.toString(key),"=>",Int.toString(value)]);SOME value) else d s (*Get kernel and Image bindings*) fun getKernel x = (case SrcIL.Var.binding x of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _),_))=> h | vb => (err (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"])) (* end case *)) fun getImage x = (case SrcIL.Var.binding x of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage(img),_))=> img | vb => (err (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"])) (* end case *)) (*Make assignment*) fun aaV(opss,args,pre,ty)=let val a=DstIL.Var.new(pre ,ty) val code=DstIL.ASSGN (a,DstIL.OP(opss,args)) val _ =(case testing of 0=> 1 | _ => (print(tS.toStringAll(ty,code)); 1) (* end case *)) in (a,[code]) end fun mkInt n=let val a=DstIL.Var.new("Int" ,Sca) val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n))) val _ =(case testing of 0=> 1 | _ => (print(tS.toStringAll(Sca,code));1) (*end case*)) in (a,[code]) end (*mk Multiple, Add Ids on list1*) fun mkMultiple(list1,rator,ty)=let fun add(e,code)=(case e of [] => err"no element in mkMultiple" | [e1] => (e1,[]) | [e1,e2] => let val (vA,A)=aaV(rator,[e1,e2],"MO",ty) in (vA,code@A) end | (e1::e2::es) => let val (vA,A)=aaV(rator,[e1,e2],"MO",ty) in add(vA::es,code@A) end (*end case*)) in add(list1,[]) end fun mapIndex(e1,mapp)=(case e1 of E.V e => (case (lookup e mapp) of NONE=> err("Outside Bound:"^Int.toString(e)) |SOME s => s) | E.C c=> c (*end case*)) (*Integer, or Generic Tensor*) fun getTensorTy(params, id)=(case List.nth(params,id) of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*) | E.TEN(_,shape)=> DstTy.TensorTy shape |_=> err"NONE Tensor Param") fun q e=Int.toString(e) (*Just added Index options*) fun mkSca(mapp,(id, [],(params,args)))=let val nU=List.nth(args,id) in (nU,[]) end | mkSca(mapp,(id,ix,(params,args)))= let val nU=List.nth(args,id) val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix) val ix'=DstTy.indexTy ixx val argTy=getTensorTy(params,id) (*DstOp.S(id,ix',argTy)*) val opp=DstOp.IndexTensor(id,Sca,ix',argTy) in aaV(opp,[nU],"S"^Int.toString(id),Sca) end (*eval Epsilon*) fun evalEps(mapp,a,b,c)=let val i=mapIndex(E.V a,mapp) val j=mapIndex(E.V b,mapp) val k=mapIndex(E.V c,mapp) in if(i=j orelse j=k orelse i=k) then mkInt 0 else if(j>i) then if(j>k andalso k>i) then mkInt ~1 else mkInt 1 else if(i>k andalso k>j) then mkInt 1 else mkInt ~1 end (*eval Delta*) fun evalDelta2(mapp,a,b)= let val i=mapIndex(a,mapp) val j=mapIndex(b,mapp) in if(i=j) then mkInt 1 else mkInt 0 end fun evalDels(mapp,dels)=let fun m(a,b)=if(a=b) then 1 else 0 fun ij(i,j)=(case (i,j) of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp)) | (E.C a, E.V b)=>m(a,mapIndex(j,mapp)) | (E.V a, E.C b)=>m(mapIndex(i,mapp),b) | (E.C a, E.C b)=>m(i,j) (*end case*)) val dels'=List.map ij dels in List.foldl(fn(x,y)=>x+y) 0 dels' end (*--------------------Vectorization Helper Functions--------------------*) (*val nextfnArgs=(body,params,args,origargs)*) fun mkVec(mapp,(id,[],vecIX,(params,args)))= let val nU=List.nth(args,id) in (nU,[]) end | mkVec(mapp,(id,ix,vecIX,(params,args)))= let val nU=List.nth(args,id) val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix) val argTy= getTensorTy(params,id) (*DstOp.V(id,vecIX, ix',argTy)*) val vecTy=DstTy.TensorTy [vecIX] val opp=DstOp.IndexTensor(id,vecTy,ix',argTy) in aaV(opp,[nU],"V"^Int.toString(id),vecTy) end (*product of -1 and 1 projection*) fun mkNegV(mapp,((vA,id,ix),vecIX,info))=let val (vB, B)= mkVec(mapp,(id,ix,vecIX,info)) val (vD, D)=aaV(DstOp.prodScaV vecIX,[vA, vB],"prodScaV",DstTy.TensorTy [vecIX]) in (vD,B@D) end (* Vector Subtraction*) fun mksubVec(mapp,(id1,ix1,id2,ix2,vecIX,info))= let val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info)) val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info)) val (vD,D)= aaV(DstOp.subVec vecIX ,[vA, vB],"subVec",DstTy.TensorTy [vecIX]) in (vD, A@B@D) end (*Vector Addition *) fun handleAddVec(mapp,(es,vecIX,info))=let fun add([],rest,code)=(rest,code) | add((id1,ix1)::es,rest,code)=let val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info)) in add(es,rest@[vA],code@A) end val (rest,code)=add(es,[],[]) val (vA,A)=mkMultiple(rest,DstOp.addVec(vecIX),DstTy.TensorTy([vecIX])) in (vA,code@A) end (*Vector Scaling*) fun mkprodScaV(mapp,(id1,ix1,id2,ix2,vecIX,info))=let val (vA,A)= mkSca(mapp,(id1,ix1,info)) val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info)) val (vD,D)= aaV(DstOp.prodScaV(vecIX),[vA, vB],"prodScaV",DstTy.TensorTy([vecIX])) in (vD,A@B@D) end (*Vector Product*) fun mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))= let val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info)) val (vB, B)= mkVec(mapp,(id2,ix2,vecIX,info)) val (vD, D)=aaV(DstOp.prodVec(vecIX),[vA, vB],"prodV",DstTy.TensorTy([vecIX])) in (vD, A@B@D) end (*Sum of Vector Product*) fun mkprodSumVec(mapp,(id1,ix1,id2,ix2,vecIX, info))= let val (vD,D)=mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info)) val (vE, E)=aaV(DstOp.sumVec vecIX,[vD],"sumVec",DstTy.realTy) in (vE, D @E) end (*Dot Product like summation *) fun sumDot(mapp, ((E.V v,lb,ub),t))=let fun sumI(a,0,rest,code)=let val mapp =insert(v, 0) a val (vE, E)=mkprodSumVec(mapp,t) val rest'=[vE]@rest val (vF, F)=mkMultiple(rest',DstOp.addSca,Sca) in (vF,E@code@F) end | sumI(a,sx,rest',code')=let val mapp =insert(v, (sx+lb)) a val (vE, E)=mkprodSumVec(mapp,t) in sumI(a,sx-1,[vE]@rest',E@code') end in sumI(mapp, (ub-lb), [],[]) end | sumDot _= raise Fail "Non-variable index in summation" end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |