Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/split-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/split-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2521 - (download) (annotate)
Thu Jan 9 02:17:07 2014 UTC (5 years, 7 months ago) by cchiw
File size: 8414 byte(s)
Added type Checker
(* Split Functions before code generation process*)
structure split = struct
    local
    structure E = Ein
    structure P=Printer
    structure shift=shift
    in



fun printA(asn,e,args)=print(String.concat[asn,"=",P.printerE(e),"(",
(String.concatWith "," args),") \n"])

fun createEin( params,index, body)=Ein.EIN{params=params, index=index, body=body}
fun flat xs = List.foldr op@ [] xs
val counter=ref 0
fun fresh _=let
    val ref x=counter
    val m=x+1
    in (counter:=m;"Q"^Int.toString(x)) end

fun createnewb (params,index,args,(id,e))=let
    val (p',b',args)=shift.cleanParams(e,params,args)
    val a=createEin(p',index, b')
    in (id,a,args)
end

fun createnewP (params,args,(id,e,ix))=let
    val (p',b',args)=shift.cleanParams(e,params,args)
    val a=createEin(p',ix, b')
    in (id,a,args)
end



(*Outside Operator is Neg*)
fun handleNeg(params, index,e,args)=let
    val id=ref (length params)
    val ix=List.map (fn e=> E.V e)  index
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
    fun replace e'= let
        val q=fresh 1
        val t=mkTensor 0        
        val newbie=createEin( params,index,e')
        in ([(q,newbie,args)],(params@[E.TEN 1], E.Neg t,[q]))
        end

    in  (case e
        of E.Add  _=> replace e
        | E.Sub _=> replace e
        | E.Prod _=> replace e
        | E.Div _=> replace e
        | E.Sum _=> replace e
        | _=>([],(params, e, args))
        (*end case*))
    end




(*Outside Operator is Add*)
fun handleAdd(params, index,list1,args)=let
    val id=ref (length params)
    val ix=List.map (fn e=> E.V e)  index
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);[E.Tensor(idx,ix)]) end

    fun foundOp(e,es,(lft,newbies,params,args))=let
        val q=fresh 1
        in  (es,(lft@(mkTensor 0),  newbies@[(q, e)],params@[E.TEN 1],args@[q]))
        end


    fun sort([], m)=m
        | sort(e::es,m)=(case e
            of E.Add p => sort(p@es, m)
            | E.Sub _=>sort (foundOp(e, es,m))
            | E.Prod _=>sort (foundOp(e, es,m))
            | E.Div _=>sort (foundOp(e, es,m))
            | E.Neg _=>sort (foundOp(e, es,m))
            | E.Sum _=>sort (foundOp(e, es,m))
            | _ => let
                val (l,n, p, a)=m
                in sort(es,(l@[e],n,p,a)) end
            (*end case *))

    val (lft, newbies,params',args')=sort(list1,([],[],[],[]))
    val (p',b',args')= shift.cleanParams(E.Add(lft),params@params', args@args')
    val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
    in
        (z,(p',b',args'))
    end



fun findOp e=(case e
    of E.Neg _=>1
    | E.Add _=>1
    | E.Sub _=>1
    | E.Prod _=>1
    | E.Div _=>1
    | E.Sum _ =>1 
    |  _=>0
(*end case*))


(*Outside Operator is Sub*)
fun handleSub(params, index,e1,e2,args)=let
    val id=ref (length params)
    val ix=List.map (fn e=> E.V e)  index
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end

    fun subsort(e)= let
        val s=findOp e
        in (case s
            of 0=>(e,[],[],[])
            | _=> let
                val q=fresh 1
                in (mkTensor 0, [(q, e)],[E.TEN 1],[q]) end
            (*end case*))
        end 

    val (lft1, newbies1,params1,args1)=subsort e1
    val (lft2, newbies2,params2,args2)=subsort e2
    val (p',b',args')= shift.cleanParams(E.Sub(lft1,lft2), params@params1@params2, args@args1@args2)
    val newbies=newbies1@newbies2
       val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
    in
        (z,(p',b',args'))
    end

(*Outside Operator is Div *)
fun handleDiv(params, index,e1,e2,args)=let
    val id=ref (length params)
    val ix=List.map (fn e=> E.V e)  index
    fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
    fun mkSca _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,[])) end

    fun divsort(e,nextfn)= let
        val s=findOp e
        in (case s
            of 0=>(e,[],[],[])
            | _=> let
                val q=fresh 1
                in (nextfn 0, [(q, e)],[E.TEN 1],[q]) end
            (*end case*))
        end

    val (lft1, newbies1,params1,args1)=divsort(e1,mkTensor)
    val (lft2, newbies2,params2,args2)=divsort(e2,mkSca)
    val (p',b',args')= shift.cleanParams(E.Div(lft1,lft2),params@params1@params2, args@args1@args2)

    val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
    val z2=List.map (fn(e)=> createnewb(params,[],args,e) ) newbies2
    in
        (z1@z2,(p',b',args'))
    end



fun hProd(params, index,list1,args)=let
    val id=ref (length params)
    val n=length index

    fun mkPTensor e=let
        val ref idx= id
        (*
        val ix=shift.findOuterIndex(e, n)
        val (index',e')=shift.mapSumIndex(e,ix,index)*)
        val (ix,index',e')=shift.cleanIndex(e, n, index)

   
        in (id:=(idx+1);([E.Tensor(idx,ix)],index',e')) end

    fun foundOp (e,es,(lft,newbies, params, args))=let
        val q=fresh 1
        val (p,ix,e')=mkPTensor e
        in (es,(lft@p, newbies@[(q, e',ix)],params@[E.TEN 1],args@[q])) end

    fun sort([], m)=m
        | sort(e::es,m)=(case e
            of E.Add _ => sort (foundOp(e,es,m))
            | E.Sub _=>sort (foundOp(e, es,m))
            | E.Prod p=>sort (p@es,m)
            | E.Div _=>sort (foundOp(e, es,m))
            | E.Neg _=>sort (foundOp(e, es,m))
            | E.Sum _=>sort (foundOp(e, es,m))
            | _ => let
                val (l,n, p, a)=m
                in sort(es,(l@[e],n,p,a)) end
            (*end case *))

    in
        sort(list1,([],[],[],[]))
    end

fun handleProd(params, index,list1,args)=let
    val (lft, newbies,params',args')=hProd(params, index,list1,args)
    val (p',b',args')= shift.cleanParams(E.Prod lft, params@params', args@args')
    val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
    in
        (z,(p',b',args'))
    end

fun handleSumProd(params, ind,sx,list1,args)=let
    val id=ref (length params)
    val n=length ind

    fun g(lft,[],_)=(1,lft)
        | g(lft,(E.V s,0,ub)::es,n')=if(s=n') then g(lft@[ub],es,n'+1) else (0,[])
        | g _ =(0,[])

    val (c,index')= g([],sx,n)

    in  case c
        of 0=> ([],(params,E.Sum(sx, E.Prod(list1)),args))
        |_=>let
            val index=ind@index'
            val (lft, newbies,params',args')=hProd(params, index,list1,args)
            val (p',b',args')= shift.cleanParams(E.Sum(sx,E.Prod lft), params@params', args@args')
            val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
            in
                (z,(p',b',args'))
            end
    end 





fun genfn(id,Ein.EIN{params, index, body},args)= let

    val notDone=([],(params,body,args))
    fun gen body=(case body
        of  E.Field _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Partial _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Apply _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Probe _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Conv _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Krn _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Img _=> raise Fail(concat["Invalid Field here "]   )
        | E.Const _=> notDone
        | E.Tensor(id,[])=> notDone
        | E.Neg(E.Neg e)=> gen e
        | E.Neg e=>  handleNeg(params, index,e,args)
        | E.Add a => handleAdd(params, index,a,args)
        | E.Sub(E.Sub(a,b),E.Sub(c,d))=> gen(E.Sub(E.Add[a,d],E.Add[b,c]))
        | E.Sub(E.Sub(a,b),e2)=>gen (E.Sub(a,E.Add[b,e2]))
        | E.Sub(e1,E.Sub(c,d))=>gen(E.Add([E.Sub(e1,c),d]))
        | E.Sub(e1,e2)=>handleSub(params, index,e1,e2,args)
        | E.Div(E.Div(a,b),E.Div(c,d))=> gen(E.Div(E.Prod[a,d],E.Prod[b,c]))
        | E.Div(E.Div(a,b),c)=> gen(E.Div(a, E.Prod[b,c]))
        | E.Div(a,E.Div(b,c))=> gen(E.Div(E.Prod[a,c],b))
        | E.Div(e1,e2)=>handleDiv(params, index,e1,e2,args)
        | E.Prod e=> handleProd(params, index,e,args)
        | E.Sum(sx,E.Prod e)=> handleSumProd(params, index,sx,e,args)
        (* need to shift sum indices *)
        | _=> notDone
    (*end case*))
   
    val (newbie,(p,b,arg))= gen body
    val e'=createEin(p,index, b)


    val f= (id,e',arg)
    in (newbie, f)
    end



fun splitein e=let
    val (newbie, e')= genfn(e)
    val newbie'=flat(List.map splitein newbie)
    in
        newbie'@[e']
    end




end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0