Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-il/app-ein.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-il/app-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2603 - (view) (download)

1 : cchiw 2507 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure App = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure P = Printer
13 :     in
14 :    
15 :    
16 :    
17 :    
18 :    
19 :     fun insert (key, value) d =fn s =>
20 : cchiw 2510 if s = key then SOME value
21 :     else d s
22 : cchiw 2507
23 :     fun lookup k d = d k
24 :     val empty =fn key =>NONE
25 :    
26 :     fun mapId(i ,dict,shift)= let
27 :     val l =lookup i dict
28 :     in (case l
29 :     of NONE =>i+shift
30 :     | SOME(j)=>j
31 :     (*end case*))
32 :     end
33 :    
34 : cchiw 2508 fun mapIndex((E.V i) ,dict,shift)= let
35 :     val l =lookup (E.V i) dict
36 :     in (case l
37 :     of NONE =>E.V(i+shift)
38 :     | SOME(j)=>j
39 :     (*end case*))
40 :     end
41 : cchiw 2507
42 : cchiw 2508
43 : cchiw 2507 fun mapId2(i ,dict,shift)= let
44 : cchiw 2515 val l =lookup i dict
45 :     in (case l
46 :     of NONE =>(print "Err out of range";i+shift)
47 :     | SOME(j)=>j
48 :     (*end case*))
49 :     end
50 : cchiw 2507
51 :    
52 :    
53 :    
54 :    
55 : cchiw 2515 fun rewriteSubst(e,subId,mx,paramShift,sumShift)=let
56 : cchiw 2507 fun insertIndex([],_,dict,shift)=(dict,shift)
57 : cchiw 2508 | insertIndex(E.V e::es, n,dict,_)= insertIndex(es, n+1, insert(E.V n ,E.V e) dict,e-n)
58 :     | insertIndex(E.C e::es, n,dict,_)= insertIndex(es, n+1, insert(E.V n ,E.C e) dict,e-n)
59 :    
60 : cchiw 2507 val (subMu,shift)=insertIndex(mx,0,empty,0)
61 :     val shift'=Int.max(sumShift, shift)
62 : cchiw 2508 fun mapMu(E.V i)= mapIndex((E.V i), subMu, shift')
63 :     | mapMu c = c
64 : cchiw 2507 fun mapAlpha mx=List.map mapMu mx
65 : cchiw 2508 fun mapSingle(i)=let val E.V v=mapIndex(E.V i,subMu, shift')
66 :     in v end
67 : cchiw 2507 fun mapSum []=[]
68 :     | mapSum ((a,b,c)::e)=[((mapMu a),b,c)]@mapSum(e)
69 : cchiw 2508 fun mapParam(id)= mapId2(id, subId, 0)
70 : cchiw 2507 fun apply e=(case e
71 :     of Ein.Const _ => e
72 :     | Ein.Tensor(id, mx) => Ein.Tensor(mapParam id,mapAlpha mx)
73 :     | Ein.Field(id, mx) => Ein.Field(mapParam id,mapAlpha mx)
74 :     | Ein.Krn(id,deltas,pos)=> Ein.Krn(mapParam id, deltas,apply pos)
75 :     | Ein.Delta (i,j) => Ein.Delta(mapMu i,mapMu j)
76 :     | Ein.Value _=> e
77 : cchiw 2508 | Ein.Epsilon(i, j, k) =>Ein.Epsilon(mapSingle i, mapSingle j, mapSingle k)
78 : cchiw 2507 | Ein.Sum(c,esum)=> Ein.Sum(mapSum c, apply esum)
79 :     | Ein.Neg e => Ein.Neg(apply e)
80 : cchiw 2603 | Ein.Lift e => Ein.Lift(apply e)
81 : cchiw 2507 | Ein.Add es => Ein.Add(List.map apply es)
82 :     | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
83 :     | Ein.Prod es => Ein.Prod(List.map apply es)
84 :     | Ein.Div(e1, e2)=> Ein.Div(apply e1, apply e2)
85 :     | Ein.Partial mx => E.Partial (mapAlpha mx)
86 :     | Ein.Apply(e1, e2)=> Ein.Apply(apply e1, apply e2)
87 : cchiw 2510 | Ein.Conv (v,mx,h,ux) =>(Ein.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux))
88 : cchiw 2507 | Ein.Probe(f, pos) => Ein.Probe(apply f, apply pos)
89 :     | Ein.Img(id,mx,pos)=> Ein.Img(mapParam id,mapAlpha mx, (List.map apply pos))
90 :     (*end case*))
91 :     in apply e end
92 :    
93 :    
94 :    
95 :    
96 :     (*params subst*)
97 :     fun rewriteParams(params, params2, place)=let
98 :     val beg=List.take(params,place)
99 :     val next=List.drop(params,place+1)
100 :     val params'=beg@params2@next
101 :     val n= length(params)
102 :     val n2=length(params2)
103 :     val nbeg=length(beg)
104 :     val nnext=length(next)
105 :    
106 :     fun createDict(0,shift1, shift2,dict)= dict
107 : cchiw 2515 | createDict(n,shift1, shift2,dict)=createDict(n-1,shift1,shift2, insert(n+shift1,n+shift2) dict)
108 : cchiw 2507
109 :     val origId=createDict(nnext,place,place+n2-1,empty)
110 :     val subId=createDict(n2,~1,place-1,empty)
111 :    
112 :     in (params',origId,subId,nbeg) end
113 :    
114 :    
115 :     fun splitEin(Ein.EIN{params, index, body})=(params,index,body)
116 :    
117 :     (*Looks for params id that match substitution*)
118 :     fun app(Ein.EIN{params, index, body},place,e2)=let
119 :    
120 : cchiw 2515 val changed = ref 0
121 :    
122 : cchiw 2507 val (params2,index2,body2)=splitEin(e2)
123 :     val (params',origId,substId,paramShift)=rewriteParams(params,params2,place)
124 :     val err="Wrong size for Subst"
125 :    
126 :     val sumIndex=ref (length index)
127 :     fun rewrite(id,mx ,e)=let
128 :     val ref x=sumIndex
129 :     in
130 :     if(id=place) then
131 : cchiw 2515 if(length(mx)=length(index2)) then
132 :     (changed:=1; rewriteSubst(body2,substId,mx,paramShift,x))
133 : cchiw 2553 else ( raise Fail(err);E.Const 0)
134 : cchiw 2507 else (case e
135 :     of E.Tensor(id,mx)=>E.Tensor(mapId(id,origId,0), mx)
136 :     | E.Field(id,mx)=> E.Field(mapId(id,origId,0), mx)
137 :     (*end case*))
138 :     end
139 :     fun sumI(e)=let
140 :     val (E.V v,_,_)=List.nth(e, length(e)-1)
141 :     in v end
142 :    
143 :     fun apply e=(case e
144 :     of Ein.Const _ => e
145 :     | Ein.Tensor(id, mx) =>rewrite (id,mx,e)
146 :     | Ein.Field(id, mx) => rewrite (id,mx,e)
147 :     | Ein.Krn(id,deltas,pos)=> Ein.Krn(mapId(id,origId,0), deltas,apply pos)
148 :     | Ein.Delta _ => e
149 :     | Ein.Value _=> e
150 :     | Ein.Epsilon(i, j, k) => e
151 :     | Ein.Sum(c,esum)=> (sumIndex:=sumI(c); Ein.Sum( c, apply esum))
152 : cchiw 2603 | Ein.Lift e => Ein.Lift(apply e)
153 : cchiw 2507 | Ein.Neg e => Ein.Neg(apply e)
154 :     | Ein.Add es => Ein.Add(List.map apply es)
155 :     | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
156 :     | Ein.Prod es => Ein.Prod(List.map apply es)
157 :     | Ein.Div(e1, e2)=> Ein.Div(apply e1, apply e2)
158 :     | Ein.Partial mx => e
159 :     | Ein.Apply(e1, e2)=> Ein.Apply(apply e1, apply e2)
160 :     | Ein.Conv (v,mx,h,ux) =>Ein.Conv(mapId(v, origId,0), mx, mapId(h,origId,0), ux)
161 : cchiw 2521 | Ein.Probe(f, pos) => Ein.Probe(apply f, apply pos)
162 : cchiw 2507 | Ein.Img(id,mx,pos)=> Ein.Img(mapId(id,origId,0),mx, (List.map apply pos))
163 :     (*end case*))
164 :     val body''=apply body
165 : cchiw 2515 val ref g=changed
166 : cchiw 2507 in
167 : cchiw 2515 ( g,Ein.EIN{params=params', index=index, body=body''})
168 : cchiw 2507 end
169 :    
170 :    
171 :     end; (* local *)
172 :    
173 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0