Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/lift-ein.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/lift-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3559 - (download) (annotate)
Fri Jan 8 20:24:06 2016 UTC (3 years, 7 months ago) by cchiw
File size: 13694 byte(s)
added hard limit to float size
(*created to lift out field terms only*)
structure LiftEin =struct

    local

structure E = Ein
structure DstIL = MidIL
structure DstTy = MidILTypes
structure DstV = DstIL.Var
structure P= Printer
structure cleanP=cleanParams
structure cleanI=cleanIndex
    in



    val numFlag=1   (*remove common subexpression*)
    val testing=0

    fun mkEin e = E.mkEin e
    val einappzero= DstIL.EINAPP(mkEin([],[],E.B(E.Const 0)),[])
    fun setEinZero y=  (y,einappzero)
    fun cleanParams e = cleanP.cleanParams e
    fun cleanIndex e = cleanI.cleanIndex e
    fun toStringBind e= MidToString.toStringBind e
    fun mkEin e=Ein.mkEin e
    fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
    fun itos i = Int.toString i
    fun err str = raise Fail str
    val cnt = ref 0
    fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun genName prefix = let
        val n = !cnt
        in
            cnt := n+1;
            String.concat[prefix, "_", Int.toString n]
        end
    fun testp n=(case testing
        of 0=> 1
        | _ =>( (String.concat n);1)
        (*end case*))


    fun cut(name,e,params,index,sx,argsOrig,fieldset,cntinplace,cntlift,newvx) =let

        val _ = "\ncutting"
        val _ = (String.concat["\nto intput to cutting :",Int.toString(length(params)),
            "args:",Int.toString(length(argsOrig))])
        (*clean and rewrite current body*)
        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val id=length(params)
        val Rparams=params@[E.TEN(1,sizes)]
        val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
        val Rargs=argsOrig@[M]
        val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
        val _= " past first clean Params"

        (*shift indices in probe body from constant to variable*)
        val (y, DstIL.EINAPP(ein,args))=einapp
        val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
        val index0=Ein.index ein
        val index1 = index0@[3]
        val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)

        (* clean to get body indices in order *)
        val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
        val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
        val _ = "before first cleanParam"


        val ein1 = mkEin(Ein.params ein,index1,body1)
        val code1= (lhs1,mkEinApp(ein1,args))
        val (fieldset',cntinplace',cntlift',Rargs,newbies)=(case einSet.rtnVarN(fieldset,code1)
            of(fieldset,NONE)=>   let
                val _ =(String.concat["\n******************************************\n",P.printerE(ein1), "\n:Insert",
                    "\ncnt inplace:",Int.toString(cntinplace+1)," cntlift:",Int.toString(cntlift)])
                in (fieldset,cntinplace+1,cntlift,argsOrig@[lhs1],[code1]) end
            | (fieldset,SOME v)=> (fieldset,cntinplace,cntlift+1,argsOrig@[v],[])
            (*end case*))


        (*Probe that tensor at a constant position  c1*)
        val param0 = [E.TEN(1,index1)]
        val nx=List.tabulate(length(dx),fn n=>E.V n)
        val Re =  E.Tensor(id,[c1]@tshape)
        val Rparams=params@param0


        val _ = (String.concat["\nretuning from cutting :",Int.toString(length(Rparams)),
        "args:",Int.toString(length(Rargs))])
        val _= " past cut "
    in
         (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')
    end


    (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans the index and params of subexpression
    *creates new param and replacement tensor for the original ein_exp
    *)
    fun lift(name,e,params,index,sx,args,fieldset,cntinplace,cntlift)=let
        val _ = " \n in side lift"
        val _ = (String.concat["\nto intput to lift :",Int.toString(length(params)),
            "args:",Int.toString(length(args))])
    
        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val id=length(params)
        val Rparams=params@[E.TEN(1,sizes)]
        val Re=E.Tensor(id,tshape)
        val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
        val Rargs=args@[M]
        val _ = (String.concat["\nto cleanParams:",Int.toString(length(Rparams)),
        "args:",Int.toString(length(Rargs))])
        val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
        val (_,einapp0)=einapp
         val MidIL.EINAPP(eind,_) =einapp0

        val (Rargs,newbies,fieldset',cntinplace',cntlift') =(case numFlag
            of 1=> let
                val (fieldset',var) = einSet.rtnVar(fieldset,M,einapp0)
                in (case var
                    of NONE=> let
                        val MidIL.EINAPP(ein0,arg0) =einapp0
                        val _ = (String.concat[P.printerE(ein0), "\n:Insert",
                        "                       \ncnt inplace:",Int.toString(cntinplace+1)," cntlift:",Int.toString(cntlift),"\n******************************************\n"])
                        in (args@[M],[einapp],fieldset',cntinplace+1,cntlift)
                        end
                    | SOME v=> let
                        val DstIL.EINAPP(eind,_)=einapp0

                        in (args@[v],[],fieldset',cntinplace,cntlift+1) end
                    (*end case*))
                end
            | _=>(args@[M],[einapp],fieldset,cntinplace,cntlift)
            (*end case*))
        val _ = " \n out side lift"

        in
            (Re,Rparams,Rargs,newbies,fieldset',cntinplace',cntlift')
        end
    fun isOp e =(case e
        of E.Op1 _    => true
        | E.Op2 _     => true
        | E.Opn _     => true
        | E.Sum _     => true
        | E.Probe _   => true
        | _           => false
        (*end case*))

    fun filterOps([],rest,data,_,_)= (rest,data)
    | filterOps(e3::es,rest,data,index,sx)=(case (isOp(e3))
        of false => filterOps(es,rest@[e3],data,index,sx)
        | true   => let
            val (params3,args3,code3,fieldset3,cntinplace3,cntlift3)=data
            val (body4,params4,args4,code4,fieldset4,cntinplace4,cntlift4)=lift("op1_e3_",e3,params3,index,(sx),args3,fieldset3,cntinplace3,cntlift3)
            val data4=(params4,args4,code3@code4,fieldset4,cntinplace4,cntlift4)
            in filterOps(es,rest@[body4],data4,index,sx) end
        (*end case*))


    fun isLimit(_,_,code,_,_,_)=let
            val n=length(code)
           (* val _=print(String.concat["\nCode length",Int.toString(n)])*)
            in 4000>=n end

    fun liftfields(y,DstIL.EINAPP(ein0,args0))=let
        val sx = ref []
        val index=Ein.index ein0


        fun rewrite (b,data)=(case b
            of E.B  _           => (b,data)
            | E.Tensor _        => (b,data)
            | E.Field _         => (b,data)
            | E.Krn _           => (b,data)
            | E.G  _            => (b,data)
            | E.Value  _        => (b,data)
            | E.Partial _       => (b,data)
            | E.Apply _         => (b,data)
            | E.Conv  _         => (b,data)
            | E.Img  _          => (b,data)
            | E.Lift e1  => let
                val (e1',data')=rewrite (e1,data)
                in
                    (E.Lift e1',data')
                end
            | E.Probe(E.Conv(_,[E.C _ ],_,[]),pos)=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,0)

                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end
            | E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0]),pos)=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,1)

                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end
            | E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0,E.V 1]),pos)=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,2)

                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end

            | E.Probe(E.Conv(_,[E.C _ ],_,[E.V 0,E.V 1,E.V 2]),pos)=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,3)

                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end
            | E.Probe(E.Conv(_,[E.C _ ],_,dx),pos)=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=cut("cut",b,params,index,(!sx),args,fieldset,cntinplace,cntlift,length(dx))

                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end
            | E.Probe _=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=lift("probe",b,params,index,(!sx),args,fieldset,cntinplace,cntlift)
                val _ =(P.printbody body')
                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                    (body',data')
                end
            | E.Sum(sx1,E.Probe _)=> let
                val (params,args,code,fieldset,cntinplace,cntlift)=data
                val (body',params',args',code',fieldset',cntinplace',cntlift')=lift("probe",b,params,index,(!sx),args,fieldset,cntinplace,cntlift)
                fun iterIx e=List.map (fn e=> (Int.toString(e)^",")) e
                fun iterIVx e=List.map (fn (E.V e)=> ("E.V"^Int.toString(e)^",")) e
                val data'=(params',args',code@code',fieldset',cntinplace',cntlift')
                in
                (body',data')
                end

            | E.Op1(op1, e1)  =>
                if(isLimit data) then
                    let
                        val (e1',data')=rewrite (e1,data)
                        val ([e1],data)= filterOps([e1'],[],data',index,!sx)
                        in (E.Op1(op1, e1),data) end
                else    let
                        val (e1',data')=rewrite (e1,data)
                        in (E.Op1(op1, e1'),data') end
            | E.Op2(op2,e1,e2)=>
                if (isLimit data) then let
                    val (e1',data1')=rewrite (e1,data)
                    val (e2',data2')=rewrite (e2,data1')
                    val ([e1,e2],data)= filterOps([e1',e2'],[],data2',index,!sx)
                    in (E.Op2(op2,e1,e2),data) end
                else   let
                    val (e1',data1')=rewrite (e1,data)
                    val (e2',data2')=rewrite (e2,data1')
                    in (E.Op2(op2,e1',e2'),data2') end
            | E.Opn(opn, es)=>
                if(isLimit data) then let
                    fun iter(rest,[],data)=(rest,data)
                    | iter(rest,e1::es,data)=let
                        val (e1',data')=rewrite (e1,data)
                        in
                            iter(rest@[e1'],es,data')
                        end
                    val (rest,data)=iter([],es,data)
                    val (rest',data')= filterOps(rest,[],data,index,!sx)
                    in (E.Opn(opn, rest'),data') end
                else let
                    fun iter(rest,[],data)=(E.Opn(opn, rest),data)
                    | iter(rest,e1::es,data)=let
                        val (e1',data')=rewrite (e1,data)
                        in
                            iter(rest@[e1'],es,data')
                        end
                    in iter([],es,data) end

            | E.Sum (sx1,e1)=> let
                val a=(!sx)
                in
                    let
                        val (e1',data')=(sx:=sx1@a;rewrite (e1,data))
                    in
                            (sx:=a;(E.Sum(sx1,e1'),data'))
                    end
                end
        (*end case*))

        fun scan(E.Probe p ,data)=(E.Probe p ,data)
          | scan(E.Sum(sx,E.Probe p) ,data)=(E.Sum(sx,E.Probe p) ,data)
          | scan e = rewrite e

       val fieldset= einSet.EinSet.empty
        val data= (Ein.params ein0,args0,[],fieldset,0,0)
        val (body',data')=scan(Ein.body ein0,data)
        val (params',args',code',fieldset',cntinplace',cntlift')=data'
        val k=(toStringBind (y,DstIL.EINAPP(Ein.EIN{params=params',index=index,body=body'},args')))
        val _= testp ["\n in place",Int.toString(cntinplace'),"- l",Int.toString(cntlift')]
        val _ = "\n **last clean params"
        val einapp= cleanParams(y,body',params',index,args')
        in
            (einapp,code',fieldset')
        end


    fun testLift e1=  let
        val _ = "\nUses LIFT"
        val (einapp1,e2,fieldset)=liftfields e1
        val n=length(e2)
        val _ =(case n
            of 1=>  1
            | _ => testp["-- Returning :", Int.toString(n)]
        (*end case*))

        in
            (einapp1,e2,fieldset)

        end
end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0