Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2844 - (view) (download)
Original Path: branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

1 : cchiw 2608 (* Currently under construction
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure ProbeEin = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure mk= mkOperators
13 :     structure SrcIL = HighIL
14 :     structure SrcTy = HighILTypes
15 :     structure SrcOp = HighOps
16 :     structure SrcSV = SrcIL.StateVar
17 :     structure VTbl = SrcIL.Var.Tbl
18 :     structure DstIL = MidIL
19 :     structure DstTy = MidILTypes
20 :     structure DstOp = MidOps
21 :     structure DstV = DstIL.Var
22 :     structure SrcV = SrcIL.Var
23 :     structure P=Printer
24 : cchiw 2611 structure T=TransformEin
25 : cchiw 2838 structure split=Split
26 : cchiw 2843 structure cleanI=cleanIndex
27 : cchiw 2606
28 :    
29 : cchiw 2843 val testing=1
30 : cchiw 2827
31 : cchiw 2843
32 : cchiw 2606 in
33 :    
34 : cchiw 2843
35 :     (* This file expands probed fields
36 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
37 :     * Param_ids are used to note the placement of the argument in the midIL.var list
38 : cchiw 2844 * Index_ids keep track of the shape of an Image or differentiation.
39 :     * Mu bind Index_id
40 :     * Generally, we will refer to the following
41 : cchiw 2843 *dim:dimension of field V
42 :     * s: support of kernel H
43 :     * alpha: The alpha in <V_alpha * H^(deltas)>
44 :     * deltas: The deltas in <V_alpha * H^(deltas)>
45 :     * Vid:param_id for V
46 :     * hid:param_id for H
47 :     * nid: integer position param_id
48 :     * fid :fractional position param_id
49 :     *img-imginfo about V
50 :     *)
51 :    
52 :    
53 : cchiw 2838 val cnt = ref 0
54 :     fun genName prefix = let
55 :     val n = !cnt
56 :     in
57 :     cnt := n+1;
58 :     String.concat[prefix, "_", Int.toString n]
59 :     end
60 : cchiw 2606
61 : cchiw 2838
62 : cchiw 2843 fun transformToIndexSpace e=T.transformToIndexSpace e
63 :     fun transformToImgSpace e=T.transformToImgSpace e
64 : cchiw 2606 fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
65 :     fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
66 : cchiw 2829 fun testp n=(case testing
67 :     of 0=> 1
68 :     | _ =>(print(String.concat n);1)
69 :     (*end case*))
70 : cchiw 2843 fun getRHSDst x = (case DstIL.Var.binding x
71 : cchiw 2838 of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
72 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
73 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
74 :     (* end case *))
75 : cchiw 2606
76 :    
77 : cchiw 2843 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
78 :     uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments
79 :     returns the support of ther kernel, and image
80 :     *)
81 :     fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
82 : cchiw 2838 of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
83 :     in
84 : cchiw 2843 ((Kernel.support h) ,img,ImageInfo.dim img)
85 : cchiw 2838 end
86 : cchiw 2843 | _ => raise Fail "Expected Image and kernel arguments"
87 : cchiw 2838 (*end case*))
88 : cchiw 2606
89 : cchiw 2830
90 : cchiw 2843 (*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var
91 :     * uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each
92 :     *Transforms the position to index space
93 :     *P-mid-il var for the (transformation matrix)transpose
94 :     *)
95 :     fun handleArgs(Vid,hid,tid,args)=let
96 :     val imgArg=List.nth(args,Vid)
97 : cchiw 2838 val hArg=List.nth(args,hid)
98 : cchiw 2843 val newposArg=List.nth(args,tid)
99 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
100 : cchiw 2838 val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
101 :     in (dim,args@argsT,code, s,P)
102 : cchiw 2606 end
103 :    
104 : cchiw 2608
105 : cchiw 2844 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
106 : cchiw 2843 * expands the body for the probed field
107 :     *)
108 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
109 :    
110 :     (*1-d fields*)
111 : cchiw 2838 fun createKRND1 ()=let
112 :     val sum=sx
113 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
114 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
115 : cchiw 2843 val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
116 : cchiw 2838 in
117 : cchiw 2843 E.Prod [E.Img(Vid,alpha,pos),rest]
118 : cchiw 2838
119 :     end
120 : cchiw 2606 (*createKRN Image field and kernels *)
121 : cchiw 2843 fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
122 : cchiw 2606 | createKRN(dim,imgpos,rest)=let
123 :     val dim'=dim-1
124 :     val sum=sx+dim'
125 : cchiw 2830 val dels=List.map (fn e=>(E.C dim',e)) deltas
126 : cchiw 2606 val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
127 : cchiw 2843 val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
128 : cchiw 2606 in
129 :     createKRN(dim',pos@imgpos,[rest']@rest)
130 :     end
131 : cchiw 2838 val exp=(case dim
132 :     of 1 => createKRND1()
133 :     | _=> createKRN(dim, [],[])
134 :     (*end case*))
135 :    
136 : cchiw 2843 (*sumIndex creating summaiton Index for body*)
137 :     val slb=1-s
138 :     val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
139 :     in
140 :     E.Sum(esum, exp)
141 :     end
142 : cchiw 2606
143 : cchiw 2844 (*getsumshift:sum_indexid list* int list-> int
144 : cchiw 2843 *get fresh/unused index_id, returns int
145 :     *)
146 :     fun getsumshift(sx,index) =let
147 :     val nsumshift= (case sx
148 :     of []=> length(index)
149 :     | _=>let
150 :     val (E.V v,_,_)=List.hd(List.rev sx)
151 :     in v+1
152 :     end
153 :     (* end case *))
154 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
155 :     val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
156 :     in
157 :     nsumshift
158 :     end
159 : cchiw 2606
160 : cchiw 2843 (*formBody:ein_exp->ein_exp
161 :     *just does a quick rewrite
162 :     *)
163 :     fun formBody(E.Sum([],e))=formBody e
164 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
165 :     | formBody(E.Prod [e])=e
166 :     | formBody e=e
167 : cchiw 2606
168 :    
169 : cchiw 2843 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
170 :     * Transforms position to world space
171 :     * transforms result back to index_space
172 :     * rewrites body
173 :     * replace probe with expanded version
174 :     *)
175 :     fun replaceProbe(b,params,args,index, sx)=let
176 : cchiw 2606
177 : cchiw 2843 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
178 : cchiw 2606 val fid=length(params)
179 : cchiw 2611 val nid=fid+1
180 : cchiw 2843 val Pid=nid+1
181 : cchiw 2611 val nshift=length(dx)
182 : cchiw 2843 val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
183 :     val freshIndex=getsumshift(sx,index)
184 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
185 : cchiw 2611 val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
186 : cchiw 2843 val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
187 :     val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
188 : cchiw 2611 val args'=argsA@[PArg]
189 : cchiw 2838 in
190 :     (body',params',args' ,code)
191 : cchiw 2606 end
192 :    
193 : cchiw 2611
194 : cchiw 2843 (* expandEinOp: code-> code list
195 :     *Looks to see if the expression has a probe. If so, replaces it.
196 :     * Note how we keeps eps type expressions so we have less time in mid-to-low-il stage
197 :     *)
198 : cchiw 2838 fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
199 : cchiw 2843 fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
200 :     (String.concatWith",\t"(List.map split.printEINAPP code))]
201 : cchiw 2606
202 : cchiw 2843 fun rewriteBody b=(case b
203 :     of E.Probe e =>let
204 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
205 : cchiw 2838 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
206 :     val code=newbies@[einapp]
207 :     in
208 : cchiw 2843 code
209 : cchiw 2606 end
210 : cchiw 2843 | E.Sum(sx,E.Probe e) =>let
211 : cchiw 2838 val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
212 :     val body'=E.Sum(sx,body')
213 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
214 :     val code=newbies@[einapp]
215 : cchiw 2606 in
216 : cchiw 2843 code
217 : cchiw 2606 end
218 : cchiw 2843 | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
219 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
220 :     val body'=E.Sum(sx,E.Prod[eps,body'])
221 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
222 :     val code=newbies@[einapp]
223 :     in
224 :     code
225 :     end
226 :     | _=> [e]
227 : cchiw 2606 (* end case *))
228 :     in
229 : cchiw 2843 rewriteBody body
230 : cchiw 2606 end
231 :    
232 : cchiw 2843
233 :    
234 : cchiw 2606 end; (* local *)
235 :    
236 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0