Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2838 - (download) (annotate)
Tue Nov 25 03:40:24 2014 UTC (4 years, 8 months ago) by cchiw
File size: 12050 byte(s)
edit split-ein
(* Currently under construction 
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)




structure Split = struct

    local
   
    structure E = Ein
    structure mk= mkOperators
    structure SrcIL = HighIL
    structure SrcTy = HighILTypes
    structure SrcOp = HighOps
    structure SrcSV = SrcIL.StateVar
    structure VTbl = SrcIL.Var.Tbl
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstOp = MidOps
    structure DstV = DstIL.Var
    structure SrcV = SrcIL.Var
    structure P=Printer
    structure F=Filter
    structure T=TransformEin
    structure Var = MidIL.Var
    structure cleanP=cleanParams
    structure cleanI=cleanIndex

    val testing=1
    in
 

    fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
    fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
    fun setEinZero(y,params,index,args)=  (y,DstIL.EINAPP(setEin(params,index,E.Const 0),args))
    fun cleanParams e =cleanP.cleanParams e
    fun cleanIndex e =cleanI.cleanIndex e
    fun itos i =Int.toString i
    fun err str=raise Fail str
    val cnt = ref 0
    fun genName prefix = let
        val n = !cnt
    in
            cnt := n+1;
        String.concat[prefix, "_", Int.toString n]
    end
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))

    fun printEINAPP(id, DstIL.EINAPP(rator, args))=let
        val a=String.concatWith " , " (List.map Var.toString args)
        in
            String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> ==",P.printerE rator, a,"\n"])
        end
      | printEINAPP(id, DstIL.OP(rator, args))=let
          val a=String.concatWith " , " (List.map Var.toString args)      
         in
            String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> =",DstOp.toString rator,a,"\n"])
          end 
                
       | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])


    (* mkreplacement:params*index*index_list*int list* ein_exp-> ein_exp* params*args*code*
    *creates new param and replacement tensor for the original ein_exp
    *Then cleans params for suebxpression
    *)
    fun mkreplacement(params,args,tshape,sizes,body)=let
        val id=length(params)
        val params'=params@[E.TEN(1,sizes)]
        val e'=E.Tensor(id,tshape)
        val M  = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)
        val args'=args@[M]
        val einapp=cleanParams(M,body,params',sizes,args')
    in
            (e',params',args',[einapp])
    end

    (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans the index and params of subexpression 
    *)
    fun lift(e,params,index,sx,args)=let
        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val (Re,Rparams,Rargs,code)=mkreplacement(params,args,tshape,sizes,body)
    in
        (Re,Rparams,Rargs,code)
    end

    (* simplelift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans params of subexpression
    *)
    fun simplelift(e,params,index,args)=(*let
        val tshape=List.map (fn x => E.V x) index
        val(Re,Rparams,Rargs,code)=mkreplacement(params,args,tshape,index,e)
        in
        (Re,Rparams,Rargs,code)
        end
        *)lift(e,params,index,[],args)

           
    
    (* isOp: ein->int
     * checks to see if this sub-expression is pulled out or split form original
     * 0-becomes zero,1-remains the same, 2-operator
     *)
    fun isOp e =(case e
        of E.Field _  => 0
        | E.Conv _    => 0
        | E.Apply _   => 0
        | E.Lift _    => 0
        | E.Neg _     => 1
        | E.Add _     => 1
        | E.Sub _     => 1
        | E.Prod _    => 1
        | E.Div _     => 1
        | E.Sum _     => 1
        | E.Probe _   => 1
        | E.Partial _ => err(" Partial used after normalize")
        | E.Krn _     => err("Krn used before expand")
        | E.Value _   => err("Value used before expand")
        | E.Img _     => err("Probe used before expand")
        | _           => 2
    (*end case*))


    (* simpleOp:ein_exp*params*index*args-> ein_exp*params*args*code
     * If e1 an op then call simplelift() to replace it
     * Otherwise rewrite to 0 or it remains the same
     *)
    fun simpleOp(e1,params,index,args)=(case (isOp e1)
        of  0   => (E.Const 0,params,args,[])
        | 2     => (e1,params,args,[])
        | _     => simplelift(e1,params,index,args)
        (*end*))


    (* simpleOps:ein_exp list*params*index*args-> ein_exp list*params*args*code
     * calls simpleOp on ein_exp list
     *)
    fun simpleOps(list1,params,index,args)=let
        fun m([],rest,params,args,code)=(rest,params,args,code)
        | m(e1::es,rest,params,args,code)=let
            val (e1',params',args',code')= simpleOp(e1,params,index,args)
            in
                m(es,rest@[e1'],params',args',code@code')
            end
        in
            m(list1,[],params,args,[])
        end

    (* prodOps:ein_exp list*params*index*sum_id list*args-> ein_exp list*params*args*code
     * calls lift  on ein_exp list
     *)
    fun prodOps(list1,params,index,sx,args)=let
        fun m([],rest,params,args,code)=(rest,params,args,code)
        | m(e1::es,rest,params,args,code)=(case (isOp e1)
            of  0   => m(es,rest@[E.Const 0],params,args,code)
            | 2     => m(es,rest@[e1],params,args,code)
            | 1     => let
                val (e1',params',args',code')= lift(e1,params,index,sx,args)
                in
                    m(es,rest@[e1'],params',args',code@code')
                end
            (*end case*))
        in
            m(list1,[],params,args,[])
        end

           
    fun isZero(y,body,params,index,args) =(case (cleanI.isZero body)
        of 1=>  setEinZero(y,params,[],args)
        | _ =>  cleanParams(y,body,params,index,args)
    (*end case*))

    (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
    * calls simpleOp() lift  on ein_exp
    *)
    fun handleNeg(y,e1,params,index,args)=let
        val (e1',params',args',code)=  simpleOp(e1,params,index,args)
        val body =E.Neg e1'
        val einapp= isZero(y,body,params',index,args')
    in
        (einapp,code)
    end

   (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
    * calls simpleOps() lift  on ein_exp
    *)
    fun handleSub(y,e1,e2,params,index,args)=let
        val ([e1',e2'],params',args',code)=  simpleOps([e1,e2],params,index,args)
        val body =E.Sub(e1',e2')
        val einapp= isZero(y,body,params',index,args')
    in
        (einapp,code)
    end

    (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
    * calls simpleOp() lift  on ein_exp
    *)
    fun handleDiv(y,e1,e2,params,index,args)=let
        val (e1',params1',args1',code1')=simpleOp(e1,params,index,args)
        val (e2',params2',args2',code2')=simpleOp(e2,params1',[],args1')
        val body =E.Div(e1',e2')
        val einapp= isZero(y,body,params2',index,args2')
    in
            (einapp,code1'@code2')
    end

           
    (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
    * calls simpleOps() lift  on ein_exp
    *)
    fun handleAdd(y,e1,params,index,args)=let
        val (e1',params',args',code)=  simpleOps(e1,params,index,args)
        val body =E.Add e1'
        val einapp= isZero(y,body,params',index,args')
    in
        (einapp,code)
    end

    (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
     * calls prodOps() lift  on ein_exp
     *)
    fun handleProd(y,e1,params,index,args)=let
        val (e1',params',args',code)=  prodOps(e1,params,index,[],args)
        val body =E.Prod e1'
        val einapp= isZero(y,body,params',index,args')
    in
        (einapp,code)
    end

   (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
    * calls prodOps() lift  on ein_exp
    *)
    fun handleSumProd(y,e1,params,index,sx,args)=let
        val _ =List.map (fn (_,_,ub)=> Int.toString ub) sx
        val (e1',params',args',code)=  prodOps(e1,params,index,sx,args)
        val body= E.Sum(sx,E.Prod e1')
        val einapp= isZero(y,body,params',index,args')
    in
        (einapp,code)
    end

    (* split:var*ein_app-> (var*einap)*code
    * split ein expression into smaller pieces
    *)
    fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
        val zero=   (setEinZero(y,params,[],args),[])
        val default=((y,einapp),[])
        val sumIndex=ref []
        fun rewrite b=(case b
            of E.Probe _              => default
            | E.Conv _                => zero
            | E.Field _               => zero
            | E.Apply _               => zero
            | E.Lift e                => zero
            | E.Delta _               => default
            | E.Epsilon _             => default
            | E.Tensor _              => default
            | E.Const _               => default
            | E.Neg e1                => handleNeg(y,e1,params,index,args)
            | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)
            | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)
            | E.Sum(_,E.Probe _)      => default
            | E.Sum(_,E.Conv _)       => zero
            | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)
            | E.Sum(sx,E.Neg n)       => rewrite (E.Neg(E.Sum(sx,n)))
            | E.Sum(sx,E.Add a)       => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
            | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
            | E.Sum(sx,E.Div(e1,e2))  => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))
            | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
            | E.Sum(sx,_)             => default
            | E.Add e1                => handleAdd(y,e1,params,index,args)
            | E.Prod e1               => handleProd(y,e1,params,index,args) 
            | E.Partial _             => err(" Partial used after normalize")
            | E.Krn _                 => err("Krn used before expand")
            | E.Value _               => err("Value used before expand")
            | E.Img _                 => err("Probe used before expand")
            (*end case *))

        val (einapp2,newbies) =rewrite body
        in
            (einapp2,newbies)
        end
        |split(y,app) =((y,app),[])

           
    (* iterMultiple:code*code=> (code*code)
     * recursively split ein expression into smaller pieces
    *)
    fun iterMultiple(einapp2,newbies2)=let
        fun itercode([],rest,code)=(rest,code)
        | itercode(e1::newbies,rest,code)=let
            val (einapp3,code3) =split(e1)
            val (rest4,code4)=itercode(code3,[],[])
            in itercode(newbies,rest@[einapp3],code4@rest4@code)
            end
        val(rest,code)= itercode(newbies2,[],[])
        in
            (einapp2,code@rest)
        end


    fun iterSplit(y,einapp)=let
        val (einapp2,newbies2)=split(y,einapp)
    in
        iterMultiple(einapp2,newbies2)
    end


 
    (* gettest:code*code=> (code*code)
    * print results for splitting einapp
    *)
    fun gettest(einapp)=(case testing
        of 0=>iterSplit(einapp)
        | _=>let
            val star="\n*************\n"
            val _ =print(String.concat[star])
            val (einapp2,newbies)=iterSplit(einapp)
            val a=printEINAPP einapp2
            val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
            val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])
            in
                (einapp2,newbies)
            end
        (*end case*))


  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0