Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2554 - (download) (annotate)
Sun Mar 2 19:58:48 2014 UTC (7 years, 1 month ago) by cchiw
File size: 10799 byte(s)
Value numbering
(* Split Functions before code generation process*)
structure splitHtM = struct
    local
    structure E = Ein
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure shift=shiftHtM
    structure P=Printer
    structure Var = MidIL.Var
    structure HVar = HighIL.Var
    in


fun printA(id,e,arg)=let
val a=String.concatWith " , " (List.map Var.toString arg)
in String.concat([(Var.toString id)," ==",P.printerE e, a])
end

fun printAA(id,e,arg)=let
val a=String.concatWith " , " (List.map HVar.toString arg)
in String.concat([(Var.toString id)," ==",P.printerE e, a])
end


fun createEin( params,index, body)=Ein.EIN{params=params, index=index, body=body}
fun flat xs = List.foldr op@ [] xs
val counter=ref 0

(*How to create new ein variable*)
fun fresh ty=let
    val ref x=counter
    val m=x+1
    val x=DstIL.Var.new("Q" ^ Int.toString(m) ,ty)
    in (counter:=m;x) end

fun createnewb (params,index,args,(id,e))=let
    val (p',b',args)=shift.cleanParams(e,params,args)
    val a=createEin(p',index, b')
    in (id,a,args)
end

fun createnewP (params,args,(id,e,ix))=let
    val (p',b',args)=shift.cleanParams(e,params,args)
    val a=createEin(p',ix, b')
    in (id,a,args)
end

fun findOp e=(case e
of E.Neg _=>1
| E.Add _=>1
| E.Sub _=>1
| E.Prod _=>1
| E.Div _=>1
| E.Sum _ =>1
|  _=>0
(*end case*))



(*Outside Operator is Neg*)
fun handleNeg(params, index,e1,args)=let
        val id=ref (length params)
        val n=length index

        val ix=List.tabulate (n,fn v=> E.V(v))
        fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end

        fun divsort(e)= let
            val s=findOp e
            in (case s
                of 0=>(e,[],[],[])
                | _=> let
                    val q=fresh(DstTy.TensorTy(index))
                    in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
                (*end case*))
            end

            val (lft1, newbies1,params1,args1)=divsort(e1)
            val (p',b',args')= shift.cleanParams(E.Neg(lft1),params@params1, args@args1)
            val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
            in
                (z1,(p',b',args'))
            end



(*let
    val id=ref (length params)
        val n=length index
        val ix=List.tabulate (n,fn v=> E.V(v))
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end

    fun replace e'= let
        val q=fresh(DstTy.TensorTy(index))
        val t=mkTensor 0        
        val newbie=createEin( params,index,e')
        in ([(q,newbie,args)],(params@[E.TEN(1,index)], E.Neg t,[q]))
        end

    fun sort e1= (case e1
        of E.Add  _=> replace e1
        | E.Sub _=> replace e1
        | E.Prod _=> replace e1
        | E.Div _=> replace e1
        | E.Sum _=> replace e1
        | _=>([],params, E.Neg e1, args)
        (*end case*))

    val (newbies1,params1,lft1,args1)=sort e
    val (p',b',args')= shift.cleanParams(lft1,params@params1, args@args1)
    val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
    in
        (z1,(p',b',args'))
    end

*)




(*Outside Operator is Add*)
fun handleAdd(params, index,list1,args)=let
    val id=ref (length params)
    val n=length index
    val ix=List.tabulate (n,fn v=> E.V(v))
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);[E.Tensor(idx,ix)]) end

    fun foundOp(e,es,(lft,newbies,params,args))=let
        val q=fresh(DstTy.TensorTy(index))
        in  (es,(lft@(mkTensor 0),  newbies@[(q, e)],params@[E.TEN(1,index)],args@[q]))
        end


    fun sort([], m)=m
        | sort(e::es,m)=(case e
            of E.Add p => sort(p@es, m)
            | E.Sub _=>sort (foundOp(e, es,m))
            | E.Prod _=>sort (foundOp(e, es,m))
            | E.Div _=>sort (foundOp(e, es,m))
            | E.Neg _=>sort (foundOp(e, es,m))
            | E.Sum _=>sort (foundOp(e, es,m))
            | _ => let
                val (l,n, p, a)=m
                in sort(es,(l@[e],n,p,a)) end
            (*end case *))

    val (lft, newbies,params',args')=sort(list1,([],[],[],[]))
    val (p',b',args')= shift.cleanParams(E.Add(lft),params@params', args@args')
    val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
    in
        (z,(p',b',args'))
    end






(*Outside Operator is Sub*)
fun handleSub(params, index,e1,e2,args)=let
    val id=ref (length params)
    val n=length index
    val ix=List.tabulate (n,fn v=> E.V(v))

    fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end

    fun subsort(e)= let
        val s=findOp e
        in (case s
            of 0=>(e,[],[],[])

            | _=> let
                val q=fresh(DstTy.TensorTy(index))
                in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
            (*end case*))
        end 

    val (lft1, newbies1,params1,args1)=subsort e1
    val (lft2, newbies2,params2,args2)=subsort e2
    val (p',b',args')= shift.cleanParams(E.Sub(lft1,lft2), params@params1@params2, args@args1@args2)
     val newbies=newbies1@newbies2
    val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
    in
        (z,(p',b',args'))
    end

(*Outside Operator is Div *)
fun handleDiv(params, index,e1,e2,args)=let
    val id=ref (length params)
    val ix=List.tabulate (n,fn v=> E.V(v))
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
    fun mkSca _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,[])) end

    fun divsort(e,nextfn)= let
        val s=findOp e
        in (case s
            of 0=>(e,[],[],[])

            | _=> let
                val q=fresh(DstTy.TensorTy(index))
                in (nextfn 0, [(q, e)],[E.TEN(1,index)],[q]) end
            (*end case*))
        end

    val (lft1, newbies1,params1,args1)=divsort(e1,mkTensor)
    val (lft2, newbies2,params2,args2)=divsort(e2,mkSca)
    val (p',b',args')= shift.cleanParams(E.Div(lft1,lft2),params@params1@params2, args@args1@args2)

    val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
    val z2=List.map (fn(e)=> createnewb(params,[],args,e) ) newbies2
    in
        (z1@z2,(p',b',args'))
    end




fun hProd(params, index,list1,args)=let
    val id=ref (length params)
    val n=length index

    fun mkPTensor e=let
        val ref idx= id
        val (ix,index',e')=shift.cleanIndex(e, n, index)   
        in (id:=(idx+1);([E.Tensor(idx,ix)],index',e')) end


    fun foundOp (e,es,(lft,newbies, params, args))=let
   
        val ref idx= id
        val (ix,index',e')=shift.cleanIndex(e, n, index)
        val (p,ix,e')=([E.Tensor(idx,ix)],index',e')
         val q=fresh(DstTy.TensorTy(ix))
        in (es,(lft@p, newbies@[(q, e',ix)],params@[E.TEN(1,ix)],args@[q])) end


    fun sort([], m)=m
        | sort(e::es,m)=(case e
            of E.Add _ => sort (foundOp(e,es,m))
            | E.Sub _=>sort (foundOp(e, es,m))
            | E.Prod p=>sort (p@es,m)
            | E.Div _=>sort (foundOp(e, es,m))
            | E.Neg _=>sort (foundOp(e, es,m))
            | E.Sum _=>sort (foundOp(e, es,m))
            | E.Probe _=> raise Fail("Probe- Should have been expanded")
            | _ => let
                val (l,n, p, a)=m
                in sort(es,(l@[e],n,p,a)) end
            (*end case *))

    in
        sort(list1,([],[],[],[]))
    end

fun handleProd(params, index,list1,args)=let
    val (lft, newbies,params',args')=hProd(params, index,list1,args)
    val (p',b',args')= shift.cleanParams(E.Prod lft, params@params', args@args')
    val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
    in
        (z,(p',b',args'))
    end

fun handleSumProd(params, ind,sx,list1,args)=let
    val id=ref (length params)
    val n=length ind
    val m=print (String.concat["\n In Sum Prod", "n",Int.toString(n)])



    fun g(lft,[],_)=(1,lft) (*lft-outer index*)
        | g(lft,(E.V s,0,ub)::es,n')=if(s=n') then (print "match";g(lft@[ub],es,n'+1)) else (0,[])
        | g _ =(0,[]) (*Can't be split, weird bound*)

    val (c,index')= g([],sx,n)

    in  case c
        of 0=> ([],(params,E.Sum(sx, E.Prod(list1)),args))
        |_=>let
            val index=ind@index'
            val (lft, newbies,params',args')=hProd(params, index,list1,args)
            val (p',b',args')= shift.cleanParams(E.Sum(sx,E.Prod lft), params@params', args@args')
            val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
            in
                (z,(p',b',args'))
            end
    end 





fun genfn(id,Ein.EIN{params, index, body},args)= let

    val notDone=([],(params,body,args))
    fun gen body=(case body
        of  E.Field _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Partial _ =>raise Fail(concat["Invalid Partial here "]   )
        | E.Apply _ =>raise Fail(concat["Invalid Apply here "]   )
        | E.Probe _ => raise Fail("Probe- Should have been expanded")
        | E.Conv _ =>notDone
        | E.Krn _ =>notDone
        | E.Img _=> notDone
        | E.Const _=> notDone
        | E.Tensor(id,[])=> notDone
        | E.Prod(E.Img _  :: _)=>notDone
        | E.Neg(E.Neg e)=> gen e
        | E.Neg e=>  handleNeg(params, index,e,args)
| E.Add a => (print "Add";handleAdd(params, index,a,args))
        | E.Sub(E.Sub(a,b),E.Sub(c,d))=> gen(E.Sub(E.Add[a,d],E.Add[b,c]))
        | E.Sub(E.Sub(a,b),e2)=>gen (E.Sub(a,E.Add[b,e2]))
        | E.Sub(e1,E.Sub(c,d))=>gen(E.Add([E.Sub(e1,c),d]))
| E.Sub(e1,e2)=>(print "SUB";handleSub(params, index,e1,e2,args))
        | E.Div(E.Div(a,b),E.Div(c,d))=> gen(E.Div(E.Prod[a,d],E.Prod[b,c]))
        | E.Div(E.Div(a,b),c)=> gen(E.Div(a, E.Prod[b,c]))
        | E.Div(a,E.Div(b,c))=> gen(E.Div(E.Prod[a,c],b))
        | E.Div(e1,e2)=>handleDiv(params, index,e1,e2,args)
| E.Prod e=> (print "PROD ";handleProd(params, index,e,args))
        
        | E.Sum(_,E.Prod(E.Img _ :: _ ))=>notDone
| E.Sum(sx,E.Prod e)=>(print "CAT"; handleSumProd(params, index,sx,e,args))

 
        | _=> notDone
    (*end case*))


    val (newbie,(p,b,arg))= gen body
    val e'=createEin(p,index, b)


    val f= (id,e',arg)
    in (newbie, f)
    end



fun splitIt (change,e)=let
    val (newbie, e')= genfn e
    in (case length(newbie)
        of 0=>(change,[e'])
        | _=> let
            val a=List.map (fn(e1)=>splitIt(1,e1)) newbie
            val newbie'=flat(List.map (fn(e1,e2)=>e2) a)
            in (1,newbie'@[e']) end 
        (*end case *))
        
    end

fun splitein(id,E.EIN{params,index,body},arg)=let
    val m=print(printA(id,E.EIN{params=params,index=index,body=body},arg))
    val g=print "\n \t changed to =>\n \t"
    val (p',i',b',args')=shiftHtM.clean(params, index, body, arg)
    val einn'=createEin(p',i', b')
    val m=print(printA(id,einn',args'))
    
    in
        splitIt(0,(id,einn',args'))
    end





end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0