Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2847 - (view) (download)
Original Path: branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure ProbeEin = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure DstIL = MidIL
13 :     structure DstOp = MidOps
14 :     structure P=Printer
15 : cchiw 2611 structure T=TransformEin
16 : cchiw 2845 structure MidToS=MidToString
17 : cchiw 2606 in
18 :    
19 : cchiw 2843 (* This file expands probed fields
20 : cchiw 2845 * Take a look at ProbeEin tex file for examples
21 : cchiw 2843 *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
22 :     * Param_ids are used to note the placement of the argument in the midIL.var list
23 : cchiw 2844 * Index_ids keep track of the shape of an Image or differentiation.
24 :     * Mu bind Index_id
25 :     * Generally, we will refer to the following
26 : cchiw 2843 *dim:dimension of field V
27 :     * s: support of kernel H
28 :     * alpha: The alpha in <V_alpha * H^(deltas)>
29 :     * deltas: The deltas in <V_alpha * H^(deltas)>
30 :     * Vid:param_id for V
31 :     * hid:param_id for H
32 :     * nid: integer position param_id
33 :     * fid :fractional position param_id
34 :     *img-imginfo about V
35 :     *)
36 :    
37 : cchiw 2845 val testing=0
38 :     val cnt = ref 0
39 : cchiw 2606
40 : cchiw 2845 fun transformToIndexSpace e=T.transformToIndexSpace e
41 :     fun transformToImgSpace e=T.transformToImgSpace e
42 :     fun testp n=(case testing
43 :     of 0=> 1
44 :     | _ =>(print(String.concat n);1)
45 :     (*end case*))
46 :     fun getRHSDst x = (case DstIL.Var.binding x
47 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
48 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
49 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
50 :     (* end case *))
51 : cchiw 2838
52 : cchiw 2606
53 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
54 :     uses the Param_ids for the image, kernel,
55 :     and position tensor to get the Mid-IL arguments
56 :     returns the support of ther kernel, and image
57 :     *)
58 :     fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
59 :     of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
60 :     in
61 :     ((Kernel.support h) ,img,ImageInfo.dim img)
62 :     end
63 :     | _ => raise Fail "Expected Image and kernel arguments"
64 :     (*end case*))
65 : cchiw 2606
66 :    
67 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
68 :     ->int*Mid.ILVars list* code*int* low-il-var
69 :     * uses the Param_ids for the image, kernel, and tensor
70 :     * and gets the mid-IL vars for each.
71 :     *Transforms the position to index space
72 :     *P is the mid-il var for the (transformation matrix)transpose
73 :     *)
74 :     fun handleArgs(Vid,hid,tid,args)=let
75 :     val imgArg=List.nth(args,Vid)
76 :     val hArg=List.nth(args,hid)
77 :     val newposArg=List.nth(args,tid)
78 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
79 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
80 : cchiw 2606 in
81 : cchiw 2845 (dim,args@argsT,code, s,P)
82 : cchiw 2606 end
83 : cchiw 2838
84 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
85 :     * expands the body for the probed field
86 :     *)
87 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
88 :     (*1-d fields*)
89 :     fun createKRND1 ()=let
90 :     val sum=sx
91 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
92 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
93 :     val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
94 :     in
95 :     E.Prod [E.Img(Vid,alpha,pos),rest]
96 : cchiw 2843 end
97 : cchiw 2845 (*createKRN Image field and kernels *)
98 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
99 :     | createKRN(dim,imgpos,rest)=let
100 :     val dim'=dim-1
101 :     val sum=sx+dim'
102 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
103 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
104 :     val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
105 :     in
106 :     createKRN(dim',pos@imgpos,[rest']@rest)
107 :     end
108 :     val exp=(case dim
109 :     of 1 => createKRND1()
110 :     | _=> createKRN(dim, [],[])
111 :     (*end case*))
112 :     (*sumIndex creating summaiton Index for body*)
113 :     val slb=1-s
114 :     val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
115 : cchiw 2843 in
116 : cchiw 2845 E.Sum(esum, exp)
117 : cchiw 2606 end
118 :    
119 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
120 :     *get fresh/unused index_id, returns int
121 :     *)
122 :     fun getsumshift(sx,index) =let
123 :     val nsumshift= (case sx
124 :     of []=> length(index)
125 :     | _=>let
126 :     val (E.V v,_,_)=List.hd(List.rev sx)
127 :     in v+1
128 :     end
129 :     (* end case *))
130 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
131 :     val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
132 :     "\nThink nshift is ", Int.toString nsumshift]
133 :     in
134 :     nsumshift
135 :     end
136 : cchiw 2611
137 : cchiw 2845 (*formBody:ein_exp->ein_exp
138 :     *just does a quick rewrite
139 :     *)
140 :     fun formBody(E.Sum([],e))=formBody e
141 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
142 :     | formBody(E.Prod [e])=e
143 :     | formBody e=e
144 : cchiw 2606
145 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
146 :     -> ein_exp* *code
147 :     * Transforms position to world space
148 :     * transforms result back to index_space
149 :     * rewrites body
150 :     * replace probe with expanded version
151 :     *)
152 :     fun replaceProbe(b,params,args,index, sx)=let
153 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
154 :     val fid=length(params)
155 :     val nid=fid+1
156 :     val Pid=nid+1
157 :     val nshift=length(dx)
158 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
159 :     val freshIndex=getsumshift(sx,index)
160 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
161 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
162 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
163 :     val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
164 :     val args'=argsA@[PArg]
165 :     in
166 :     (body',params',args' ,code)
167 :     end
168 : cchiw 2606
169 : cchiw 2845 (* expandEinOp: code-> code list
170 :     *Looks to see if the expression has a probe. If so, replaces it.
171 :     * Note how we keeps eps expressions so only generate pieces that are used
172 :     *)
173 :     fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
174 :     fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",
175 :     MidToS.printEINAPP e, "\n=>\n",
176 :     (String.concatWith",\t"(List.map MidToS.printEINAPP code))]
177 :     fun rewriteBody b=(case b
178 : cchiw 2847 of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"
179 :     | E.Probe e =>let
180 : cchiw 2845 val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
181 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
182 :     val code=newbies@[einapp]
183 :     in
184 :     code
185 :     end
186 :     | E.Sum(sx,E.Probe e) =>let
187 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
188 :     val body'=E.Sum(sx,body')
189 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
190 :     val code=newbies@[einapp]
191 :     in
192 :     code
193 :     end
194 :     | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
195 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
196 :     val body'=E.Sum(sx,E.Prod[eps,body'])
197 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
198 :     val code=newbies@[einapp]
199 :     in
200 :     code
201 :     end
202 :     | _=> [e]
203 :     (* end case *))
204 :     in
205 :     rewriteBody body
206 :     end
207 : cchiw 2843
208 : cchiw 2606 end; (* local *)
209 :    
210 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0