Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/genH.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/genH.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2584 - (download) (annotate)
Tue Apr 15 03:22:58 2014 UTC (5 years, 5 months ago) by cchiw
File size: 10610 byte(s)
Multiply Fields
(*hashs Ein Function after substitution*)
structure gHelper = struct
    local
    structure E = Ein
    
   (* structure genKrn=genKrn*)

    

structure DstIL = LowIL
structure DstTy = LowILTypes
   structure DstOp = LowOps
  structure Var = LowIL.Var

structure SrcIL = MidIL
structure SrcOp = MidOps
structure SrcSV = SrcIL.StateVar
structure SrcTy = MidILTypes
structure VTbl = SrcIL.Var.Tbl

    in



fun insert (key, value) d =fn s =>
if s = key then SOME value
else d s

fun lookup k d = d k
val empty =fn key =>NONE
fun findDup(list1,list2)=let
    fun current []=NONE
    | current(v::vs)=let
        val m=List.find (fn x => x =v) list2
        in (case m
            of NONE =>current(vs)
            |_=> m
            (*end case*))
        end
    in current list1
    end



val bV= ref 0

fun printgetRHS x  = 

    (case DstIL.Var.binding x
    of vb => String.concat[
        "\n Found ", DstIL.vbToString vb,"\n"]

    (* end case *))




fun getKernel x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _),_))=> h
    | vb => (raise Fail (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
        (* end case *))


fun getImage  x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage(img),_))=> img
    | vb => (raise Fail (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
(* end case *))





fun printX(DstIL.ASSGN (x, DstIL.OP(opss,args)))= let
          
        val a= print(String.concat([Var.toString  x,"==",DstOp.toString opss," : "]))
            in print (String.concatWith "," (List.map Var.toString args)) end
    | printX(DstIL.ASSGN(x,DstIL.LIT _))= print(String.concat[Var.toString  x,"==...Lit"])
    | printX(DstIL.ASSGN(x,DstIL.CONS (_, varl)))= let
             val y= List.map (fn e1=> Var.toString e1) varl
            in print(String.concat[(Var.toString  x),"==",(String.concatWith "," y)]) end
    | printX(DstIL.ASSGN (x, _))=print(String.concat[Var.toString  x,"==","CONS",printgetRHS x])



fun  printTy(DstTy.IntTy )= "int "
    | printTy(DstTy.TensorTy [])= "Real "
    |printTy(DstTy.TensorTy(dd))=String.concat[
    "TensorTy[", String.concatWith "," (List.map Int.toString dd), "] "]


fun aaV(opss,args,pre,ty)=let

    (*problem here forces variable binding *)
    val a=DstIL.Var.new(pre ,ty)
   (* val m=printTy ty
    val z=print(String.concat["\n", m] )*)
    val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
  (*  val g=printX code*)
    in
        (a,[code])
    end




fun mkMultiple(list1,rator,ty)=let
    fun add([],_)=raise Fail "no element in addM"
        | add([e1],_)=(e1,[])
        | add([e1,e2],code)=let
            val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
            in  (vA,code@A)
            end
        | add(e1::e2::es,code)=let
            val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
            in  add(vA::es,code@A)
            end
    in  add(list1,[])
    end



fun mapIndex(e1,mapp)=(case e1
    of E.V e =>let
        val a=lookup e mapp
        in (case a of NONE=> raise Fail "Outside Bound"
        |SOME s => s) end
    | E.C c=> c
    (*end case*))

fun printIndexXX(n,mapp)=let
    val a=lookup n mapp
    in (case a
        of NONE=> print("-\n")
        |SOME (s) => (print(String.concat[Int.toString(n), "==>",Int.toString(s)]);printIndexXX(n+1,mapp))
        (*end case*))
    end


fun getShape(params, id)=(case List.nth(params,id)
        of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
        | E.TEN(_,shape)=> DstTy.TensorTy shape
        |_=> raise Fail "NONE Tensor Param")



fun mkSca(mapp,(id,ix1,(args,params)))= let
        val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1
        val nU=List.nth(args,id)
        val i=DstTy.indexTy(ix1')
        val a=getShape(params,id)
        in aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([]))
        end


fun mkVec(mapp,(id,ix1,last,(args,params)))= let
    val g="New Vec"
    val gg=printIndexXX(0, mapp)
    val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1
    val nU=List.nth(args,id)
    val i=DstTy.indexTy(ix1')
    val a=getShape(params,id)
    in aaV(DstOp.V(id, last, i,a),[nU],"V"^Int.toString(id),DstTy.TensorTy([last])) end

(*Helper functions for addition *)
fun handleAddVec(mapp,(es,index,last,args))=let
    val m=print "made it to handleAdd vec"
    fun add([],rest,code)=(rest,code)
   (* | add((id1,[])::es,rest,code)=let
        val (vA,A)= mkVec(mapp,(id1,index,args))
        in add(es,rest@[vA],code@A)
        end*)

    | add((id1,ix1)::es,rest,code)=let
        val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
        in add(es,rest@[vA],code@A)
        end

    val (rest,code)=add(es,[],[])
    val (vA,A)=mkMultiple( rest,DstOp.addVec(last),DstTy.TensorTy([last]))
    in  (vA,code@A)
    end



(*Subtract SCalars*)
fun mksubSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let
    val (vA,A)=mkSca(mapp,(id1,ix1,args))
    val (vB, B)=mkSca(mapp,(id2, ix2,args))
    val (vD, D)=aaV(DstOp.subSca,[vA, vB],"SubSca",DstTy.TensorTy([]))
    in (vD, A@B@D)end


(*subtract Vectors*)
fun mksubVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let
     val mm= printIndexXX(0,mapp)
    val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
    val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
    val (vD, D)=aaV(DstOp.subVec(last),[vA, vB],"subVec",DstTy.TensorTy([last]))
    in (vD, A@B@D) end



(*Product functions*)
(*product of 2 scalars*)
fun mkprodSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let
    val (vA,A)=mkSca(mapp,(id1,ix1,args))
    val (vB, B)=mkSca(mapp,(id2, ix2,args))
    val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([]))
    in (vD, A@B@D)end
    | mkprodSca _= raise Fail "Prod----d---"

(*
(*product of 2 scalars*)
fun mkprodScaR(_,([(id1,ix1),(id2,ix2)],[],args))= let
val (vA,A)=mkSca(mapp,(id1,ix1,args))
val (vB, B)=mkSca(mapp,(id2, ix2,args))
        aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([]))

val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([]))
in (vD, A@B@D)end
| mkprodSca _= raise Fail "Prod----d---"
*)

(*product of 1 scalars and 1 projection*)
fun mkprodScaV(mapp,([(id1,ix1),(id2,ix2)],[],last,args))=let
     val mm= printIndexXX(0,mapp)
    val (vA,A)=mkSca(mapp,(id1,ix1,args))
    val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
        val q=print(String.concat["Puppy-In prodScaV",Int.toString(last)])

    val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last]))
    in (vD,A@B@D) end

(*product of 2 projections*)
fun mkprodVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let
    val rr=print "\n mkprodVec"
    val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
    val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
    val (vD, D)=aaV(DstOp.prodVec(last),[vA, vB],"prodV",DstTy.TensorTy([last]))
    in (vD, A@B@D)
    end
(*error here *)
(*summation over product of 2 projections*)
fun mkprodSumVec(mapp,(m,[],i,args))= let
    val rr=print "\n In prod sum vec"
    val i'=i+1
    val (vD,D)=mkprodVec(mapp,(m,[],i',args))
    val (vE, E)=aaV(DstOp.sumVec(i'),[vD],"sumVec",DstTy.realTy)
    in (vE, D @E)
    end

(*product of -1 and 1 projection*)
fun mkNegV(mapp,((vA,id,ix),[],last,args))=let
    val aaa=print "\n pre mkVec"
    val (vB, B)= mkVec(mapp,(id,ix,last,args))
    val b= print "\n post mkVec"
    val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last]))
    in (vD,B@D) end



(*Dot Product like summation
Does Vec x Vec *)
fun sumDot(a, ( m,sx,last,args))=let
    val [(E.V v,lb,ub)]=sx
    fun sumI(a,0,rest,code)=let
        val mapp =insert(v, 0) a
        (*val mapp=a@[lb]*)
        val (vD,pre)=mkprodVec(mapp,(m,[],last,args))
        val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([]))
        val rest'=[vE]@rest
        val (vF, F)=mkMultiple( rest',DstOp.addSca,DstTy.TensorTy([]))
        in  (vF,pre@E@code@F)    end
    | sumI(a,sx,rest',code')=let
        (* val mapp=a@[(sx+lb)]*)
        val mapp =insert(v, (sx+lb)) a
        val (vD,pre)=mkprodVec(mapp,(m,[],last,args))
        val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([]))
        in sumI(a,sx-1,[vE]@rest',pre@E@code') end
    in sumI(a, (ub-lb), [],[]) end


(*Can do multiple summations *)
fun sum(a, ( m,sx,args))=let
    val mss=print "\n IN SUM"
    fun sumI1(left,(v,0,lb1),[],rest,code)=let
            (*val mapp=a@left@[lb1]*)
            val mapp =insert(v, lb1) left
            val (vD,pre)=mkprodSca(mapp,(m,[],args))
            in ([vD]@rest,pre@code)
            end
        |  sumI1(left,(v,i,lb1),[],rest,code)=let
            val mapp =insert(v, i+lb1) left
            val (vD,pre)=mkprodSca(mapp,(m,[],args))
            in sumI1(left,(v,i-1,lb1),[],[vD]@rest,pre@code)
            end
        | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
            val mapp =insert(v, lb1) left
            in sumI1(mapp,(a,ub-lb2,lb2),sx,rest,code) end 
        | sumI1(left,(v,s,lb1),(E.V v',lb2,ub)::sx,rest,code)=let
                val mapp =insert(v, s+lb1) left
                val (rest',code')=sumI1(mapp,(v',ub-lb2,lb2),sx,rest,code)
            in sumI1(left,(v,s-1,lb1),(E.V v',lb2,ub)::sx,rest',code') end

    val (E.V v,lb,ub)=hd(sx)
    val(li, code)=sumI1(empty,(v,ub-lb,lb),tl(sx),[],[])
    val (vF, F)=mkMultiple(li,DstOp.addSca,DstTy.TensorTy([]))
    in (vF,code@F) end


fun mkC n= let
    val (vB,B)=aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([]))
    val m=print"postmkC"
    in (vB,B) end


fun evalDelta2(a,b,mapp)= let
    val i=mapIndex(a,mapp)
    val j=mapIndex(b,mapp)
    in if(i=j) then mkC 1  else mkC 0

    end

(*Field/Kern*)
fun evalDelta(dels,mapp)=let
    fun m(a,b)=if(a=b) then 1 else 0
    fun ij(i,j)=(case (i,j)
        of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
        | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
        | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
        | (E.C a, E.C b)=>m(i,j)
        (*end case*))
    val dels'=List.map ij dels
in
List.foldl(fn(x,y)=>x+y) 0 dels'
end




fun evalEps(a,b,c,mapp)=let
    val i=mapIndex(E.V a,mapp)
    val j=mapIndex(E.V b,mapp)
    val k=mapIndex(E.V c,mapp)
    in
        if(i=j orelse j=k orelse i=k) then 0
        else
            if(j>i) then
                if(j>k andalso k>i) then ~1 else 1
            else if(i>k andalso k>j) then 1 else ~1
       
    end



fun skeleton A=(case A
    of [DstIL.ASSGN(_,DstIL.OP(DstOp.C 0,_))]=>0
    |  [DstIL.ASSGN(_,DstIL.OP(DstOp.C 1,_))]=>1
    |  [DstIL.ASSGN(_,DstIL.OP(DstOp.C ~1,_))]=> ~1
    | _ => 9
    (*end case*))
    end




end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0