Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/sca-to-low.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/sca-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2867 - (download) (annotate)
Tue Feb 10 06:52:58 2015 UTC (4 years, 5 months ago) by cchiw
File size: 7151 byte(s)
moved split around, added norm to typechecker, added sqrt to ein
(*convert EIN with tensors, indexed as scalars to LowIL
*use Scalar ops like addSca, subSca..
*)
structure ScaToLow = struct
    local
    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure LowToS= LowToString
    structure Var = LowIL.Var
    structure E = Ein
    structure EtoFld= FieldToLow
    structure H=Helper

    in

    fun evalField e= EtoFld.evalField e
    fun insert e=H.insert e
    fun mapIndex e=H.mapIndex e
    fun mkReal n=H.mkReal n
    fun indexTensor e=H.indexTensor e
    fun mkSubSca e= H.mkSubSca e
    fun mkProdSca e=H.mkProdSca e
    fun mkDivSca e= H.mkDivSca e
    fun mkMultiple e=H.mkMultiple e
    fun evalDelta e=H.evalDelta e
    fun evalEps2 e=H.evalEps2 e
    fun evalEps3 e=H.evalEps3 e
    fun mkSqrt e =H.mkSqrt e 


    fun errField ()=raise Fail("Invalid Field Here")
    val realTy=DstTy.TensorTy([])

    (*skeleton :DstIL.ASSGN->int
     * Get constant
    *)
     fun skeleton A=(case A
        of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0
        | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1
        | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1
        | _ => 9
    (*end case*))

    (*negCheck0: string*(Var*LowIL.Assgn)-> Var*LowIL.Assgn
    * check -A for a rewrite 
    *)
    fun negCheck0(lhs,(vA,A))=(case skeleton A
        of 0 => mkReal 0
        | ~1 => mkReal 1
        | 1  => mkReal ~1
        |  _=> let
            val (vB,B)=mkReal ~1
            val (vD,D)=mkProdSca (lhs,[vB,vA])
            in (vD,A@B@D) end
        (*end case*))

    (*subCheck0: string*(Var*LowIL.Assgn)*(Var*LowIL.Assgn)-> Var*LowIL.Assgn
    * check A-B for a rewrite
    *)
    fun subCheck0(lhs,(vA,A),(vB,B))=(case((skeleton A),(skeleton B))
        of (0,0)=> mkReal 0
            |(0,_)=> let
                val (vC,C)= mkReal ~1
                val (vD,D)= mkProdSca(lhs, [vC,vB])
                in (vD,B@C@D) end
            | (_,0)=> (vA,A)
            | _ => let
                val (vD,D)= mkSubSca(lhs,[vA,vB])
                in (vD, A@B@D) end
        (*end case*))

    (*divCheck0: string*(Var*LowIL.Assgn)*(Var*LowIL.Assgn)-> Var*LowIL.Assgn
    * check A/B for a rewrite
    *)
    fun divCheck0(lhs,(vA,A),(vB,B))=(case (skeleton A)
        of 0=> mkReal 0
        | _ => let
            val (vC,C)= mkDivSca(lhs, [vA,vB])
            in (vC, A@B@C) end
        (*end case*))

    (* generalfn:dict*(string*E.EIN*Var list)->Var*DstIL.Assgn list
    * general expressions
    * addCheckO and ProdcheckO-removes zeros
    *)
    fun generalfn(dict,info)=let
       val mapp=ref dict
       val (lhs,e,args)=info
       val params=Ein.params e
       fun gen body=let
            (*addCheck0, removes 0s*)
            fun AddcheckO ([],[],[])=let val (vA,A)=mkReal 0 in ([vA],A) end
              | AddcheckO([],ids,code)=(ids,code)
              | AddcheckO(e1::es,ids,code)=let
                val (a,b)=gen e1
                in (case (skeleton b)
                    of 0 => AddcheckO(es,ids,code)
                    |  _ => AddcheckO(es,ids@[a],code@b)
                    (*end case*))
                end
            (*prodCheck, removes 1s, and stops at 0*)
            fun ProdcheckO ([],[],[])=let val (vA,A)=mkReal 1 in ([vA],A) end
              | ProdcheckO([],ids,code)=(ids,code)
              | ProdcheckO(e1::es,ids,code)=let
                val (a,b)=gen e1
                in (case (skeleton b)
                    of 0 => ([a],b)
                    | 1 => ProdcheckO(es,ids,code)
                    | _ => ProdcheckO(es,ids@[a],code@b)
                    (*end case*))
                end
            (*checks for 0 in the emebedded expression*)
            fun Sumcheck(sumx,e)=let
                fun sumloop mapsum=let
                val _ = mapp:=mapsum
                val(vA,A)=gen e
                in (case (skeleton A)
                    of 0 => ([],A)
                    |  _ => ([vA],A)
                    (*end case*))
                end
            fun sumI1(left,(v,0,lb1),[],rest,code)=let
                val dict=insert(v, lb1) left
                val (vD,pre)= sumloop dict
                in (vD@rest,pre@code) end
              |  sumI1(left,(v,i,lb1),[],rest,code)=let
                val dict=insert(v, (i+lb1)) left
                val (vD,pre)=sumloop dict
                in sumI1(dict,( v,i-1,lb1),[],vD@rest,pre@code) end
              | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
                val dict=insert(v, lb1) left
                in sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) end
              | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
                val dict=insert(v, (s+lb1)) left
                val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)
                in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end
              | sumI1 _ =raise Fail"None Variable-index in summation"
                val (E.V v,lb,ub)=hd(sumx)
                in
                sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])
                 end

            fun iterList (([],_),DstOp.addSca) = let val (vA,A)=mkReal 0 in (vA,A) end
              | iterList (([id1],code),_) = (id1,code)
              | iterList ((ids,code),rator)   = let
                    val (vB,B)= mkMultiple(lhs,ids,rator,realTy)
                    in (vB,code@B) end

            in (case body
                of  E.Field _           => errField()
                | E.Partial _           => errField()
                | E.Apply _             => errField()
                | E.Probe _             => errField()
                | E.Conv _              => errField()
                | E.Krn _               => errField()
                | E.Img _               => errField()
                | E.Lift _              => errField()
                | E.Sqrt e1             => mkSqrt(gen e1)
                | E.Value v             => mkReal(mapIndex(E.V v,!mapp))
                | E.Const c             => mkReal c
                | E.Epsilon(i,j,k)      => evalEps3(!mapp,i,j,k)
                | E.Eps2(i,j)           => evalEps2(!mapp,i,j)
                | E.Delta(i,j)          => evalDelta(!mapp,i,j)
                | E.Tensor(id,ix)       => indexTensor(!mapp,(lhs,params,args,id,ix,realTy))
                | E.Neg e               => negCheck0(lhs,gen e)
                | E.Sub (e1,e2)         => subCheck0(lhs,gen e1,gen e2)
                | E.Add e               => iterList(AddcheckO(e,[],[]),DstOp.addSca)
                | E.Prod e              => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)
                | E.Sum(x,E.Prod(E.Img(Vid,_,_)::E.Krn(Hid,_,_)::_))    =>
                                    evalField(!mapp,(body,info))
                | E.Div(e1 as E.Tensor(_,[_]),e2 as E.Tensor(_,[]))=>
                                        gen (E.Prod[E.Div(E.Const 1, e2),e1])
                | E.Div(e1,e2)          => divCheck0(lhs,gen e1, gen e2)
                | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),DstOp.addSca)
                (*end case*))
                end

         in
            gen(E.body e)
         end

end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0