Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/ein/check-ein.sml
ViewVC logotype

View of /branches/ein16/src/compiler/ein/check-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3682 - (download) (annotate)
Thu Feb 18 20:13:18 2016 UTC (4 years, 5 months ago) by cchiw
File size: 6071 byte(s)
creating stable branch that represents ein ir
(*so the summation index binds the embedded expression. This allows multiple instances of the same index such as
 *
 Σ_i (A_i B_i)+ Σ_i(C_i D_i)
 Σ_i (A_i B_i)* Σ_i(C_i D_i)
 
 *)
 
structure CheckEin = struct
    local
    structure E = Ein
    structure P=Printer

    in

    val testing=0
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))
    fun err str=raise Fail(String.concat(["check-ein: "]@str))

    (*get shape of indices*)
    fun shapeAlpha []=[]
      | shapeAlpha (E.V v::es)= [v]@shapeAlpha es
      | shapeAlpha (E.C _::es)= shapeAlpha es

    (*get shape of exp. Determined by number of non constant indices*)
    fun shapeExp(E.Tensor(_,alpha))=length(shapeAlpha alpha)
      | shapeExp(E.Const _)  = 0
      | shapeExp(E.Value _)  = 0
      | shapeExp(E.ConstR _) = 0
      | shapeExp e =err ["Shape of: ",P.printbody e]

    (*check that the shape of everything on the list is the same*)
    fun iter ([],_)=1
      | iter(e1::es,~1)= iter(es,shapeExp e1)
      | iter(e1::es,n)= if(n=shapeExp e1) then iter(es,n) else err ["Uneven buildDict"]

    (*use by the index dictionary *)
    fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s

    fun lookup k d = d k
    val empty =fn key =>NONE

    (*mkMap:int *dict->dict
    *creates mapp for index
    *)
    fun mkMapp(0,mapp)=mapp
      | mkMapp(n,mapp)=let
        val dict=insert(n-1, 0) mapp
        in mkMapp(n-1,dict) end
    
    (* removeMapp:sum_index_id list * dict->dict
    * sets lookup to ~1
    *)
    fun removeMapp([],mapp)=mapp
      | removeMapp((E.V n,_,_)::sx,mapp)=let
        val dict=insert(n, ~1) mapp
        in removeMapp(sx,dict) end

    fun mkSumMapp([],mapp,_)=mapp
      | mkSumMapp(e1::ns,mapp,b)=
        (case e1
            of (E.V n,_,_)=>(case (lookup n mapp)
                of SOME ~1=>let
                    val dict=insert(n, 0) mapp
                    in mkSumMapp(ns,dict,b) end
                | NONE=> let
                    val dict=insert(n, 0) mapp
                    in mkSumMapp(ns,dict,b) end
                | _ => err(["More than one summation for Index<",Int.toString(n),">"])
                (*end case*))
           | _ =>err["non-variable index in summation"]
         (*end case*))
                
    fun checkMu([],dict)=dict
      | checkMu(E.C _::es,dict)=checkMu(es,dict)
      | checkMu(E.V v::es,dict) =
        (case lookup v dict
            of SOME s=> let
                val d'=insert(v, s+1) dict
                in checkMu(es,d')
                end
            | NONE => err ["Unknown Index<",Int.toString(v),">"]
        (*end case*))
                
     fun checkKrnDel([],dict)=dict
      | checkKrnDel(E.C _::es,dict)=checkKrnDel(es,dict)
      | checkKrnDel(E.V v::es,dict) =
        (case lookup v dict
          of SOME s=> let
            val d'=insert(v, 2) dict
            in checkKrnDel(es,d')
            end
         | NONE => err ["Unknown Index<",Int.toString(v),">"]
        (*end case*))
        
    (*checks that the outer indices are in the mapp
    * Really that in <e>_alpha alpha is in e.
    *)
    fun checkMapp(0,newMapp)= 1
        | checkMapp(n,newMapp)=let
             val n'=n-1
             in (case (lookup n' newMapp)
                of NONE  => err["Did not find Outer Index:",Int.toString(n)]
                | SOME _ =>checkMapp (n',newMapp)
                (*end case*))
             end
        
    fun mappBody(body,mapp)=let
        fun buildDict([],dict)=dict
          | buildDict(e1::es,dict)=let
            val dict'=mappBody(e1, dict)
            in
                buildDict(es, dict')
            end
        in (case body
            of E.Const _          => mapp
            | E.Delta ( i, j)     => checkMu([i, j], mapp)
            | E.Epsilon(i,j,k)    => checkMu([E.V i, E.V j,  E.V k], mapp)
            | E.Eps2(i,j)         => checkMu([E.V i, E.V j], mapp)
            | E.Value v           => checkMu([E.V v], mapp)
            | E.Tensor(id, ix)    => checkMu(ix,mapp)
            | E.Neg e             => mappBody(e,mapp)
            | E.Sqrt e            => mappBody(e,mapp)
            | E.Add e             => (iter (e,~1); buildDict(e,mapp))
            | E.Sub (e1,e2)       => (iter ([e1,e2],~1);buildDict([e1,e2],mapp))
            | E.Prod e            => buildDict(e,mapp)
            | E.Div(e1,e2)        => buildDict([e1,e2],mapp)
            | E.Conv _            => err ["Conv- Should have been expanded"]
            | E.Field _           => err ["Field- Should have been expanded"]
            | E.Partial _         => err ["Partial- Should have been expanded"]
            | E.Apply _           => err ["Apply- Should have been expanded"]
            | E.Lift _            => err ["Lift- Should have been expanded"]
            | E.Probe(e,x)        => err ["Probe- Should have been expanded"]
            | E.Img(_,alpha,pos)  => buildDict(pos,checkMu(alpha,mapp))
            | E.Krn(_,delta,pos)  => let
                 val dels=List.map (fn(_, e)=>e) delta
                 val mapp'=checkKrnDel(dels, mapp)
                 in
                    mappBody(pos,mapp')
                 end 
            | E.Sum(sx,e)       => let
                val dict=mkSumMapp(sx,mapp,body)
                val d'=mappBody(e,dict)
                in
                    removeMapp(sx,d')
                end
            | E.Cosine e          => mappBody(e,mapp)
            | E.ArcCosine e       => mappBody(e,mapp)
            | E.Sine e            => mappBody(e,mapp)
            | E.ArcSine e       => mappBody(e,mapp)
            | _                   => err ["missing ",P.printbody body]
        (*end case*))
    end

    fun checkEIN e =let
        val size=length(Ein.index e)
        val mapp=mkMapp(size,empty)
        val newMapp=mappBody(Ein.body e,mapp)
        in
            checkMapp(size,newMapp)
        end
    (*handle ex => (print(concat["\n *** error check-ein  \n",P.printerE(e)]); raise ex)*)


    end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0