Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-il/app-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-il/app-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2521 - (download) (annotate)
Thu Jan 9 02:17:07 2014 UTC (5 years, 7 months ago) by cchiw
File size: 5346 byte(s)
Added type Checker
(* examples.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure App = struct

    local
   
    structure E = Ein
    structure P = Printer
    in




 
fun insert (key, value) d =fn s =>
    if s = key then SOME value
    else d s
 
fun lookup k d = d k
val empty =fn key =>NONE

fun mapId(i ,dict,shift)= let
    val l =lookup i dict
    in (case l
        of NONE =>i+shift
        | SOME(j)=>j
    (*end case*))
    end

fun mapIndex((E.V i) ,dict,shift)= let
val l =lookup (E.V i) dict
    in (case l
    of NONE =>E.V(i+shift)
    | SOME(j)=>j
(*end case*))
end


fun mapId2(i ,dict,shift)= let
    val l =lookup i dict  
    in (case l
        of NONE =>(print "Err out of range";i+shift)
        | SOME(j)=>j
        (*end case*))
    end





fun rewriteSubst(e,subId,mx,paramShift,sumShift)=let
    fun insertIndex([],_,dict,shift)=(dict,shift)
    | insertIndex(E.V e::es, n,dict,_)= insertIndex(es, n+1, insert(E.V n ,E.V e) dict,e-n)
    | insertIndex(E.C e::es, n,dict,_)= insertIndex(es, n+1, insert(E.V n ,E.C e) dict,e-n)
   
    val (subMu,shift)=insertIndex(mx,0,empty,0)
    val shift'=Int.max(sumShift, shift)
    fun  mapMu(E.V i)= mapIndex((E.V i), subMu, shift')
        | mapMu c = c 
    fun mapAlpha mx=List.map mapMu mx
    fun mapSingle(i)=let val E.V v=mapIndex(E.V i,subMu, shift')
              in v end 
    fun mapSum []=[]
        | mapSum ((a,b,c)::e)=[((mapMu a),b,c)]@mapSum(e)
    fun mapParam(id)= mapId2(id, subId, 0)
    fun apply e=(case e
        of Ein.Const _ => e
        | Ein.Tensor(id, mx) => Ein.Tensor(mapParam id,mapAlpha mx)
        | Ein.Field(id, mx) => Ein.Field(mapParam id,mapAlpha mx)
        | Ein.Krn(id,deltas,pos)=> Ein.Krn(mapParam id, deltas,apply pos)
        | Ein.Delta (i,j) => Ein.Delta(mapMu i,mapMu j)
        | Ein.Value _=> e
        | Ein.Epsilon(i, j, k) =>Ein.Epsilon(mapSingle i, mapSingle j, mapSingle k)
        | Ein.Sum(c,esum)=> Ein.Sum(mapSum c, apply esum)
        | Ein.Neg e => Ein.Neg(apply e)
        | Ein.Add es => Ein.Add(List.map apply es)
        | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
        | Ein.Prod es => Ein.Prod(List.map apply es)
        | Ein.Div(e1, e2)=> Ein.Div(apply e1, apply e2)
        | Ein.Partial mx => E.Partial (mapAlpha mx)
        | Ein.Apply(e1, e2)=> Ein.Apply(apply e1, apply e2)
        | Ein.Conv (v,mx,h,ux) =>(Ein.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux))
        | Ein.Probe(f, pos) => Ein.Probe(apply f, apply pos)
        | Ein.Img(id,mx,pos)=> Ein.Img(mapParam id,mapAlpha mx, (List.map apply pos))
        (*end case*))
    in apply e end 




(*params subst*)
fun rewriteParams(params, params2, place)=let
    val beg=List.take(params,place)
    val next=List.drop(params,place+1)
    val params'=beg@params2@next
    val n= length(params)
    val n2=length(params2)
    val nbeg=length(beg)
    val nnext=length(next)

    fun createDict(0,shift1, shift2,dict)= dict
        | createDict(n,shift1, shift2,dict)=createDict(n-1,shift1,shift2, insert(n+shift1,n+shift2) dict)

    val origId=createDict(nnext,place,place+n2-1,empty)
    val subId=createDict(n2,~1,place-1,empty)

   in (params',origId,subId,nbeg)  end


fun splitEin(Ein.EIN{params, index, body})=(params,index,body)

(*Looks for params id that match substitution*)
fun app(Ein.EIN{params, index, body},place,e2)=let

    val changed = ref 0

    val (params2,index2,body2)=splitEin(e2)
    val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)
    val err="Wrong size for Subst"

    val sumIndex=ref (length index)
    fun rewrite(id,mx ,e)=let
        val ref x=sumIndex
        in 
        if(id=place) then
            if(length(mx)=length(index2)) then
                (changed:=1; rewriteSubst(body2,substId,mx,paramShift,x))
            else ( raise Fail(err);E.Const 0.0)
        else (case e
            of E.Tensor(id,mx)=>E.Tensor(mapId(id,origId,0), mx)
            | E.Field(id,mx)=> E.Field(mapId(id,origId,0), mx)
            (*end case*))
        end
    fun sumI(e)=let
        val (E.V v,_,_)=List.nth(e, length(e)-1)
        in v end

    fun apply e=(case e
        of Ein.Const _ => e
        | Ein.Tensor(id, mx) =>rewrite (id,mx,e)
        | Ein.Field(id, mx) => rewrite (id,mx,e)
        | Ein.Krn(id,deltas,pos)=> Ein.Krn(mapId(id,origId,0), deltas,apply pos)
        | Ein.Delta _ => e
        | Ein.Value _=> e
        | Ein.Epsilon(i, j, k) => e
        | Ein.Sum(c,esum)=> (sumIndex:=sumI(c);  Ein.Sum( c, apply esum))
        | Ein.Neg e => Ein.Neg(apply e)
        | Ein.Add es => Ein.Add(List.map apply es)
        | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
        | Ein.Prod es => Ein.Prod(List.map apply es)
        | Ein.Div(e1, e2)=> Ein.Div(apply e1, apply e2)
        | Ein.Partial mx => e
        | Ein.Apply(e1, e2)=> Ein.Apply(apply e1, apply e2)
        | Ein.Conv (v,mx,h,ux) =>Ein.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)
        | Ein.Probe(f, pos) => Ein.Probe(apply f, apply pos)
        | Ein.Img(id,mx,pos)=> Ein.Img(mapId(id,origId,0),mx, (List.map apply pos))
    (*end case*))
    val body''=apply body
    val ref g=changed
    in
        ( g,Ein.EIN{params=params', index=index, body=body''})
    end

         
  end; (* local *)

    end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0