Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/lift-ein.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/lift-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3383 - (download) (annotate)
Mon Nov 9 02:39:26 2015 UTC (4 years, 11 months ago) by cchiw
File size: 12603 byte(s)
cut vector fields
(*created to lift out field terms only*)
structure LiftEin =struct

    local

structure E = Ein
structure DstIL = MidIL
structure DstTy = MidILTypes
structure DstV = DstIL.Var
structure P= Printer
structure cleanP=cleanParams
structure cleanI=cleanIndex
    in



    val numFlag=1   (*remove common subexpression*)
    val testing=0

    fun mkEin e = E.mkEin e
    val einappzero= DstIL.EINAPP(mkEin([],[],E.Const 0),[])
    fun setEinZero y=  (y,einappzero)
    fun cleanParams e = cleanP.cleanParams e
    fun cleanIndex e = cleanI.cleanIndex e
    fun toStringBind e= MidToString.toStringBind e
fun mkEin e=Ein.mkEin e
fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
    fun itos i = Int.toString i
    fun err str = raise Fail str
    val cnt = ref 0
    fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun genName prefix = let
        val n = !cnt
        in
            cnt := n+1;
            String.concat[prefix, "_", Int.toString n]
        end
    fun testp n=(case testing
        of 0=> 1
        | _ =>( (String.concat n);1)
        (*end case*))


    fun cut(name,e,params,index,sx,argsOrig,fieldset,cntinplace,cntlift,newvx) =let

        val _ = "\ncutting"
val _ = (String.concat["\nto intput to cutting :",Int.toString(length(params)),
"args:",Int.toString(length(argsOrig))])
        (*clean and rewrite current body*)
        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val id=length(params)
        val Rparams=params@[E.TEN(1,sizes)]
        val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
        val Rargs=argsOrig@[M]
        val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
        val _= " past first clean Params"

        (*shift indices in probe body from constant to variable*)
        val (y, DstIL.EINAPP(ein,args))=einapp
        val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
        val index0=Ein.index ein
        val index1 = index0@[3]
        val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)

        (* clean to get body indices in order *)
        val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
        val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
        val _ = "before first cleanParam"


        val ein1 = mkEin(Ein.params ein,index1,body1)
        val code1= (lhs1,mkEinApp(ein1,args))
        val (fieldset',lhs0,cntinplace',cntlift')=(case LiftSet.rtnVarN(fieldset,code1)
            of(fieldset,NONE)=>  (fieldset,lhs1,cntinplace+1,cntlift)
            | (fieldset,SOME v)=>(fieldset,v,cntinplace,cntlift+1)
            (*end case*))


        (*Probe that tensor at a constant position  c1*)
        val param0 = [E.TEN(1,index1)]
        val nx=List.tabulate(length(dx),fn n=>E.V n)
        val Re =  E.Tensor(id,[c1]@tshape)
        val Rparams=params@param0
        val Rargs=argsOrig@[lhs1]


        val newbies=[code1]
val _ = (String.concat["\nretuning from cutting :",Int.toString(length(Rparams)),
"args:",Int.toString(length(Rargs))])
           val _= " past cut "

    in
         (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')
    end


    (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans the index and params of subexpression
    *creates new param and replacement tensor for the original ein_exp
    *)
    fun lift(name,e,params,index,sx,args,fieldset,cntinplace,cntlift)=let
        val _ = " \n in side lift"
val _ = (String.concat["\nto intput to lift :",Int.toString(length(params)),
"args:",Int.toString(length(args))])

        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val id=length(params)
        val Rparams=params@[E.TEN(1,sizes)]
        val Re=E.Tensor(id,tshape)
        val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
        val Rargs=args@[M]
val _ = (String.concat["\nto cleanParams:",Int.toString(length(Rparams)),
"args:",Int.toString(length(Rargs))])
        val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
        val (_,einapp0)=einapp

        val (Rargs,newbies,fieldset',cntinplace',cntlift') =(case numFlag
            of 1=> let
                val (fieldset',var) = LiftSet.rtnVar(fieldset,M,einapp0)
                in (case var
                    of NONE=> let
                        val MidIL.EINAPP(ein0,arg0) =einapp0
                        in (args@[M],[einapp],fieldset',cntinplace+1,cntlift)
                        end
                    | SOME v=> (args@[v],[],fieldset',cntinplace,cntlift+1)
                    (*end case*))
                end
            | _=>(args@[M],[einapp],fieldset,cntinplace,cntlift)
            (*end case*))
        val _ = " \n out side lift"
        in
            (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')
        end


(*
fun ff(name,e,params,index,sx,args,fieldset,cntinplace,cntlift,newvx)=let
    val (tshape,sizes,body)=cleanIndex(e,index,sx)
    in (case body
of    E.Probe(E.Conv(_,[E.C _ ],_,[]),pos)
=> liftFieldVec (0,e,fieldset)
| E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))
=> liftFieldVec (1,e,fieldset)
| E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))
=> liftFieldVec (2,e,fieldset)
| E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))
=> liftFieldVec (3,e,fieldset)

*)
    fun liftfields(y,DstIL.EINAPP(ein0,args0))=let
        val sx = ref []
        val index=Ein.index ein0


        fun rewrite (b,data)=(case b
            of E.Const _        => (b,data)
            | E.Tensor _        => (b,data)
            | E.Field _         => (b,data)
            | E.Krn _           => (b,data)
            | E.Delta  _        => (b,data)
            | E.Value  _        => (b,data)
            | E.Epsilon  _      => (b,data)
            | E.Eps2  _         => (b,data)
            | E.Partial _       => (b,data)
            | E.Apply _         => (b,data)
            | E.Conv  _         => (b,data)
            | E.Img  _          => (b,data)
            | E.PowInt  _       => (b,data)
            | E.PowReal  _      => (b,data)
            | E.ConstR  _       => (b,data)
            | E.Neg e1          => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.Neg e1',data')
                end
            | E.Lift e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.Lift e1',data')
                end
            | E.Sqrt e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.Sqrt e1',data')
                end
            | E.Cosine e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.Cosine e1',data')
                end
            | E.ArcCosine e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.ArcCosine e1',data')
                    end
            | E.Sine e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.Sine e1',data')
                end
            | E.ArcSine e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                (E.ArcSine e1',data')
                end

        | E.Probe(E.Conv(_,[E.C _ ],_,[]),pos)=> let
            val (params,args,code,fieldset,cntinplace,cntlift)=data
            val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,0)

            val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
            in
            (body',data')
            end
| E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0]),pos)=> let
val (params,args,code,fieldset,cntinplace,cntlift)=data
val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,1)

val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
in
(body',data')
end
| E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0,E.V 1]),pos)=> let
val (params,args,code,fieldset,cntinplace,cntlift)=data
val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,2)

val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
in
(body',data')
end

| E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0,E.V 1,E.V 2]),pos)=> let
val (params,args,code,fieldset,cntinplace,cntlift)=data
val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,3)

val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
in
(body',data')
end
| E.Probe(E.Conv(_,[E.C _ ],_,dx),pos)=> let
val (params,args,code,fieldset,cntinplace,cntlift)=data
val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,length(dx))

val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
in
(body',data')
end



            | E.Probe _=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=lift("probe",b,params,index,(!sx),args,fieldset,cntinplace,cntlift)
                val _ =(P.printbody body')
                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end
| E.Sum(sx1,E.Probe _)=> let
val (params,args,code,fieldset,cntinplace,cntlift)=data
val (body',params',args',code',fieldset',cntinplace',cntlift')=lift("probe",b,params,index,(!sx),args,fieldset,cntinplace,cntlift)
val _ =(P.printbody b)
val _ =("=>\n\t")
val _ =(P.printbody body')
val _ =("\nindex:")
fun iterIx e=List.map (fn e=> (Int.toString(e)^",")) e
fun iterIVx e=List.map (fn (E.V e)=> ("E.V"^Int.toString(e)^",")) e
val _ = iterIx index
val _ =("\nsx:")
val _ =List.map (fn (E.V e,_,_)=> (Int.toString(e)^",")) (!sx)
 val (a,b,c)=cleanIndex(b,index,!sx)
val _ =("\nTshape")
val _ =iterIVx a
val _ =("\nsizes")
val _ = iterIx b
val _ =("\nbody")
val _ =(P.printbody c)

val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
in
(body',data')
end
            | E.Sub(e1,e2)=> let
                val (e1',data1')=rewrite (e1,data)
                val (e2',data2')=rewrite (e2,data1')
                in (E.Sub(e1',e2'),data2') end

            | E.Div (e1,e2) => let
                val (e1',data1')=rewrite (e1,data)
                val (e2',data2')=rewrite (e2,data1')
                in (E.Div(e1',e2'),data2') end
            | E.Add es=>let
                fun iter(rest,[],data)=(E.Add rest,data)
                | iter(rest,e1::es,data)=let
                    val (e1',data')=rewrite (e1,data)
                    in
                        iter(rest@[e1'],es,data')
                    end
                in iter([],es,data) end
            | E.Prod es=>let
                fun iter(rest,[],data)=(E.Prod rest,data)
                | iter(rest,e1::es,data)=let
                    val (e1',data')=rewrite (e1,data)
                    in
                        iter(rest@[e1'],es,data')
                    end
                in iter([],es,data) end
            | E.Sum (sx1,e1)=> let
                val a=(!sx)
                in
                    let
                        val (e1',data')=(sx:=sx1@a;rewrite (e1,data))
                    in
                            (sx:=a;(E.Sum(sx1,e1'),data'))
                    end
                end
        (*end case*))

        fun scan(E.Probe p ,data)=(E.Probe p ,data)
          | scan(E.Sum(sx,E.Probe p) ,data)=(E.Sum(sx,E.Probe p) ,data)
          | scan e = rewrite e

       val fieldset= LiftSet.LiftSet.empty
        val data= (Ein.params ein0,args0,[],fieldset,0,0)
        val (body',data')=scan(Ein.body ein0,data)
        val (params',args',code',fieldset',cntinplace',cntlift')=data'
        val k=(toStringBind (y,DstIL.EINAPP(Ein.EIN{params=params',index=index,body=body'},args')))
        val _= print(String.concat["\n in place",Int.toString(cntinplace'),"- l",Int.toString(cntlift')])
        val _ = "\n **last clean params"
        val einapp= cleanParams(y,body',params',index,args')
        in
            (einapp,code')
        end


    fun testLift e1=  let
        val _ = "\nUses LIFT"
        val (einapp1,e2)=liftfields e1
        val n=length(e2)
        val _ =(case n
            of 1=>  1
            | _ => testp["-- Returning :", Int.toString(n)]
        (*end case*))

        in
            (einapp1,e2)

        end
end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0