Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-il/app-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-il/app-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3700 - (download) (annotate)
Mon Mar 28 23:12:00 2016 UTC (5 years, 5 months ago) by cchiw
File size: 6758 byte(s)
merge in c-util-basevis
(* substitution  179
 * Apply EIN opperator arguments to EIN operator.
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure App = struct

    local
   
    structure E = Ein
    structure P = Printer
    in

    fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s
     
    fun lookup k d = d k
    val empty =fn key =>NONE

    fun mapId(i ,dict,shift)=(case (lookup i dict)
        of NONE =>i+shift
        | SOME j=>j
        (*end case*))

    fun mapIndex(v ,dict,shift)= (case (lookup v dict)
        of NONE =>let val E.V(i)=v in E.V(i+shift) end
        | SOME j=>j
        (*end case*))

    fun mapId2(i ,dict,shift)= (case (lookup i dict)
        of NONE =>(print "Err out of range";i+shift)
        | SOME j=>j
        (*end case*))

    fun rewriteSubst(e,subId,mx,paramShift,sumShift,newArgs,done)=let
        fun insertIndex([],_,dict,shift)=(dict,shift)
        | insertIndex(e1::es, n,dict,_)=(case e1
            of E.V e=> insertIndex(es, n+1, insert(E.V n ,E.V e) dict,e-n)
             | E.C e => insertIndex(es, n+1, insert(E.V n ,E.C e) dict,e-n)
            (*end case*))

        val (subMu,shift)=insertIndex(mx,0,empty,0)
        val shift'=Int.max(sumShift, shift)
        fun  mapMu(E.V i)= mapIndex((E.V i), subMu, shift')
            | mapMu c = c 
        fun mapAlpha mx=List.map mapMu mx
        fun mapSingle(i)=let
            val E.V v=mapIndex(E.V i,subMu, shift')
            in
                v
            end
        fun mapSum []=[]
            | mapSum ((a,b,c)::e)=[((mapMu a),b,c)]@mapSum(e)
        fun mapParam(id)= mapId2(id, subId, 0)
        fun apply e=(case e
            of E.B _          => e
            | E.Tensor(id, mx)          => E.Tensor(mapParam id,mapAlpha mx)
            | E.G(E.Delta (i,j))        => E.G(E.Delta(mapMu i,mapMu j))
            | E.G(E.Epsilon(i, j, k))   => E.G(E.Epsilon(mapSingle i, mapSingle j, mapSingle k))
            | E.G(E.Eps2(i, j))         => E.G(E.Eps2(mapSingle i, mapSingle j))
            | E.Field(id, mx)           => E.Field(mapParam id,mapAlpha mx)
            | E.Lift e1                 => E.Lift(apply e1)
            | E.Conv (v,mx,h,ux)        => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)
            | E.Partial mx              => E.Partial (mapAlpha mx)
            | E.Apply(e1, e2)           => E.Apply(apply e1, apply e2)
            | E.Probe(f, pos)           => E.Probe(apply f, apply pos)
            | E.Value _                 => raise Fail "expression before expand"
            | E.Img  _                  => raise Fail "expression before expand"
            | E.Krn _                   => raise Fail "expression before expand"
            | E.Sum(c,esum)             => E.Sum(mapSum c, apply esum)
            | E.Op1(E.PowEmb(sx,n),e1)  => E.Op1(E.PowEmb(mapSum sx,n),apply e1)
            | E.Op1(op1,e1)             => E.Op1(op1,apply e1)
            | E.Op2(op2,e1,e2)          => E.Op2(op2,apply e1,apply e2)
            | E.Opn(opn,e1)             => E.Opn(opn,List.map apply e1)
            (*end case*))
            in
                apply e
            end

    (*params subst*)
    fun rewriteParams(params, params2, place)=let
        val beg=List.take(params,place)
        val next=List.drop(params,place+1)
        val params'=beg@params2@next
        val n= length(params)
        val n2=length(params2)
        val nbeg=length(beg)
        val nnext=length(next)
        fun createDict(0,shift1, shift2,dict)= dict
          | createDict(n,shift1, shift2,dict)=createDict(n-1,shift1,shift2, insert(n+shift1,n+shift2) dict)
        val origId=createDict(nnext,place,place+n2-1,empty)
        val subId=createDict(n2,~1,place-1,empty)
        in
            (params',origId,subId,nbeg)
        end


    (*Looks for params id that match substitution*)
    fun app(E.EIN{params, index, body},place,e2,newArgs,done)=let
        val changed = ref 0
        val params2=E.params e2
        val index2=E.index e2
        val body2=E.body e2
        val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)
        (*val err=String.concat["Wrong size for Subst:",
                P.printbody body,"-with-",P.printbody body2,"@",Int.toString place]
*)
fun err(mx)=String.concat["\n***\nWrong size for Subst:",
P.printbody body,"-with-",P.printbody body2,"@",Int.toString place,
"index2:", Int.toString( length index2), "mx:",Int.toString(length mx)]


        val sumIndex=ref (length index)
        fun rewrite(id,mx ,e)=let
            val ref x=sumIndex
            in 
            if(id=place) then
                if(length(mx)=length(index2)) then
                    (changed:=1; rewriteSubst(body2,substId,mx,paramShift,x,newArgs,done))
                else ( raise Fail(err(mx));E.B(E.Const 0))
            else (case e
                of E.Tensor(id,mx) => E.Tensor(mapId(id,origId,0), mx)
                | E.Field(id,mx) =>    E.Field(mapId(id,origId,0), mx)
                |  _ => raise Fail"Id error:Term to be replaced is not a Tensor or Fields"
                (*end case*))
            end
        fun sumI(e)=let
            val (E.V v,_,_)=List.nth(e, length(e)-1)
            in v end

        fun apply b=(case b
            of E.B _                    => b
            | E.Tensor(id, mx)          => rewrite (id,mx,b)
            | E.G _                     => b
            | E.Field(id, mx)           => rewrite (id,mx,b)
            | E.Lift e1                 => E.Lift(apply e1)
            | E.Conv (v,mx,h,ux)        => E.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)
            | E.Partial mx              => b
            | E.Apply(e1, e2)           => E.Apply(apply e1, apply e2)
            | E.Probe(f, pos)           => E.Probe(apply f, apply pos)
            | E.Value _                 => raise Fail "expression before expand"
            | E.Img  _                  => raise Fail "expression before expand"
            | E.Krn _                   => raise Fail "expression before expand"
            | E.Sum(c,esum)             => (sumIndex:=sumI(c); E.Sum(c, apply esum))
            | E.Op1(E.PowEmb(sx,n),e1)    => (sumIndex:=sumI(sx);E.Op1(E.PowEmb(sx,n),apply e1))
            | E.Op1(op1, e1)            => E.Op1(op1,apply e1)
            | E.Op2(op2, e1,e2)         => E.Op2(op2,apply e1,apply e2)
            | E.Opn(opn, es)            => E.Opn(opn,List.map apply es)
        (*end case*))
        val body''=apply body
        val ref g=changed
        in
            ( g,E.EIN{params=params', index=index, body=body''})
        end


  end; (* local *)

    end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0