Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/gen-helpers.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/gen-helpers.sml

Parent Directory Parent Directory | Revision Log Revision Log

Revision 2525 - (download) (annotate)
Tue Jan 21 19:14:22 2014 UTC (5 years, 7 months ago) by cchiw
File size: 8994 byte(s)
(*hashs Ein Function after substitution*)
structure genHelper = struct
    structure E = Ein
   (* structure genKrn=genKrn*)


structure DstIL = LowIL
structure DstTy = LowILTypes
   structure DstOp = LowOps
  structure Var = LowIL.Var

structure SrcIL = MidIL
structure SrcOp = MidOps
structure SrcSV = SrcIL.StateVar
structure SrcTy = MidILTypes
structure VTbl = SrcIL.Var.Tbl


fun findDup(list1,list2)=let
    fun current []=NONE
    | current(v::vs)=let
        val m=List.find (fn x => x =v) list2
        in (case m
            of NONE =>current(vs)
            |_=> m
            (*end case*))
    in current list1

val bV= ref 0

fun printgetRHS x  = 

    (case DstIL.Var.binding x
    of vb => String.concat[
        "\n Found ", DstIL.vbToString vb,"\n"]

    (* end case *))

fun getKernel x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _),_))=> h
    | vb => (raise Fail (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
        (* end case *))

fun printX(DstIL.ASSGN (x, DstIL.OP(opss,args)))= let
        val a= print(String.concat(["\n",Var.toString  x,"==",DstOp.toString opss," : "]))
            in print (String.concatWith "," (List.map Var.toString args)) end
    | printX(DstIL.ASSGN(x,DstIL.LIT _))= print(String.concat["\n :",Var.toString  x,"==...Lit"])
    | printX(DstIL.ASSGN(x,DstIL.CONS (_, varl)))= let
             val y= List.map (fn e1=> Var.toString e1) varl
            in print(String.concat[ "\n",(Var.toString  x),"==",(String.concatWith "," y)]) end
    | printX(DstIL.ASSGN (x, _))=print(String.concat["\n",Var.toString  x,"==","CONS",printgetRHS x])

fun printTy(DstTy.TensorTy(dd))=String.concat[
            "Argument:[", String.concatWith "," (List.map Int.toString dd), "]"]

fun aaV(opss,args,pre,ty)=let

    (*problem here forces variable binding *)
    val a=DstIL.Var.new(pre ,ty)
    val m=printTy ty
    val z=print(String.concat["\n -----------------------  \n Created Var", m ])
    val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
    val g=printX code

fun mkMultiple(list1,rator,ty)=let
    fun add([],_)=raise Fail "no element in addM"
        | add([e1],_)=(e1,[])
        | add([e1,e2],code)=let
            val (vA,A)=aaV(rator,[e1,e2],"PV",ty)
            in  (vA,code@A)
        | add(e1::e2::es,code)=let
            val (vA,A)=aaV(rator,[e1,e2],"PV",ty)
            in  add(vA::es,code@A)
    in  add(list1,[])

fun mapIndex(e1,mapp)=(case e1
    of E.V e =>List.nth(mapp, e)
    | E.C c=> c
    (*end case*))

fun getShape(params, id)=(case List.nth(params,id)
        of E.TEN(3,shape)=> DstTy.iVecTy(2) (*FIX HERE*)
        | E.TEN(_,shape)=> DstTy.TensorTy shape
        |_=> raise Fail "NONE Tensor Param")

fun mkSca(mapp,(id,ix1,(args,params)))= let
        val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1
        val nU=List.nth(args,id)
        val i=DstTy.indexTy(ix1')
        val a=getShape(params,id)
        in aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([]))

fun mkVec(mapp,(id,ix1,last,(args,params)))= let
    val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1
    val nU=List.nth(args,id)
    val i=DstTy.indexTy(ix1')
    val a=getShape(params,id)
    in aaV(DstOp.V(id, last, i,a),[nU],"V"^Int.toString(id),DstTy.TensorTy([last])) end

(*Helper functions for addition *)
fun handleAddVec(mapp,(es,index,last,args))=let
    val m=print "made it to handleAdd vec"
    fun add([],rest,code)=(rest,code)
   (* | add((id1,[])::es,rest,code)=let
        val (vA,A)= mkVec(mapp,(id1,index,args))
        in add(es,rest@[vA],code@A)
    | add((id1,ix1)::es,rest,code)=let
        val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
        in add(es,rest@[vA],code@A)

    val (rest,code)=add(es,[],[])
    val (vA,A)=mkMultiple( rest,DstOp.addVec(last),DstTy.TensorTy([last]))
    in  (vA,code@A)

(*Subtract SCalars*)
fun mksubSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let
    val (vA,A)=mkSca(mapp,(id1,ix1,args))
    val (vB, B)=mkSca(mapp,(id2, ix2,args))
    val (vD, D)=aaV(DstOp.subSca,[vA, vB],"SubSca",DstTy.TensorTy([]))
    in (vD, A@B@D)end

(*subtract Vectors*)
fun mksubVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let
    val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
    val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
    val (vD, D)=aaV(DstOp.subVec(last),[vA, vB],"subVec",DstTy.TensorTy([last]))
    in (vD, A@B@D) end

(*Product functions*)
(*product of 2 scalars*)
fun mkprodSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let
    val (vA,A)=mkSca(mapp,(id1,ix1,args))
    val (vB, B)=mkSca(mapp,(id2, ix2,args))
    val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([]))
    in (vD, A@B@D)end

(*product of 1 scalars and 1 projection*)
fun mkprodScaV(mapp,([(id1,ix1),(id2,ix2)],[],last,args))=let
    val (vA,A)=mkSca(mapp,(id1,ix1,args))
    val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
        val q=print(String.concat["Puppy-In prodScaV",Int.toString(last)])

    val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last]))
    in (vD,A@B@D) end

(*product of 2 projections*)
fun mkprodVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let
    val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
    val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
    val (vD, D)=aaV(DstOp.prodVec(last),[vA, vB],"prodV",DstTy.TensorTy([last]))
    in (vD, A@B@D)
(*error here *)
(*summation over product of 2 projections*)
fun mkprodSumVec(mapp,(m,[],i,args))= let
    val (vD,D)=mkprodVec(mapp,([],m,i,args))
    val (vE, E)=aaV(DstOp.sumVec(i),[vD],"sumVec",DstTy.TensorTy([i]))
    in (vE, D @E)

(*product of -1 and 1 projection*)
fun mkNegV(mapp,((vA,id,ix),[],last,args))=let
    val aaa=print "\n pre mkVec"
    val (vB, B)= mkVec(mapp,(id,ix,last,args))
    val b= print "\n post mkVec"
    val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last]))
    in (vD,B@D) end

(*Dot Product like summation
Does Vec x Vec *)
fun sumDot(a, ( m,sx,last,args))=let
    val [(_,lb,ub)]=sx
    fun sumI(a,0,rest,code)=let
        val mapp=a@[lb]
        val (vD,pre)=mkprodVec(mapp,([],m,last,args))
        val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([last]))
        val rest'=[vE]@rest
        val (vF, F)=mkMultiple( rest',DstOp.addSca,DstTy.TensorTy([]))
        in  (vF,pre@E@code@F)    end
    | sumI(a,sx,rest',code')=let
        val mapp=a@[(sx+lb)]
        val (vD,pre)=mkprodVec(mapp,([],m,last,args))
        val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([last]))
        in sumI(a,sx-1,[vE]@rest',pre@E@code') end
    in sumI(a, (ub-lb), [],[]) end

(*Can do multiple summations *)
fun sum(a, ( m,sx,args))=let
    fun sumI1(left,(0,lb1),[],rest,code)=let          
            val mapp=a@left@[lb1]
            val (vD,pre)=mkprodSca(mapp,([],m,args))
            in ([vD]@rest,pre@code)
        |  sumI1(left,(i,lb1),[],rest,code)=let
            val mapp=a@left@[i+lb1]
            val (vD,pre)=mkprodSca(mapp,([],m,args))
            in sumI1(left,(i-1,lb1),[],[vD]@rest,pre@code)
        | sumI1(left,(0,lb1),(a,lb2,ub)::sx,rest,code)=
        | sumI1(left,(s,lb1),(a,lb2,ub)::sx,rest,code)=let
                val (rest',code')=sumI1(left@[s+lb1],(ub-lb2,lb2),sx,rest,code)
            in sumI1(left,(s-1,lb1),(E.V 0,lb2,ub)::sx,rest',code') end

    val (_,lb,ub)=hd(sx)
    val(li, code)=sumI1([],(ub-lb,lb),tl(sx),[],[])
    val (vF, F)=mkMultiple(li,DstOp.addSca,DstTy.TensorTy([]))
    in (vF,code@F) end

fun mkC n= let
    val (vB,B)=aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([]))
    val m=print"postmkC"
    in (vB,B) end

fun evalDelta2(a,b,mapp)= let
    val i=mapIndex(a,mapp)
    val j=mapIndex(b,mapp)
    in if(i=j) then mkC 1  else mkC 0


fun evalDelta(dels,mapp)=let
    fun m(a,b)=if(a=b) then 1 else 0
    fun ij(i,j)=(case (i,j)
        of (E.V a, E.V b)=>m(List.nth(mapp, a),List.nth(mapp, b))
        | (E.C a, E.V b)=>m(a,List.nth(mapp, b))
        | (E.V a, E.C b)=>m(List.nth(mapp, a),b)
        | (E.C a, E.C b)=>m(a,b)
        (*end case*))
    val dels'=List.map ij dels
List.foldl(fn(x,y)=>x+y) 0 dels'

fun evalEps(a,b,c,mapp)=let
    val i=mapIndex(E.V a,mapp)
    val j=mapIndex(E.V b,mapp)
    val k=mapIndex(E.V c,mapp)
        if(i=j orelse j=k orelse i=k) then 0
            if(j>i) then
                if(j>k andalso k>i) then ~1 else 1
            else if(i>k andalso k>j) then 1 else ~1

fun skeleton A=(case A
    of [DstIL.ASSGN(_,DstIL.OP(DstOp.C 0,_))]=>0
    |  [DstIL.ASSGN(_,DstIL.OP(DstOp.C 1,_))]=>1
    |  [DstIL.ASSGN(_,DstIL.OP(DstOp.C ~1,_))]=> ~1
    | _ => 9
    (*end case*))


ViewVC Help
Powered by ViewVC 1.0.0