SCM Repository
View of /branches/charisee/src/compiler/mid-to-low/genH.sml
Parent Directory
|
Revision Log
Revision 2584 -
(download)
(annotate)
Tue Apr 15 03:22:58 2014 UTC (6 years, 10 months ago) by cchiw
File size: 10610 byte(s)
Tue Apr 15 03:22:58 2014 UTC (6 years, 10 months ago) by cchiw
File size: 10610 byte(s)
Multiply Fields
(*hashs Ein Function after substitution*) structure gHelper = struct local structure E = Ein (* structure genKrn=genKrn*) structure DstIL = LowIL structure DstTy = LowILTypes structure DstOp = LowOps structure Var = LowIL.Var structure SrcIL = MidIL structure SrcOp = MidOps structure SrcSV = SrcIL.StateVar structure SrcTy = MidILTypes structure VTbl = SrcIL.Var.Tbl in fun insert (key, value) d =fn s => if s = key then SOME value else d s fun lookup k d = d k val empty =fn key =>NONE fun findDup(list1,list2)=let fun current []=NONE | current(v::vs)=let val m=List.find (fn x => x =v) list2 in (case m of NONE =>current(vs) |_=> m (*end case*)) end in current list1 end val bV= ref 0 fun printgetRHS x = (case DstIL.Var.binding x of vb => String.concat[ "\n Found ", DstIL.vbToString vb,"\n"] (* end case *)) fun getKernel x = (case SrcIL.Var.binding x of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _),_))=> h | vb => (raise Fail (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"])) (* end case *)) fun getImage x = (case SrcIL.Var.binding x of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage(img),_))=> img | vb => (raise Fail (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"])) (* end case *)) fun printX(DstIL.ASSGN (x, DstIL.OP(opss,args)))= let val a= print(String.concat([Var.toString x,"==",DstOp.toString opss," : "])) in print (String.concatWith "," (List.map Var.toString args)) end | printX(DstIL.ASSGN(x,DstIL.LIT _))= print(String.concat[Var.toString x,"==...Lit"]) | printX(DstIL.ASSGN(x,DstIL.CONS (_, varl)))= let val y= List.map (fn e1=> Var.toString e1) varl in print(String.concat[(Var.toString x),"==",(String.concatWith "," y)]) end | printX(DstIL.ASSGN (x, _))=print(String.concat[Var.toString x,"==","CONS",printgetRHS x]) fun printTy(DstTy.IntTy )= "int " | printTy(DstTy.TensorTy [])= "Real " |printTy(DstTy.TensorTy(dd))=String.concat[ "TensorTy[", String.concatWith "," (List.map Int.toString dd), "] "] fun aaV(opss,args,pre,ty)=let (*problem here forces variable binding *) val a=DstIL.Var.new(pre ,ty) (* val m=printTy ty val z=print(String.concat["\n", m] )*) val code=DstIL.ASSGN (a,DstIL.OP(opss,args)) (* val g=printX code*) in (a,[code]) end fun mkMultiple(list1,rator,ty)=let fun add([],_)=raise Fail "no element in addM" | add([e1],_)=(e1,[]) | add([e1,e2],code)=let val (vA,A)=aaV(rator,[e1,e2],"MO",ty) in (vA,code@A) end | add(e1::e2::es,code)=let val (vA,A)=aaV(rator,[e1,e2],"MO",ty) in add(vA::es,code@A) end in add(list1,[]) end fun mapIndex(e1,mapp)=(case e1 of E.V e =>let val a=lookup e mapp in (case a of NONE=> raise Fail "Outside Bound" |SOME s => s) end | E.C c=> c (*end case*)) fun printIndexXX(n,mapp)=let val a=lookup n mapp in (case a of NONE=> print("-\n") |SOME (s) => (print(String.concat[Int.toString(n), "==>",Int.toString(s)]);printIndexXX(n+1,mapp)) (*end case*)) end fun getShape(params, id)=(case List.nth(params,id) of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*) | E.TEN(_,shape)=> DstTy.TensorTy shape |_=> raise Fail "NONE Tensor Param") fun mkSca(mapp,(id,ix1,(args,params)))= let val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1 val nU=List.nth(args,id) val i=DstTy.indexTy(ix1') val a=getShape(params,id) in aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([])) end fun mkVec(mapp,(id,ix1,last,(args,params)))= let val g="New Vec" val gg=printIndexXX(0, mapp) val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1 val nU=List.nth(args,id) val i=DstTy.indexTy(ix1') val a=getShape(params,id) in aaV(DstOp.V(id, last, i,a),[nU],"V"^Int.toString(id),DstTy.TensorTy([last])) end (*Helper functions for addition *) fun handleAddVec(mapp,(es,index,last,args))=let val m=print "made it to handleAdd vec" fun add([],rest,code)=(rest,code) (* | add((id1,[])::es,rest,code)=let val (vA,A)= mkVec(mapp,(id1,index,args)) in add(es,rest@[vA],code@A) end*) | add((id1,ix1)::es,rest,code)=let val (vA,A)= mkVec(mapp,(id1,ix1,last,args)) in add(es,rest@[vA],code@A) end val (rest,code)=add(es,[],[]) val (vA,A)=mkMultiple( rest,DstOp.addVec(last),DstTy.TensorTy([last])) in (vA,code@A) end (*Subtract SCalars*) fun mksubSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let val (vA,A)=mkSca(mapp,(id1,ix1,args)) val (vB, B)=mkSca(mapp,(id2, ix2,args)) val (vD, D)=aaV(DstOp.subSca,[vA, vB],"SubSca",DstTy.TensorTy([])) in (vD, A@B@D)end (*subtract Vectors*) fun mksubVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let val mm= printIndexXX(0,mapp) val (vA,A)= mkVec(mapp,(id1,ix1,last,args)) val (vB, B)= mkVec(mapp,(id2,ix2,last,args)) val (vD, D)=aaV(DstOp.subVec(last),[vA, vB],"subVec",DstTy.TensorTy([last])) in (vD, A@B@D) end (*Product functions*) (*product of 2 scalars*) fun mkprodSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let val (vA,A)=mkSca(mapp,(id1,ix1,args)) val (vB, B)=mkSca(mapp,(id2, ix2,args)) val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([])) in (vD, A@B@D)end | mkprodSca _= raise Fail "Prod----d---" (* (*product of 2 scalars*) fun mkprodScaR(_,([(id1,ix1),(id2,ix2)],[],args))= let val (vA,A)=mkSca(mapp,(id1,ix1,args)) val (vB, B)=mkSca(mapp,(id2, ix2,args)) aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([])) val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([])) in (vD, A@B@D)end | mkprodSca _= raise Fail "Prod----d---" *) (*product of 1 scalars and 1 projection*) fun mkprodScaV(mapp,([(id1,ix1),(id2,ix2)],[],last,args))=let val mm= printIndexXX(0,mapp) val (vA,A)=mkSca(mapp,(id1,ix1,args)) val (vB, B)= mkVec(mapp,(id2,ix2,last,args)) val q=print(String.concat["Puppy-In prodScaV",Int.toString(last)]) val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last])) in (vD,A@B@D) end (*product of 2 projections*) fun mkprodVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let val rr=print "\n mkprodVec" val (vA,A)= mkVec(mapp,(id1,ix1,last,args)) val (vB, B)= mkVec(mapp,(id2,ix2,last,args)) val (vD, D)=aaV(DstOp.prodVec(last),[vA, vB],"prodV",DstTy.TensorTy([last])) in (vD, A@B@D) end (*error here *) (*summation over product of 2 projections*) fun mkprodSumVec(mapp,(m,[],i,args))= let val rr=print "\n In prod sum vec" val i'=i+1 val (vD,D)=mkprodVec(mapp,(m,[],i',args)) val (vE, E)=aaV(DstOp.sumVec(i'),[vD],"sumVec",DstTy.realTy) in (vE, D @E) end (*product of -1 and 1 projection*) fun mkNegV(mapp,((vA,id,ix),[],last,args))=let val aaa=print "\n pre mkVec" val (vB, B)= mkVec(mapp,(id,ix,last,args)) val b= print "\n post mkVec" val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last])) in (vD,B@D) end (*Dot Product like summation Does Vec x Vec *) fun sumDot(a, ( m,sx,last,args))=let val [(E.V v,lb,ub)]=sx fun sumI(a,0,rest,code)=let val mapp =insert(v, 0) a (*val mapp=a@[lb]*) val (vD,pre)=mkprodVec(mapp,(m,[],last,args)) val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([])) val rest'=[vE]@rest val (vF, F)=mkMultiple( rest',DstOp.addSca,DstTy.TensorTy([])) in (vF,pre@E@code@F) end | sumI(a,sx,rest',code')=let (* val mapp=a@[(sx+lb)]*) val mapp =insert(v, (sx+lb)) a val (vD,pre)=mkprodVec(mapp,(m,[],last,args)) val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([])) in sumI(a,sx-1,[vE]@rest',pre@E@code') end in sumI(a, (ub-lb), [],[]) end (*Can do multiple summations *) fun sum(a, ( m,sx,args))=let val mss=print "\n IN SUM" fun sumI1(left,(v,0,lb1),[],rest,code)=let (*val mapp=a@left@[lb1]*) val mapp =insert(v, lb1) left val (vD,pre)=mkprodSca(mapp,(m,[],args)) in ([vD]@rest,pre@code) end | sumI1(left,(v,i,lb1),[],rest,code)=let val mapp =insert(v, i+lb1) left val (vD,pre)=mkprodSca(mapp,(m,[],args)) in sumI1(left,(v,i-1,lb1),[],[vD]@rest,pre@code) end | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let val mapp =insert(v, lb1) left in sumI1(mapp,(a,ub-lb2,lb2),sx,rest,code) end | sumI1(left,(v,s,lb1),(E.V v',lb2,ub)::sx,rest,code)=let val mapp =insert(v, s+lb1) left val (rest',code')=sumI1(mapp,(v',ub-lb2,lb2),sx,rest,code) in sumI1(left,(v,s-1,lb1),(E.V v',lb2,ub)::sx,rest',code') end val (E.V v,lb,ub)=hd(sx) val(li, code)=sumI1(empty,(v,ub-lb,lb),tl(sx),[],[]) val (vF, F)=mkMultiple(li,DstOp.addSca,DstTy.TensorTy([])) in (vF,code@F) end fun mkC n= let val (vB,B)=aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([])) val m=print"postmkC" in (vB,B) end fun evalDelta2(a,b,mapp)= let val i=mapIndex(a,mapp) val j=mapIndex(b,mapp) in if(i=j) then mkC 1 else mkC 0 end (*Field/Kern*) fun evalDelta(dels,mapp)=let fun m(a,b)=if(a=b) then 1 else 0 fun ij(i,j)=(case (i,j) of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp)) | (E.C a, E.V b)=>m(a,mapIndex(j,mapp)) | (E.V a, E.C b)=>m(mapIndex(i,mapp),b) | (E.C a, E.C b)=>m(i,j) (*end case*)) val dels'=List.map ij dels in List.foldl(fn(x,y)=>x+y) 0 dels' end fun evalEps(a,b,c,mapp)=let val i=mapIndex(E.V a,mapp) val j=mapIndex(E.V b,mapp) val k=mapIndex(E.V c,mapp) in if(i=j orelse j=k orelse i=k) then 0 else if(j>i) then if(j>k andalso k>i) then ~1 else 1 else if(i>k andalso k>j) then 1 else ~1 end fun skeleton A=(case A of [DstIL.ASSGN(_,DstIL.OP(DstOp.C 0,_))]=>0 | [DstIL.ASSGN(_,DstIL.OP(DstOp.C 1,_))]=>1 | [DstIL.ASSGN(_,DstIL.OP(DstOp.C ~1,_))]=> ~1 | _ => 9 (*end case*)) end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |