Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/step2.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/step2.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2827 - (download) (annotate)
Tue Nov 11 00:18:38 2014 UTC (4 years, 9 months ago) by cchiw
File size: 8384 byte(s)
changed types for imgAddr and imgLd
(*general function for scalars*)
structure step2 = struct
    local
    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure E = Ein
    structure S3=step3
    structure genKrn=genKrn
    structure tS= toStringEin

    in


fun insert (key, value) d =fn s =>
if s = key then SOME value
else d s

fun lookup k d = d k
val empty =fn key =>NONE

fun err _=raise Fail("Invalid Field Here")
fun errS str=raise Fail(str)

(*Helpers for scalars*)

fun mkCons(shape, rest)=let
    val ty=DstTy.TensorTy shape
    val a=DstIL.Var.new("Cons" ,ty)
    val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
    val _=print("###"^tS.toStringAll(ty,code))
    in (a, [code])
    end


val Sca=DstTy.TensorTy([])
val addR=DstOp.addSca


fun mkProdSca(lhs,rest)=S3.aaV(DstOp.prodSca,rest,lhs^"prodSca",Sca)
fun mkSubSca(lhs,rest)= S3.aaV(DstOp.subSca,rest,lhs^"subSca",Sca)
fun mkDivSca(lhs,rest)= S3.aaV(DstOp.divSca,rest,lhs^"divSca",Sca)
fun mkMultipleSca(info,ids,rator)=S3.mkMultiple(info,ids,rator,Sca)
fun mkReal n=S3.mkReal n 





fun prodIter(origIndex,index,nextfn,args)=(let
    val index'=List.map (fn (e)=>(e-1)) index

    fun get(n,m,mapp)=let
        val mapp =insert(n, m) mapp
        in
            nextfn(mapp,args)
        end 

    fun Iter(mapp,[],rest,code,shape,_)=let
        val (vF,code')=nextfn(mapp,args)
        in (vF, code'@code)
        end
    | Iter(mapp,[0], rest, code,shape,n)=let
        val (vF,code')= get(n,0,mapp)
        val(vE,E)=mkCons(shape,[vF]@rest)
        in
            (vE, code'@code@E)
        end
    | Iter(mapp,[c],rest,code,shape,n)=let
        (*val (vF,code')= get(n,c,mapp)
        val (vE,E)=nextfn(mapp,args)*)
        val (vE,E)=get(n,c,mapp)
        in
            Iter(mapp, [c-1], [vE]@rest,E@code,shape,n)
        end
    | Iter(mapp,b::c,rest,ccode,s::shape,n)=let
        val n'=n+1
        fun S(0, rest,code)=let
            val mapp =insert(n, 0) mapp
            val (v',code')=Iter(mapp,c,[],[],shape,n')
            val(vA,A)=mkCons(s::shape,[v']@rest)
            in
                (vA, code'@code@A)
            end
        | S(i, rest, code)= let
            val mapp =insert(n, i) mapp
            val (v',code')=Iter(mapp,c,[],[],shape,n')
            in
                S(i-1,[v']@rest,code'@code)
            end
        val (vA,code')=S(b, [],[])
        in
            (vA,code'@ccode)
        end
    | Iter _=raise Fail"index' is larger than origIndex"
    in
        Iter(empty,index',[],[],origIndex,0)
    end)

(*Get constant *)
fun skeleton A=(case A
    of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0
    | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1
    | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1
    | _ => 9
    (*end case*))

(*Helper Functions for General functions*)
fun findIX(v, mapp)=(case (lookup v mapp)
    of NONE=> errS( "Outside Bound:"^Int.toString(v))
    |SOME s => s
    (*end case*))


fun NegCheckO(lhs,(vA,A))=(case skeleton A
    of 0 => mkReal 0
    | ~1 => mkReal 1
    | 1  => mkReal ~1
    |  _=> let
        val (vB,B)=mkReal ~1
        val (vD,D)=mkProdSca (lhs,[vB,vA])
        in (vD,A@B@D) end
    (*end case*))


fun SubcheckO(lhs,(vA,A),(vB,B))=(case((skeleton A),(skeleton B))
    of (0,0)=> mkReal 0
    |(0,_)=> let
        val (vD,D)= mkReal ~1
        val (vE,E)= mkProdSca(lhs, [vD,vB])
        in (vE,B@D@E) end
    | (_,0)=> (vA,A)
    | _ => let
        val (vD,D)= mkSubSca(lhs,[vA,vB])
        in (vD, A@B@D) end
    (*end case*))
(*
fun printMapp mapp=(case (lookup 0 mapp)
of NONE=>print(String.concat["\n No zero"])
    |SOME s => print(String.concat["\n Found 0 =>",Int.toString(s)])
(*end case*))

*)
(*  val info=(params,args)*)

(* general expressions-removes zeros*)
fun generalfn(dict,(body,origargs,info as (lhs,_,_)))=let
    val mapp=ref dict

   

    fun gen body=let
        fun AddcheckO ([],[],[])=let val (vA,A)=mkReal 0 in ([vA],A) end
          | AddcheckO([],ids,code)=(ids,code)
          | AddcheckO(e1::es,ids,code)=let
            val (a,b)=gen e1
            in (case (skeleton b)
                of 0 => AddcheckO(es,ids,code)
                |  _ => AddcheckO(es,ids@[a],code@b)
                (*end case*))
            end
        fun ProdcheckO ([],[],[])=let val (vA,A)=mkReal 1 in ([vA],A) end
          | ProdcheckO([],ids,code)=(ids,code)
          | ProdcheckO(e1::es,ids,code)=let
             val (a,b)=gen e1
             in (case (skeleton b)
                of 0 => ([a],b)
                | 1 => ProdcheckO(es,ids,code)
                | _ => ProdcheckO(es,ids@[a],code@b)
                (*end case*))
            end

        fun Sumcheck(sumx,e)=let
            fun sumloop mapsum=let
                val _ = mapp:=mapsum
                val(vA,A)=gen e
                in (case (skeleton A)
                    of 0 => ([],A)
                    |  _ => ([vA],A)
                    (*end case*))
                end

                (*in ([vA],A) end*)
            fun sumI1(left,(v,0,lb1),[],rest,code)=let
                val dict=insert(v, lb1) left
                val (vD,pre)= sumloop dict
                in (vD@rest,pre@code) end
            |  sumI1(left,(v,i,lb1),[],rest,code)=let
                val dict=insert(v, (i+lb1)) left
                val (vD,pre)=sumloop dict
                in sumI1(dict,( v,i-1,lb1),[],vD@rest,pre@code) end
            | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
                val dict=insert(v, lb1) left
                in sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) end
            | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
                val dict=insert(v, (s+lb1)) left
                val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)
                in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end
             | sumI1 _ =raise Fail"None Variable-index in summation"
            val (E.V v,lb,ub)=hd(sumx)
            in
                sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])
            end

        fun iterList(e, DstOp.addSca)=(case e
            of ([],code)=>let val (vA,A)=mkReal 0 in (vA,A) end
            | ([id1],code) => (id1,code)
            | (ids,code)    => let
                val (vB,B)= mkMultipleSca(info,ids,addR)
                in (vB,code@B) end
            (*end case*))

        | iterList(e,rator)= (case e
            of ([id1],code) => (id1,code)
            | (ids,code)    => let
                val (vB,B)= mkMultipleSca(info,ids, rator)
                in (vB,code@B) end
            (*end case*))
        
    in (case body
        of  E.Field _           => err 1
        | E.Partial _           => err 1
        | E.Apply _             => err 1
        | E.Probe _             => err 1
        | E.Conv _              => err 1
        | E.Krn _               => err 1
        | E.Img _               => err 1
        | E.Lift _              => err 1
        | E.Value v             => mkReal(findIX(v,!mapp))
        | E.Const c             => mkReal c
        | E.Epsilon(i,j,k)      => S3.evalEps(!mapp,i,j,k)
        | E.Delta(i,j)          => S3.evalDelta2(!mapp,i,j)
        | E.Tensor(id,ix)       => S3.mkSca(!mapp,(id,ix,info))
        | E.Neg e               => NegCheckO(lhs,gen e)
        | E.Sub (e1,e2)         => SubcheckO(lhs,gen e1,gen e2)
        | E.Div(e1,e2)          => let
            val (vA,A)=gen e1
            in (case (skeleton A)
                of 0=> mkReal 0
                | _=> let
                    val (vB,B)=gen e2
                    val (vD,D)= mkDivSca(lhs, [vA,vB])
                    in (vD, A@B@D) end
                (*end case*))
            end
        | E.Add e               => (iterList(AddcheckO(e,[],[]),addR))
        | E.Prod e              => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)
       (* | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let
            val harg=List.nth(origargs,Hid)
            val imgarg=List.nth(origargs,Vid)
            val h=S3.getKernel(harg)
            val v=S3.getImage(imgarg)
val imgargnew=List.nth(args,Vid)
val v=S3.getImage(imgarg,imgargnew)
            in
                genKrn.evalField(!mapp,(body,v,h,info))
            end
*)
        | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),addR)
        (*end case*))
        end

        in gen body
       end
    
end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0