Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/high-il/app-ein.sml
ViewVC logotype

View of /branches/ein16/src/compiler/high-il/app-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log

Revision 4410 - (download) (annotate)
Fri Aug 12 18:28:32 2016 UTC (2 years, 11 months ago) by cchiw
File size: 7129 byte(s)
added inverse
(* substitution  179
 * Apply EIN opperator arguments to EIN operator.
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.

structure App = struct

    structure E = Ein
    structure P = Printer

    fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s
    fun lookup k d = d k
    val empty =fn key =>NONE

    fun mapId(i ,dict,shift)=(case (lookup i dict)
        of NONE =>i+shift
        | SOME j=>j
        (*end case*))

    fun mapIndex(v ,dict,shift)= (case (lookup v dict)
        of NONE =>let val E.V(i)=v in E.V(i+shift) end
        | SOME j=>j
        (*end case*))

    fun mapId2(i ,dict,shift)= (case (lookup i dict)
        of NONE =>(print "Err out of range";i+shift)
        | SOME j=>j
        (*end case*))

    fun rewriteSubst(e,subId,mx,paramShift,sumShift,newArgs,done)=let
        fun insertIndex([],_,dict,shift)=(dict,shift)
        | insertIndex(e1::es, n,dict,_)=(case e1
            of E.V e=> insertIndex(es, n+1, insert(E.V n ,E.V e) dict,e-n)
            | E.C (e,flag) => insertIndex(es, n+1, insert(E.V n ,E.C (e,flag)) dict,e-n)
            (*end case*))

        val (subMu,shift)=insertIndex(mx,0,empty,0)
        val shift'=Int.max(sumShift, shift)
        fun  mapMu(E.V i)= mapIndex((E.V i), subMu, shift')
            | mapMu c = c 
        fun mapAlpha mx=List.map mapMu mx
        fun mapSingle(i)=let
            val E.V v=mapIndex(E.V i,subMu, shift')
        fun mapSum []=[]
            | mapSum ((a,b,c)::e)=[((mapMu a),b,c)]@mapSum(e)

        fun mapParam(id)= mapId2(id, subId, 0)
        fun mapParam id = let
            val vA=List.nth(newArgs,id)
            fun iter([],_)=mapId2(id, subId, 0)
            | iter(e1::es,n)=
                if(HighIL.Var.same(e1,vA)) then n
                else iter(es,n+1)
            in iter(done,0) end
        fun apply e=(case e
            of E.B _          => e
            | E.Tensor(id, mx)          => E.Tensor(mapParam id,mapAlpha mx)
            | E.G(E.Delta (i,j))        => E.G(E.Delta(mapMu i,mapMu j))
            | E.G(E.Epsilon(i, j, k))   => E.G(E.Epsilon(mapMu i, mapMu j, mapMu k))
            | E.G(E.Eps2(i, j))         => E.G(E.Eps2(mapMu i, mapMu j))
            | E.Field(id, mx)           => E.Field(mapParam id,mapAlpha mx)
            | E.Lift e1                 => E.Lift(apply e1)
            | E.Conv (v,mx,h,ux)        => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)
            | E.Partial mx              => E.Partial (mapAlpha mx)
            | E.Apply(e1, e2)           => E.Apply(apply e1, apply e2)
            | E.Probe(f, pos)           => E.Probe(apply f, apply pos)
            | E.Value _                 => raise Fail "expression before expand"
            | E.Img  _                  => raise Fail "expression before expand"
            | E.Krn _                   => raise Fail "expression before expand"
            | E.Sum(c,esum)             => E.Sum(mapSum c, apply esum)
            | E.Op1(E.PowEmb(sx,n),e1)  => E.Op1(E.PowEmb(mapSum sx,n),apply e1)
            | E.Op1(op1,e1)             => E.Op1(op1,apply e1)
            | E.Op2(op2,e1,e2)          => E.Op2(op2,apply e1,apply e2)
            | E.Opn(opn,e1)             => E.Opn(opn,List.map apply e1)
            (*end case*))
                apply e

    (*params subst*)
    fun rewriteParams(params, params2, place)=let
        val beg=List.take(params,place)
        val next=List.drop(params,place+1)
        val params'=beg@params2@next
        val n= length(params)
        val n2=length(params2)
        val nbeg=length(beg)
        val nnext=length(next)
        fun createDict(0,shift1, shift2,dict)= dict
          | createDict(n,shift1, shift2,dict)=createDict(n-1,shift1,shift2, insert(n+shift1,n+shift2) dict)
        val origId=createDict(nnext,place,place+n2-1,empty)
        val subId=createDict(n2,~1,place-1,empty)

    (*Looks for params id that match substitution*)
    fun app(E.EIN{params, index, body},place,e2,newArgs,done)=let
        val e1=E.EIN{params=params, index=index, body=body}
val _ = print("\n\n**original: "^P.printerE(e1))
val _ = print("\n\n**replacing at: "^Int.toString(place)^"-"^P.printerE (e2))
        val changed = ref 0
        val params2=E.params e2
        val index2=E.index e2
        val body2=E.body e2
        val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)
        fun err(mx)=String.concat["\ne1:",P.printerE(e1),"\ne2:",P.printerE(e2)]

        val sumIndex=ref (length index)
        fun rewrite(id,mx ,e)=let
val _ = print (String.concat["\n\tid", Int.toString(id), "\n\tmx-", Int.toString(length(mx)),  "\n\tindex2-", Int.toString(length(index2)),"\n\te:",P.printbody(e)])
            val ref x=sumIndex
            if(id=place) then
                if(length(mx)=length(index2)) then
                else ( raise Fail(err mx);E.B(E.Const 0))
            else (case e
                of E.Tensor(id,mx) => E.Tensor(mapId(id,origId,0), mx)
                | E.Field(id,mx) =>    E.Field(mapId(id,origId,0), mx)
                |  _ => raise Fail"Id error:Term to be replaced is not a Tensor or Fields"
                (*end case*))
        fun sumI(e)=let
            val (E.V v,_,_)=List.nth(e, length(e)-1)
            in v end

fun apply b=(case b
            of E.B _                    => b
            | E.Tensor(id, mx)          => rewrite (id,mx,b)
            | E.G _                     => b
            | E.Field(id, mx)           => rewrite (id,mx,b)
            | E.Lift e1                 => E.Lift(apply e1)
            | E.Conv (v,mx,h,ux)        => E.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)
            | E.Partial mx              => b
            | E.Apply(e1, e2)           => E.Apply(apply e1, apply e2)
            | E.Probe(f, pos)           => E.Probe(apply f, apply pos)
            | E.Value _                 => raise Fail "expression before expand"
            | E.Img  _                  => raise Fail "expression before expand"
            | E.Krn _                   => raise Fail "expression before expand"
            | E.Sum(c,esum)             => (sumIndex:=sumI(c); E.Sum(c, apply esum))
            | E.Op1(E.PowEmb(sx,n),e1)    => (sumIndex:=sumI(sx);E.Op1(E.PowEmb(sx,n),apply e1))
            | E.Op1(op1, e1)            => E.Op1(op1,apply e1)
            | E.Op2(op2, e1,e2)         => E.Op2(op2,apply e1,apply e2)
            | E.Opn(opn, es)            => E.Opn(opn,List.map apply es)
        (*end case*))

        val body''=apply body

        val ref g=changed
            ( g,E.EIN{params=params', index=index, body=body''})

  end; (* local *)

    end (* local *)

ViewVC Help
Powered by ViewVC 1.0.0