Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/filter-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/filter-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2605 - (download) (annotate)
Wed Apr 30 01:46:09 2014 UTC (7 years, 2 months ago) by cchiw
File size: 7253 byte(s)
code cleanup
structure Filter = struct

    local

    structure E = Ein
    structure P=Printer
    in

fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
 
(*Flattens Add constructor: change, expression *)
fun mkAdd [e]=(1,e)
    | mkAdd(e)=let
    fun flatten((i, (E.Add l)::l'))= flatten(1,l@l')
        |flatten(i,((E.Const c):: l'))=
            if (c>0 orelse 0>c) then let
                val (b,a)=flatten(i,l') in (b,[E.Const c]@a) end
            else flatten(1,l')
        | flatten(i,[])=(i,[])
        | flatten (i,e::l') =  let
                    val(b,a)=flatten(i,l') in (b,[e]@a) end
    
    val (b,a)=flatten(0,e)
    in case a
        of [] => (1,E.Const(1))
        | [e] => (1,e)
        | es => (b,E.Add es)
        (* end case *)
     end
        

fun mkProd [e]=(1,e)
    | mkProd(e)=let
    fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')
        |flatten(i,((E.Const c)::l'))=
                if(c>0 orelse  0>c) then (3,[E.Const 0])
                else flatten(i,l')
         | flatten(i,[])=(i,[])
         | flatten (i,e::l') =  let val(a,b)=flatten(i,l') in (a,[e]@b) end
     val (change,a)=flatten(0,e)
     in if(change=3) then (1,E.Const(0))
        else case a
            of [] => (1,E.Const(0))
            | [e] => (1,e)
            | es => (change, E.Prod es)
            (* end case *)
    end
                
                
(*filter scalars and greeks*)
fun filterGreek e= let
    fun filter([],pre,eps,dels, post)=(pre,eps,dels ,post) 
    | filter(e1::es,pre,eps,dels, post)=(case e1
        of E.Prod p                     => filter(p@es, pre, eps,dels,post)
        | E.Field(_,[])                 => filter(es, pre@[e1], eps,dels,post)
        | E.Conv(_,[],_,[])             => filter(es, pre@[e1], eps,dels,post)
        | E.Probe(E.Field(_,[]),_)      => filter(es, pre@[e1], eps,dels,post)
        | E.Probe(E.Conv(_,[],_,[]),_)  => filter(es, pre@[e1], eps,dels,post)
        | E.Tensor(id,[])               => filter(es, pre@[e1], eps,dels,post)
        | E.Const _                     => filter(es, pre@[e1], eps,dels,post)
        | E.Epsilon _                   => filter(es, pre,eps@[e1],dels, post)
        | E.Delta _                     => filter(es, pre,eps,dels@[e1], post)
        |  _                            => filter(es, pre, eps, dels, post@[e1])
    (*end case *))
    in filter(e,[],[],[],[])
    end
                         
                
(* Note Lift indicates a Tensor*)
(*So expression is either Lift, del, eps, or contains a Field*)
fun filterField e= let
    fun filter([],pre,post)=(pre,post)
    | filter(e1::es, pre,post)=(case e1
        of E.Prod p     => filter(p@es, pre, post)
          | E.Lift _    => filter(es, pre@[e1], post)
          | E.Epsilon _ => filter(es, pre@[e1], post)
          | E.Delta _   => filter(es, pre@[e1], post)
          | _           => filter(es,pre, post@[e1])
          (*end case*))
    in filter(e,[],[])
    end
                
   
fun filterPartial([])=[]
    | filterPartial(E.Partial d1::es)=d1@filterPartial(es)
    | filterPartial _= err"Found non-Partial in Apply"
                
fun findeps(eps,[],rest)                  = (eps,rest,[])
  | findeps(eps,e1::es ,rest)=(case e1
    of (E.Epsilon eps1)                   => findeps(eps@[e1],es ,rest)
    | E.Prod p                            => findeps(eps,p@es, rest)
    | E.Field _                           => findeps(eps,es,rest@[e1])
    | E.Tensor _                          => findeps(eps,es,rest@[e1])
    | E.Sum(c,E.Prod(E.Epsilon eps1::ps)) => (eps@[E.Epsilon eps1], rest@ps@es,c)
    |  _                                  => (eps,rest@[e1]@es,[])
    (*end case*))

(* filter Scalars outside Summation product *)
fun filterSca(c,e)= let
    fun filter([],[],[post])=(0,E.Sum(c,post))
    | filter([],[],post)=(0,E.Sum(c,E.Prod(post)))
    | filter([],pre,[post])=(1,E.Prod(pre@[E.Sum(c,post)]))
    | filter([],pre,post)=(1,E.Prod(pre@[E.Sum(c,E.Prod(post))]))
    | filter(e1::es, pre,post)=(case e1
        of E.Prod p                     => filter(p@es, pre, post)
        | E.Field(_,[])                 => filter(es, pre@[e1], post)
        | E.Conv(_,[],_,[])             => filter(es, pre@[e1], post)
        | E.Probe(E.Field(_,[]),_)      => filter(es, pre@[e1], post)
        | E.Probe(E.Conv(_,[],_,[]),_)  => filter(es, pre@[e1], post)
        | E.Tensor(id,[])               => filter(es, pre@[e1], post)
        | E.Const _                     => filter(es, pre@[e1], post)
        | _                             => filter(es,pre, post@[e1])
        (*end case*))
    in filter(e,[],[])
    end


fun findIndex(v,searchspace )=List.find (fn x => x = v) searchspace


(*Question is c, in e *)
fun foundSx(c,e)=let

    fun sort []= NONE
    | sort(e1::es)= (case foundSx(c, e1)
        of NONE => sort(es)
        |SOME s => SOME s
        (*end case *))

    in (case e
        of E.Krn _                  => raise Fail"Krn used pre expansion"
        | E.Img _                   => raise Fail"Img used pre expansion"
        | E.Value _                 => NONE
        | E.Const _                 => NONE
        | E.Tensor(id,[])           => NONE
        | E.Conv(v,[],h,[])         => NONE
        | E.Probe(E.Conv(v,[],h,[]),E.Tensor(t,[])) =>  NONE
        | E.Tensor(id,shape)        => findIndex(c,shape)
        | E.Field(id,shape)         => findIndex(c,shape)
        | E.Delta(i,j)              => findIndex(c, [i,j])
        | E.Epsilon (i,j,k)         => findIndex(c,[E.V i,E.V j,E.V k])
        | E.Partial (shape)         => findIndex(c,shape)
        | E.Conv(_ , alpha, _ , dx) => findIndex(c, alpha@dx)
        | E.Neg a                   => foundSx(c,a)
        | E.Lift a                  => foundSx(c,a)
        | E.Sum(_,a)                => foundSx(c,a)
        | E.Apply(e1,e2)            => sort([e1,e2])
        | E.Sub(e1,e2)              => sort([e1,e2])
        | E.Div(e1,e2)              => sort([e1,e2])
        | E.Probe(e1,e2)            => sort([e1,e2])
        | E.Prod a                  => sort a
        | E.Add a                   => sort a
    (*end case*))
end



(*Approach, Look to see if each expression has index,
flattens product, does no other rewriting.
*)

fun pushSum(c,p)= let
    val (v, lb, ub)=c
    fun filter([],[keep],[])=   (0,E.Sum([c],keep))
    | filter([],keep,[])=       (0,E.Sum([c],E.Prod(keep)))
    | filter(s,[],[])=          (1,E.Prod(s))
    | filter(s,[keep],[])=      (1,E.Prod(s@[E.Sum([c],keep)]))
    | filter(s,keep,[])=        (1,E.Prod(s@[E.Sum([c],E.Prod(keep))]))
    | filter(s,keep,e1::es)= (case e1
        of E.Prod p=> filter(s,keep, p@es)
        | _ =>(case foundSx(v,e1)
            of NONE => filter(s@[e1],keep,es)
            | SOME _  => filter(s,keep@[e1],es)
            (*end case*))
        (*end case*))
    in filter([],[],p)
    end

fun splitSum(c,p)= let
    val (v, lb, ub)=c
    fun filter(s,keep,[])=(s,keep)
    | filter(s,keep,e1::es)= (case e1
        of E.Prod p=> filter(s,keep, p@es)
        | _ =>(case foundSx(v,e1)
            of NONE => filter(s@[e1],keep,es)
            | SOME _  => filter(s,keep@[e1],es)
            (*end case*))
        (*end case*))
    in filter([],[],p)
    end


end



end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0