Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/filter-ein.sml
 [diderot] / branches / charisee / src / compiler / ein / filter-ein.sml

# View of /branches/charisee/src/compiler/ein/filter-ein.sml

Wed Apr 30 01:46:09 2014 UTC (7 years, 2 months ago) by cchiw
File size: 7253 byte(s)
`code cleanup`
```structure Filter = struct

local

structure E = Ein
structure P=Printer
in

fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])

(*Flattens Add constructor: change, expression *)
|flatten(i,((E.Const c):: l'))=
if (c>0 orelse 0>c) then let
val (b,a)=flatten(i,l') in (b,[E.Const c]@a) end
else flatten(1,l')
| flatten(i,[])=(i,[])
| flatten (i,e::l') =  let
val(b,a)=flatten(i,l') in (b,[e]@a) end

val (b,a)=flatten(0,e)
in case a
of [] => (1,E.Const(1))
| [e] => (1,e)
(* end case *)
end

fun mkProd [e]=(1,e)
| mkProd(e)=let
fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')
|flatten(i,((E.Const c)::l'))=
if(c>0 orelse  0>c) then (3,[E.Const 0])
else flatten(i,l')
| flatten(i,[])=(i,[])
| flatten (i,e::l') =  let val(a,b)=flatten(i,l') in (a,[e]@b) end
val (change,a)=flatten(0,e)
in if(change=3) then (1,E.Const(0))
else case a
of [] => (1,E.Const(0))
| [e] => (1,e)
| es => (change, E.Prod es)
(* end case *)
end

(*filter scalars and greeks*)
fun filterGreek e= let
fun filter([],pre,eps,dels, post)=(pre,eps,dels ,post)
| filter(e1::es,pre,eps,dels, post)=(case e1
of E.Prod p                     => filter(p@es, pre, eps,dels,post)
| E.Field(_,[])                 => filter(es, pre@[e1], eps,dels,post)
| E.Conv(_,[],_,[])             => filter(es, pre@[e1], eps,dels,post)
| E.Probe(E.Field(_,[]),_)      => filter(es, pre@[e1], eps,dels,post)
| E.Probe(E.Conv(_,[],_,[]),_)  => filter(es, pre@[e1], eps,dels,post)
| E.Tensor(id,[])               => filter(es, pre@[e1], eps,dels,post)
| E.Const _                     => filter(es, pre@[e1], eps,dels,post)
| E.Epsilon _                   => filter(es, pre,eps@[e1],dels, post)
| E.Delta _                     => filter(es, pre,eps,dels@[e1], post)
|  _                            => filter(es, pre, eps, dels, post@[e1])
(*end case *))
in filter(e,[],[],[],[])
end

(* Note Lift indicates a Tensor*)
(*So expression is either Lift, del, eps, or contains a Field*)
fun filterField e= let
fun filter([],pre,post)=(pre,post)
| filter(e1::es, pre,post)=(case e1
of E.Prod p     => filter(p@es, pre, post)
| E.Lift _    => filter(es, pre@[e1], post)
| E.Epsilon _ => filter(es, pre@[e1], post)
| E.Delta _   => filter(es, pre@[e1], post)
| _           => filter(es,pre, post@[e1])
(*end case*))
in filter(e,[],[])
end

fun filterPartial([])=[]
| filterPartial(E.Partial d1::es)=d1@filterPartial(es)
| filterPartial _= err"Found non-Partial in Apply"

fun findeps(eps,[],rest)                  = (eps,rest,[])
| findeps(eps,e1::es ,rest)=(case e1
of (E.Epsilon eps1)                   => findeps(eps@[e1],es ,rest)
| E.Prod p                            => findeps(eps,p@es, rest)
| E.Field _                           => findeps(eps,es,rest@[e1])
| E.Tensor _                          => findeps(eps,es,rest@[e1])
| E.Sum(c,E.Prod(E.Epsilon eps1::ps)) => (eps@[E.Epsilon eps1], rest@ps@es,c)
|  _                                  => (eps,rest@[e1]@es,[])
(*end case*))

(* filter Scalars outside Summation product *)
fun filterSca(c,e)= let
fun filter([],[],[post])=(0,E.Sum(c,post))
| filter([],[],post)=(0,E.Sum(c,E.Prod(post)))
| filter([],pre,[post])=(1,E.Prod(pre@[E.Sum(c,post)]))
| filter([],pre,post)=(1,E.Prod(pre@[E.Sum(c,E.Prod(post))]))
| filter(e1::es, pre,post)=(case e1
of E.Prod p                     => filter(p@es, pre, post)
| E.Field(_,[])                 => filter(es, pre@[e1], post)
| E.Conv(_,[],_,[])             => filter(es, pre@[e1], post)
| E.Probe(E.Field(_,[]),_)      => filter(es, pre@[e1], post)
| E.Probe(E.Conv(_,[],_,[]),_)  => filter(es, pre@[e1], post)
| E.Tensor(id,[])               => filter(es, pre@[e1], post)
| E.Const _                     => filter(es, pre@[e1], post)
| _                             => filter(es,pre, post@[e1])
(*end case*))
in filter(e,[],[])
end

fun findIndex(v,searchspace )=List.find (fn x => x = v) searchspace

(*Question is c, in e *)
fun foundSx(c,e)=let

fun sort []= NONE
| sort(e1::es)= (case foundSx(c, e1)
of NONE => sort(es)
|SOME s => SOME s
(*end case *))

in (case e
of E.Krn _                  => raise Fail"Krn used pre expansion"
| E.Img _                   => raise Fail"Img used pre expansion"
| E.Value _                 => NONE
| E.Const _                 => NONE
| E.Tensor(id,[])           => NONE
| E.Conv(v,[],h,[])         => NONE
| E.Probe(E.Conv(v,[],h,[]),E.Tensor(t,[])) =>  NONE
| E.Tensor(id,shape)        => findIndex(c,shape)
| E.Field(id,shape)         => findIndex(c,shape)
| E.Delta(i,j)              => findIndex(c, [i,j])
| E.Epsilon (i,j,k)         => findIndex(c,[E.V i,E.V j,E.V k])
| E.Partial (shape)         => findIndex(c,shape)
| E.Conv(_ , alpha, _ , dx) => findIndex(c, alpha@dx)
| E.Neg a                   => foundSx(c,a)
| E.Lift a                  => foundSx(c,a)
| E.Sum(_,a)                => foundSx(c,a)
| E.Apply(e1,e2)            => sort([e1,e2])
| E.Sub(e1,e2)              => sort([e1,e2])
| E.Div(e1,e2)              => sort([e1,e2])
| E.Probe(e1,e2)            => sort([e1,e2])
| E.Prod a                  => sort a
| E.Add a                   => sort a
(*end case*))
end

(*Approach, Look to see if each expression has index,
flattens product, does no other rewriting.
*)

fun pushSum(c,p)= let
val (v, lb, ub)=c
fun filter([],[keep],[])=   (0,E.Sum([c],keep))
| filter([],keep,[])=       (0,E.Sum([c],E.Prod(keep)))
| filter(s,[],[])=          (1,E.Prod(s))
| filter(s,[keep],[])=      (1,E.Prod(s@[E.Sum([c],keep)]))
| filter(s,keep,[])=        (1,E.Prod(s@[E.Sum([c],E.Prod(keep))]))
| filter(s,keep,e1::es)= (case e1
of E.Prod p=> filter(s,keep, p@es)
| _ =>(case foundSx(v,e1)
of NONE => filter(s@[e1],keep,es)
| SOME _  => filter(s,keep@[e1],es)
(*end case*))
(*end case*))
in filter([],[],p)
end

fun splitSum(c,p)= let
val (v, lb, ub)=c
fun filter(s,keep,[])=(s,keep)
| filter(s,keep,e1::es)= (case e1
of E.Prod p=> filter(s,keep, p@es)
| _ =>(case foundSx(v,e1)
of NONE => filter(s@[e1],keep,es)
| SOME _  => filter(s,keep@[e1],es)
(*end case*))
(*end case*))
in filter([],[],p)
end

end

end (* local *)```