Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-opt/apply.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-opt/apply.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3515, Sat Dec 19 03:01:31 2015 UTC revision 3521, Sat Dec 19 16:43:35 2015 UTC
# Line 10  Line 10 
10    
11  structure Apply : sig  structure Apply : sig
12    
13        val apply : Ein.ein * int * Ein.ein -> (bool * Ein.ein)
14    
15    end = struct    end = struct
16    
17      structure E = Ein      structure E = Ein
18    
19      fun insert (key, value) d = fn s =>      structure IMap = IntRedBlackMap
         if s = key then SOME value  
         else d s  
   
     fun lookup k d = d k  
     val empty =fn key =>NONE  
20    
21      fun mapId(i ,dict,shift)=(case (lookup i dict)      fun mapId (i, dict, shift) = (case IMap.find(dict, i)
22          of NONE =>i+shift          of NONE =>i+shift
23          | SOME j=>j          | SOME j=>j
24          (*end case*))          (*end case*))
25    
26      fun mapIndex(v ,dict,shift)= (case (lookup v dict)      fun mapIndex (ix, dict, shift) = (case IMap.find(dict, ix)
27          of NONE =>let val E.V(i)=v in E.V(i+shift) end             of NONE => E.V(ix + shift)
28          | SOME j=>j          | SOME j=>j
29          (*end case*))          (*end case*))
30    
31      fun mapId2(i ,dict,shift)= (case (lookup i dict)      fun mapId2 (i, dict, shift) = (case IMap.find(dict, i)
32          of NONE =>(print "Err out of range";i+shift)             of NONE => (
33                    print(concat["Error: ", Int.toString i, " is out of range\n"]);
34                    i+shift)
35          | SOME j=>j          | SOME j=>j
36          (*end case*))          (*end case*))
37    
38      fun rewriteSubst(e,subId,mx,paramShift,sumShift,newArgs,done)=let      fun rewriteSubst (e, subId, mx, paramShift, sumShift) = let
39          fun insertIndex([],_,dict,shift)=(dict,shift)          fun insertIndex([],_,dict,shift)=(dict,shift)
40          | insertIndex(e1::es, n,dict,_)=(case e1              | insertIndex (e::es, n, dict, _) = let
41              of E.V e=> insertIndex(es, n+1, insert(E.V n ,E.V e) dict,e-n)                  val shift = (case e of E.V ix => ix - n | E.C i => i - n)
42               | E.C e => insertIndex(es, n+1, insert(E.V n ,E.C e) dict,e-n)                  in
43              (*end case*))                    insertIndex(es, n+1, IMap.insert(dict, n, e), shift)
44                    end
45          val (subMu,shift)=insertIndex(mx,0,empty,0)            val (subMu, shift) = insertIndex(mx, 0, IMap.empty, 0)
46          val shift'=Int.max(sumShift, shift)          val shift'=Int.max(sumShift, shift)
47          fun  mapMu(E.V i)= mapIndex((E.V i), subMu, shift')            fun mapMu (E.V i) = mapIndex(i, subMu, shift')
48              | mapMu c = c              | mapMu c = c
49          fun mapAlpha mx=List.map mapMu mx          fun mapAlpha mx=List.map mapMu mx
50          fun mapSingle(i)=let            fun mapSingle i = let
51              val E.V v=mapIndex(E.V i,subMu, shift')                  val E.V v = mapIndex(i, subMu, shift')
52              in              in
53                  v                  v
54              end              end
55          fun mapSum []=[]            fun mapSum l = List.map (fn (a, b, c) => (mapMu a, b, c)) l
56              | mapSum ((a,b,c)::e)=[((mapMu a),b,c)]@mapSum(e)            fun mapParam id = mapId2(id, subId, 0)
         fun mapParam(id)= mapId2(id, subId, 0)  
57          fun apply e=(case e          fun apply e=(case e
58              of E.B _          => e                   of E.Const _ => e
59                      | E.ConstR _ => e
60              | E.Tensor(id, mx)          => E.Tensor(mapParam id,mapAlpha mx)              | E.Tensor(id, mx)          => E.Tensor(mapParam id,mapAlpha mx)
61              | E.G(E.Delta (i,j))        => E.G(E.Delta(mapMu i,mapMu j))                    | E.Delta(i, j) => E.Delta(mapMu i,mapMu j)
62              | E.G(E.Epsilon(i, j, k))   => E.G(E.Epsilon(mapSingle i, mapSingle j, mapSingle k))                    | E.Epsilon(i, j, k) => E.Epsilon(mapSingle i, mapSingle j, mapSingle k)
63              | E.G(E.Eps2(i, j))         => E.G(E.Eps2(mapSingle i, mapSingle j))                    | E.Eps2(i, j) => E.Eps2(mapSingle i, mapSingle j)
64              | E.Field(id, mx)           => E.Field(mapParam id,mapAlpha mx)              | E.Field(id, mx)           => E.Field(mapParam id,mapAlpha mx)
65              | E.Lift e1                 => E.Lift(apply e1)              | E.Lift e1                 => E.Lift(apply e1)
66              | E.Conv (v,mx,h,ux)        => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)              | E.Conv (v,mx,h,ux)        => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)
# Line 86  Line 85 
85          val beg=List.take(params,place)          val beg=List.take(params,place)
86          val next=List.drop(params,place+1)          val next=List.drop(params,place+1)
87          val params'=beg@params2@next          val params'=beg@params2@next
88          val n= length(params)            val n= length params
89          val n2=length(params2)            val n2 = length params2
90          val nbeg=length(beg)            val nbeg = length beg
91          val nnext=length(next)            val nnext = length next
92          fun createDict(0,shift1, shift2,dict)= dict          fun createDict(0,shift1, shift2,dict)= dict
93            | createDict(n,shift1, shift2,dict)=createDict(n-1,shift1,shift2, insert(n+shift1,n+shift2) dict)              | createDict (n, shift1, shift2, dict) =
94          val origId=createDict(nnext,place,place+n2-1,empty)                  createDict (n-1, shift1, shift2, IMap.insert (dict, n+shift1, n+shift2))
95          val subId=createDict(n2,~1,place-1,empty)            val origId = createDict (nnext, place, place+n2-1, IMap.empty)
96              val subId = createDict (n2, ~1, place-1, IMap.empty)
97          in          in
98              (params',origId,subId,nbeg)              (params',origId,subId,nbeg)
99          end          end
100    
   
101    (*Looks for params id that match substitution*)    (*Looks for params id that match substitution*)
102      fun apply (E.EIN{params, index, body}, place, e2, newArgs, done) = let      fun apply (E.EIN{params, index, body}, place, e2) = let
103          val changed = ref 0            val E.EIN{params=params2, index=index2, body=body2} = e2
104          val params2=E.params e2            val changed = ref false
         val index2=E.index e2  
         val body2=E.body e2  
105          val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)          val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)
         val err=String.concat["Wrong size for Subst:",  
                 P.printbody body,"-with-",P.printbody body2,"@",Int.toString place]  
   
106          val sumIndex=ref (length index)          val sumIndex=ref (length index)
107          fun rewrite(id,mx ,e)=let          fun rewrite(id,mx ,e)=let
108              val ref x=sumIndex                  val x = !sumIndex
109              in              in
110              if(id=place) then                    if (id = place)
111                  if(length(mx)=length(index2)) then                      then if (length mx = length index2)
112                      (changed:=1; rewriteSubst(body2,substId,mx,paramShift,x,newArgs,done))                        then (
113                  else ( raise Fail(err);E.B(E.Const 0))                          changed := true;
114                            rewriteSubst (body2, substId, mx, paramShift, x))
115                          else raise Fail "argument/parameter mismatch"
116              else (case e              else (case e
117                  of E.Tensor(id,mx) => E.Tensor(mapId(id,origId,0), mx)                  of E.Tensor(id,mx) => E.Tensor(mapId(id,origId,0), mx)
118                  | E.Field(id,mx) =>    E.Field(mapId(id,origId,0), mx)                  | E.Field(id,mx) =>    E.Field(mapId(id,origId,0), mx)
119                  |  _ => raise Fail"Id error:Term to be replaced is not a Tensor or Fields"                          |  _ => raise Fail "term to be replaced is not a Tensor or Fields"
120                  (*end case*))                  (*end case*))
121              end              end
122          fun sumI(e)=let            fun sumI e = let val (E.V v,_,_) = List.last e in v end
             val (E.V v,_,_)=List.nth(e, length(e)-1)  
             in v end  
   
123          fun apply b=(case b          fun apply b=(case b
124              of E.B _                    => b                   of E.Tensor(id, mx) => rewrite (id, mx, b)
             | E.Tensor(id, mx)          => rewrite (id,mx,b)  
             | E.G _                     => b  
125              | E.Field(id, mx)           => rewrite (id,mx,b)              | E.Field(id, mx)           => rewrite (id,mx,b)
126              | E.Lift e1                 => E.Lift(apply e1)              | E.Lift e1                 => E.Lift(apply e1)
127              | E.Conv (v,mx,h,ux)        => E.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)              | E.Conv (v,mx,h,ux)        => E.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)
             | E.Partial mx              => b  
128              | E.Apply(e1, e2)           => E.Apply(apply e1, apply e2)              | E.Apply(e1, e2)           => E.Apply(apply e1, apply e2)
129              | E.Probe(f, pos)           => E.Probe(apply f, apply pos)              | E.Probe(f, pos)           => E.Probe(apply f, apply pos)
130              | E.Value _                 => raise Fail "expression before expand"              | E.Value _                 => raise Fail "expression before expand"
131              | E.Img  _                  => raise Fail "expression before expand"              | E.Img  _                  => raise Fail "expression before expand"
132              | E.Krn _                   => raise Fail "expression before expand"              | E.Krn _                   => raise Fail "expression before expand"
133              | E.Sum(c,esum)             => (sumIndex:=sumI(c); E.Sum(c, apply esum))                    | E.Sum(c, esum) => (
134              | E.Op1(E.PowEmb(sx,n),e1)    => (sumIndex:=sumI(sx);E.Op1(E.PowEmb(sx,n),apply e1))  (* QUESTION: should we flag a change here? *)
135                          sumIndex := sumI c;
136                          E.Sum(c, apply esum))
137                      | E.Op1(E.PowEmb(sx, n), e1) => (
138    (* QUESTION: should we flag a change here? *)
139                          sumIndex := sumI sx;
140                          E.Op1(E.PowEmb(sx, n), apply e1))
141              | E.Op1(op1, e1)            => E.Op1(op1,apply e1)              | E.Op1(op1, e1)            => E.Op1(op1,apply e1)
142              | E.Op2(op2, e1,e2)         => E.Op2(op2,apply e1,apply e2)              | E.Op2(op2, e1,e2)         => E.Op2(op2,apply e1,apply e2)
143              | E.Opn(opn, es)            => E.Opn(opn,List.map apply es)              | E.Opn(opn, es)            => E.Opn(opn,List.map apply es)
144                      | _ => b
145          (*end case*))          (*end case*))
146          val body''=apply body          val body''=apply body
         val ref g=changed  
147          in          in
148              ( g,E.EIN{params=params', index=index, body=body''})  (* QUESTION: can we do the following?
149                if (! changed) then SOME(E.EIN{params=params', index=index, body=body''}) else NONE
150    *)
151                (!changed, E.EIN{params=params', index=index, body=body''})
152          end          end
153    
154      end      end

Legend:
Removed from v.3515  
changed lines
  Added in v.3521

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0