Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-il/EpsHelpers.sml
 [diderot] / branches / charisee / src / compiler / high-il / EpsHelpers.sml

# View of /branches/charisee/src/compiler/high-il/EpsHelpers.sml

Tue Nov 25 03:40:24 2014 UTC (4 years, 8 months ago) by cchiw
File size: 5131 byte(s)
`edit split-ein`
```structure EpsHelpers = struct

local

structure E = Ein
structure P=Printer
structure F=Filter
in

fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])

(*Decide if two Eps changes to Deltas*)
(*Return:change, Deltas1, Deltas2 *)
fun doubleEps(e1,e2)=let
(*Function is called when eps are being changed to deltas*)
fun createDeltas(s,t,u,v)= let
val d1= [E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)]
val d2= [E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]

in (1,d1,d2)
end

in (case (e1,e2)
of  (E.Epsilon (a,b,c),E.Epsilon(d,e,f))=>
if(a=d) then createDeltas(b,c,e,f)
else if(a=e) then createDeltas(b,c,f,d)
else if(a=f) then createDeltas(b,c,d,e)
else if(b=d) then createDeltas(c,a,e,f)
else if(b=e) then createDeltas(c,a,f,d)
else if(b=f) then createDeltas(c,a,d,e)
else if(c=d) then createDeltas(a,b,e,f)
else if(c=e) then createDeltas(a,b,f,d)
else if(c=f) then createDeltas(a,b,d,e)
else (0,[],[])
| _=> raise Fail"None Epsilon Arguement")
end

(*Distribute Eps*)
(*Return: change,Rest of Eps, d1, d2,outer-summation, rest*)
(*If we use the eps in the embedded summation, then we move sx to outer summation
Otherwise keep the embedded summation term in "rest"
*)

fun distributeEps(epsAll,sumexp)=let
(*M-Embedded summaiton is not used*)
val M=(0,[],[],[],[],[])

fun distEps([],_)     = M
| distEps([e1],eps)=(
(case sumexp
of  [E.Sum(sx,E.Prod(e2::ps))] =>  (case doubleEps(e1,e2)
of (1,d1,d2)    => (1, eps, d1,d2,sx, ps)
| _             =>(0,[],[],[],[],[])
(*end case*))
| _ => (0,[],[],[],[],[])))

| distEps(e1::e2::current,eps) = ((case doubleEps(e1,e2)
of (1,d1,d2)    => (1, eps@current, d1,d2,[],sumexp)
| _             => distEps(e2::current, eps@[e1])
(*end case*)))
in
distEps(epsAll,[])
end

(* Transform eps to deltas*)
(*return: change, e', sx  *)
fun epsToDels e= let
val (epsAll,rest,sumexp)           = F.filterEps e
in (case (distributeEps(epsAll,sumexp))
of (0, _ , _ , _ , _ , _)  => (0,E.Const 0,[],epsAll,rest@sumexp)
| (change, epsUnused,d1,d2,sx,ps)  => let
val a=E.Prod(epsUnused@d1@rest@ps)
val b=E.Prod(epsUnused@d2@rest@ps)
in
(1,E.Sub(a,b),sx,[],[])
end
(*end case *))
end

(* Apply deltas to tensors/fields*)
fun reduceDelta(eps, dels, es)=let
fun distPart([],_,_,rest) =(0 ,rest)
| distPart(p::pd,i,j,rest)=
if(p=j) then (1,rest@[i]@pd)
else distPart(pd,i,j,rest@[p])

fun distribute(change,d,dels,[],done)=(change,dels@d,done)
| distribute(change,[],[],e,done)=(change,[],done@e)
| distribute(change,[],dels,e::es,done)=distribute(change,dels,[],es,done@[e])
| distribute(change,E.Delta(i,j)::ds,dels,e::es,done)=(case e
of  E.Tensor(id,[tx])=>
if(j=tx) then distribute(change@[j],dels@ds,[] ,es ,done@[E.Tensor(id,[i])])
else distribute(change,ds,dels@[E.Delta(i,j)],E.Tensor(id,[tx])::es,done)
|  E.Field(id,[tx])=>
if(j=tx) then distribute(change@[j],dels@ds,[] ,es ,done@[E.Field(id,[i])])
else distribute(change,ds,dels@[E.Delta(i,j)],E.Field(id,[tx])::es,done)
| E.Apply(E.Partial d,e)=>let
val (change'',p')=distPart(d,i,j,[])
in (case change''
of 0=>distribute(change, ds,dels@[E.Delta(i,j)], [E.Apply(E.Partial d, e)]@es,done)
|_=> distribute(change@[j], dels@ds,[], es,done@[E.Apply(E.Partial p', e)])
(*end case*))
end
| E.Probe(E.Conv(v,alpha, h, d),t)=> let
val (change'',p')=distPart(d,i,j,[])
in (case change''
of 0=>distribute(change, ds,dels@[E.Delta(i,j)], [e]@es,done)
|_=> distribute(change@[j], dels@ds,[], es,done@[E.Probe(E.Conv(v,alpha,h, p'),t)])
(*end case*))
end
| _=>distribute(change,dels@[E.Delta(i,j)]@ds,[],es,done@[e])
(*end case*))
| distribute _= raise Fail("Non-Delta in distribute function")

val (change,dels',done)=distribute([],dels,[],es,[])

in
(length change, E.Prod (eps@dels'@done))
end

fun matchEps(2,_,_,_)= 1 (*matched 2*)
| matchEps(num,_,_,[])=0
| matchEps(0,_,_,[eps])=0
| matchEps(num,[],rest,eps::epsx)=
matchEps(num,rest,[],epsx)
| matchEps(num,E.V p::px,rest,eps::epsx)=
if(p=eps) then matchEps(num+1,rest@px,[],epsx)
else matchEps(num,px,rest@[E.V p], eps::epsx)
| matchEps(num,p::px,rest,eps)= matchEps(num,px,rest,eps)

end

end (* local *)```