Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/mid-to-low/sca-to-low-set.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/mid-to-low/sca-to-low-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3395 - (download) (annotate)
Tue Nov 10 18:23:07 2015 UTC (4 years, 10 months ago) by cchiw
File size: 9395 byte(s)
val-num in evalkrn
 (*convert EIN with tensors, indexed as scalars to LowIL
*use Scalar ops like addSca, subSca..
*)
structure ScaToLowSet = struct
    local
    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure LowToS= LowToString
    structure Var = LowIL.Var
    structure E = Ein
    structure EtoFld= FieldToLowSet
    structure H=HelperSet

    in

    fun evalField e= EtoFld.evalField e
    fun insert e=H.insert e
    fun mapIndex e=H.mapIndex e
    fun  intToReal e =H.intToReal e
    fun indexTensor e=H.indexTensor e
    fun mkSubSca e= H.mkSubSca e
    fun mkProdSca e=H.mkProdSca e
    fun mkDivSca e= H.mkDivSca e
    fun mkMultiple e=H.mkMultiple e
    fun evalDelta e=H.evalDelta e
    fun evalEps2 e=H.evalEps2 e
    fun evalEps3 e=H.evalEps3 e
    fun mkSqrt e =H.mkSqrt e
    fun mkCosine e =H.mkCosine e
    fun mkArcCosine e =H.mkArcCosine e
    fun mkSine e =H.mkSine e
    fun mkArcSine e =H.mkArcSine e
    fun mkPowInt e =H.mkPowInt e
    fun mkPowRat e =H.mkPowRat e

    fun errField ()=raise Fail("Invalid Field Here")
    val realTy=DstTy.TensorTy([])

    (*skeleton :DstIL.ASSGN->int
     * Get constant
    *)
     fun skeleton A=(case A
            of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0
            | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1
            | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1
            | _ => 9
    (*end case*))

    (*negCheck0: string*(Var*LowIL.Assgn)-> Var*LowIL.Assgn
    * check -A for a rewrite 
    *)
    fun negCheck0(setA,vA,A)=
        (case skeleton A
            of 0 =>  intToReal(setA, 0)
            | ~1 =>  intToReal(setA, 1)
            | 1  =>  intToReal(setA, ~1)
            |  _=> let
                val (setB,vB,B)= intToReal(setA, ~1)
                val (setD,vD,D)=mkProdSca (setB,[vB,vA])
                in (setD,vD,A@B@D) end
        (*end case*))

    (*subCheck0: string*(Var*LowIL.Assgn)*(Var*LowIL.Assgn)-> Var*LowIL.Assgn
    * check A-B for a rewrite
    *)
    fun subCheck0((setA,vA,A),(setB,vB,B))=
        (case((skeleton A),(skeleton B))
            of (0,0)=>  intToReal(setA, 0)
                |(0,_)=> let
                    val (setC,vC,C)=  intToReal(setA, ~1)
                    val (setD,vD,D)= mkProdSca(setC, [vC,vB])
                    in (setD,vD,B@C@D) end
                | (_,0)=> (setA,vA,A)
                | _ => let
                    val (setD,vD,D)= mkSubSca(setA,[vA,vB])
                    in (setD,vD, A@B@D) end
        (*end case*))

    (*divCheck0: string*(Var*LowIL.Assgn)*(Var*LowIL.Assgn)-> Var*LowIL.Assgn
    * check A/B for a rewrite
    *)
    fun divCheck0((setA,vA,A),(setB,vB,B))=(case (skeleton A)
        of 0=>  intToReal(setA, 0)
        | _ => let
            val (setC,vC,C)= mkDivSca(setA, [vA,vB])
            in (setC,vC, A@B@C) end
        (*end case*))

    (* generalfn:dict*(string*E.EIN*Var list)->Var*DstIL.Assgn list
    * general expressions
    * addCheckO and ProdcheckO-removes zeros
    *)
    fun generalfn(setOrig,dict,(lhs:string,e:Ein.ein,args:LowIL.var list))=let
       val mapp=ref dict
       (*val (lhs,e,args)=info*)
        val info=(lhs,e,args)
       val params=Ein.params e
       fun gen(setG,body)= let

            (*addCheck0, removes 0s*)
            fun AddcheckO (opset,[],[],[])=let val (opset,vA,A)= intToReal(opset, 0) in (opset,[vA],A) end
              | AddcheckO(opset,[],ids,code)=(opset,ids,code)
              | AddcheckO(opset,e1::es,ids,code)=let
                val (setA,a,b)=gen(opset,e1)
                in (case (skeleton b)
                    of 0 => AddcheckO(setA,es,ids,code)
                    |  _ => AddcheckO(setA,es,ids@[a],code@b)
                    (*end case*))
                end
            (*prodCheck, removes 1s, and stops at 0*)
            fun ProdcheckO (opset,[],[],[])=let val (setA,vA,A)= intToReal(opset, 1) in (setA,[vA],A) end
              | ProdcheckO(opset,[],ids,code)=(opset,ids,code)
              | ProdcheckO(opset,e1::es,ids,code)=let
                val (setA,a,b)=gen(opset,e1)
                in (case (skeleton b)
                    of 0 => (setA,[a],b)
                    | 1 => ProdcheckO(setA,es,ids,code)
                    | _ => ProdcheckO(setA,es,ids@[a],code@b)
                    (*end case*))
                end

            (*********sumexpression ********)
              fun tb n= List.tabulate(n,fn e=>e)
            fun Sumcheck(setSx,sumx,e)=let
                val _="\ninside summation"
                fun sumloop(opsetLoop,mapsum)=let
                    val _ = mapp:=mapsum
                    val(setA,vA,A)=gen(opsetLoop, e)
                      (*checks for 0 in the emebedded expression*)
                    in (case (skeleton A)
                        of 0 => (setA,[],A)
                        |  _ => (setA,[vA],A)
                        (*end case*))
                    end
                fun sumI1(setS,left,(v,[i],lb1),[],rest,code)=let
                    val dict=insert(v, lb1+i) left
                    val (setD,vD,pre)= sumloop(setS, dict)
                    in (setD,rest@vD,code@pre) end

                |  sumI1(setS,left,(v,i::es,lb1),[],rest,code)=let
                    val dict=insert(v, (i+lb1)) left
                    val (setD,vD,pre)=sumloop (setS,dict)
                    in sumI1(setD,dict,(v,es,lb1),[],rest@vD,code@pre) end

                | sumI1(setS,left,(v,[i],lb1),(E.V a,lb2,ub2)::sx,rest,code)=let
                    val dict=insert(v, lb1+i) left
                    val xx=tb(ub2-lb2+1)
                    in sumI1(setS,dict,(a,xx,lb2),sx,rest,code) end

                | sumI1(setS,left,(v,s::es,lb1),(E.V a,lb2,ub2)::sx,rest,code)=let
                    val dict=insert(v, (s+lb1)) left
                    val xx=tb(ub2-lb2+1)
                    val (setT,rest',code')=sumI1(setS,dict,(a,xx,lb2),sx,rest,code)
                    in sumI1(setT,dict,(v,es,lb1),(E.V a,lb2,ub2)::sx,rest',code') end
                | sumI1 _ =raise Fail"None Variable-index in summation"


                val (E.V v,lb,ub)=hd(sumx)

                in
                    sumI1(setSx,!mapp,(v,tb(ub-lb+1),lb),tl(sumx),[],[])
                end
            (*********sumexpression ********)

            fun iterList ((opset,[],_),DstOp.addSca) = let val (setA,vA,A)= intToReal(opset, 0) in (setA,vA,A) end
              | iterList ((opset,[id1],code),_) = (opset,id1,code)
              | iterList ((opset,ids,code),rator)   = let
                    val (setB,vB,B)= mkMultiple(opset,ids,rator,realTy)
                    in (setB,vB,code@B) end

            in (case body
                of  E.Field _           => errField()
                | E.Partial _           => errField()
                | E.Apply _             => errField()
                | E.Probe _             => errField()
                | E.Conv _              => errField()
                | E.Krn _               => errField()
                | E.Img _               => errField()
                | E.Lift _              => errField()
                | E.Value v             =>  intToReal(setG,(mapIndex(E.V v,!mapp)))
                | E.Const c             =>  intToReal(setG, c)
                (*| E.ConstR rat        => (R.toReal rat)*)
                | E.Epsilon(i,j,k)      => evalEps3(setG,!mapp,i,j,k)
                | E.Eps2(i,j)           => evalEps2(setG,!mapp,i,j)
                | E.Delta(i,j)          => evalDelta(setG,!mapp,i,j)
                | E.Tensor(id,ix)       => indexTensor(setG,!mapp,(lhs,params,args,id,ix,realTy))
                | E.Sqrt e1             => mkSqrt(gen(setG, e1))
                | E.Cosine e1           => mkCosine(gen(setG, e1))
                | E.ArcCosine e1        => mkArcCosine(gen(setG, e1))
                | E.Sine e1             => mkSine(gen(setG, e1))
                | E.ArcSine e1          => mkArcSine(gen(setG, e1))
                | E.Neg e               => negCheck0(gen(setG, e))
                | E.PowInt(e,n)         =>  mkPowInt(gen(setG,e),n)
                | E.PowReal(e,n)        => mkPowRat(gen(setG,e), n)
                | E.Sub (e1,e2)         => let
                    val (setA,vA,A)=gen(setG,e1)
                    val (setB,vB,B)=gen(setA,e2)
                    in subCheck0((setB,vA,A),(setB,vB,B)) end
                | E.Add e               => iterList(AddcheckO(setG,e,[],[]),DstOp.addSca)
                | E.Prod e              =>let
                    val _ ="\n prod"
                    in iterList(ProdcheckO(setG,e,[],[]),DstOp.prodSca)
                    end
                | E.Div(e1 as E.Tensor(_,[_]),e2 as E.Tensor(_,[]))=>
                        gen (setG,E.Prod[E.Div(E.Const 1, e2),e1])
                | E.Div(e1,e2)          => (*divCheck0(gen(setG, e1), gen(setG, e2))*)
let
val (setA,vA,A)=gen(setG,e1)
val (setB,vB,B)=gen(setA,e2)
in divCheck0((setB,vA,A),(setB,vB,B)) end
                | E.Sum(x,E.Prod(E.Img(Vid,_,_)::E.Krn(Hid,_,_)::_))
                                        => evalField(setG,!mapp,(body,info))
                | E.Sum(sumx, e)        => iterList(Sumcheck(setG,sumx,e),DstOp.addSca)
                | _                     => raise Fail"unsupported ein-exp "
                (*end case*))
                end
          val rtn= gen(setOrig,E.body e)

         in
            rtn
         end

end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0