Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-il/app-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-il/app-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2605 - (download) (annotate)
Wed Apr 30 01:46:09 2014 UTC (7 years, 2 months ago) by cchiw
File size: 5721 byte(s)
code cleanup
(* examples.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure App = struct

    local
   
    structure E = Ein
    structure P = Printer
    in

fun insert (key, value) d =fn s =>
    if s = key then SOME value
    else d s
 
fun lookup k d = d k
val empty =fn key =>NONE

fun mapId(i ,dict,shift)=(case (lookup i dict)
    of NONE =>i+shift
    | SOME(j)=>j
    (*end case*))

fun mapIndex(v ,dict,shift)= (case (lookup v dict)
    of NONE =>let val E.V(i)=v in E.V(i+shift) end
    | SOME(j)=>j
    (*end case*))

fun mapId2(i ,dict,shift)= (case (lookup i dict)
    of NONE =>(print "Err out of range";i+shift)
    | SOME(j)=>j
    (*end case*))

fun rewriteSubst(e,subId,mx,paramShift,sumShift)=let

    fun insertIndex([],_,dict,shift)=(dict,shift)
    | insertIndex(e1::es, n,dict,_)=(case e1
        of E.V e=> insertIndex(es, n+1, insert(E.V n ,E.V e) dict,e-n)
        | E.C e => insertIndex(es, n+1, insert(E.V n ,E.C e) dict,e-n)
        (*end case*))

    val (subMu,shift)=insertIndex(mx,0,empty,0)
    val shift'=Int.max(sumShift, shift)
    fun  mapMu(E.V i)= mapIndex((E.V i), subMu, shift')
        | mapMu c = c 
    fun mapAlpha mx=List.map mapMu mx
    fun mapSingle(i)=let val E.V v=mapIndex(E.V i,subMu, shift')
              in v end 
    fun mapSum []=[]
        | mapSum ((a,b,c)::e)=[((mapMu a),b,c)]@mapSum(e)
    fun mapParam(id)= mapId2(id, subId, 0)
    fun apply e=(case e
        of Ein.Const _          => e
        | Ein.Value _           => raise Fail "expression before expand"
        | Ein.Krn _             => raise Fail "expression before expand"
        | Ein.Img  _            => raise Fail "expression before expand"
        | Ein.Tensor(id, mx)    => Ein.Tensor(mapParam id,mapAlpha mx)
        | Ein.Field(id, mx)     => Ein.Field(mapParam id,mapAlpha mx)
        | Ein.Delta (i,j)       => Ein.Delta(mapMu i,mapMu j)
        | Ein.Epsilon(i, j, k)  => Ein.Epsilon(mapSingle i, mapSingle j, mapSingle k)
        | Ein.Sum(c,esum)       => Ein.Sum(mapSum c, apply esum)
        | Ein.Neg e             => Ein.Neg(apply e)
        | Ein.Lift e            => Ein.Lift(apply e)
        | Ein.Add es            => Ein.Add(List.map apply es)
        | Ein.Sub(e1, e2)       => Ein.Sub(apply e1, apply e2)
        | Ein.Prod es           => Ein.Prod(List.map apply es)
        | Ein.Div(e1, e2)       => Ein.Div(apply e1, apply e2)
        | Ein.Partial mx        => E.Partial (mapAlpha mx)
        | Ein.Apply(e1, e2)     => Ein.Apply(apply e1, apply e2)
        | Ein.Conv (v,mx,h,ux)  => Ein.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)
        | Ein.Probe(f, pos)     => Ein.Probe(apply f, apply pos)
        (*end case*))
    in apply e end 




(*params subst*)
fun rewriteParams(params, params2, place)=let
    val beg=List.take(params,place)
    val next=List.drop(params,place+1)
    val params'=beg@params2@next
    val n= length(params)
    val n2=length(params2)
    val nbeg=length(beg)
    val nnext=length(next)

    fun createDict(0,shift1, shift2,dict)= dict
        | createDict(n,shift1, shift2,dict)=createDict(n-1,shift1,shift2, insert(n+shift1,n+shift2) dict)

    val origId=createDict(nnext,place,place+n2-1,empty)
    val subId=createDict(n2,~1,place-1,empty)
    in (params',origId,subId,nbeg)  end


fun splitEin(Ein.EIN{params, index, body})=(params,index,body)

(*Looks for params id that match substitution*)
fun app(Ein.EIN{params, index, body},place,e2)=let

    val changed = ref 0

    val (params2,index2,body2)=splitEin(e2)
    val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)
    val err="Wrong size for Subst"

    val sumIndex=ref (length index)
    fun rewrite(id,mx ,e)=let
        val ref x=sumIndex
        in 
        if(id=place) then
            if(length(mx)=length(index2)) then
                (changed:=1; rewriteSubst(body2,substId,mx,paramShift,x))
            else ( raise Fail(err);E.Const 0)
        else (case e
            of E.Tensor(id,mx)=>E.Tensor(mapId(id,origId,0), mx)
            | E.Field(id,mx)=> E.Field(mapId(id,origId,0), mx)
            |  _ => raise Fail"Id error:Term to be replaced is not a Tensor or Fields"
            (*end case*))
        end
    fun sumI(e)=let
        val (E.V v,_,_)=List.nth(e, length(e)-1)
        in v end

    fun apply e=(case e
        of Ein.Value _          => raise Fail "expression before expand"
        | Ein.Krn _             => raise Fail "expression before expand"
        | Ein.Img  _            => raise Fail "expression before expand"
        | Ein.Const _           => e
        | Ein.Delta _           => e
        | Ein.Epsilon(i, j, k)  => e
        | Ein.Partial mx        => e
        | Ein.Tensor(id, mx)    => rewrite (id,mx,e)
        | Ein.Field(id, mx)     => rewrite (id,mx,e)
        | Ein.Sum(c,esum)       => (sumIndex:=sumI(c); Ein.Sum(c, apply esum))
        | Ein.Lift e            => Ein.Lift(apply e)
        | Ein.Neg e             => Ein.Neg(apply e)
        | Ein.Add es            => Ein.Add(List.map apply es)
        | Ein.Sub(e1, e2)       => Ein.Sub(apply e1, apply e2)
        | Ein.Prod es           => Ein.Prod(List.map apply es)
        | Ein.Div(e1, e2)       => Ein.Div(apply e1, apply e2)
        | Ein.Apply(e1, e2)     => Ein.Apply(apply e1, apply e2)
        | Ein.Conv (v,mx,h,ux)  => Ein.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)
        | Ein.Probe(f, pos)     => Ein.Probe(apply f, apply pos)
        
    (*end case*))
    val body''=apply body
    val ref g=changed
    in
        ( g,Ein.EIN{params=params', index=index, body=body''})
    end

         
  end; (* local *)

    end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0