35 |
structure SrcV = SrcIL.Var |
structure SrcV = SrcIL.Var |
36 |
structure P=Printer |
structure P=Printer |
37 |
structure shiftHtM=shiftHtM |
structure shiftHtM=shiftHtM |
38 |
|
structure split=splitHtM |
39 |
|
|
40 |
|
|
41 |
datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int |
datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int |
44 |
|
|
45 |
|
|
46 |
fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) |
fun assign (x, rator, args) = (x, DstIL.OP(rator, args)) |
47 |
fun assignEin (x, rator, args) = (print "assignment";(x, DstIL.EINAPP(rator, args))) |
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args))) |
48 |
|
|
49 |
fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args)) |
fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args)) |
50 |
fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args)) |
fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args)) |
95 |
|
|
96 |
|
|
97 |
(*Create fractional, and integer position vectors*) |
(*Create fractional, and integer position vectors*) |
98 |
fun createArgs2 (dim,v,posx)=let |
fun createArgs (dim,v,posx)=let |
99 |
val zz=print(String.concat["\n XYZ DIM-createArgs2-",Int.toString(dim),"-"]) |
|
100 |
val translate=DstOp.Translate v |
val translate=DstOp.Translate v |
101 |
val transform=DstOp.Transform v |
val transform=DstOp.Transform v |
102 |
val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*) |
val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*) |
118 |
assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*) |
assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*) |
119 |
] |
] |
120 |
|
|
121 |
in ([f,n],code) |
in ([n,f],code) |
122 |
end |
end |
123 |
|
|
124 |
|
|
|
fun Position(img,t,newposArg,dim,pos1,ppos)=let |
|
|
val zz=print(String.concat["\n XYZ DIM-Position-",Int.toString(dim),"-"]) |
|
|
val (args',code')=createArgs2(dim,img,newposArg) |
|
|
val pos2=pos1+1 |
|
|
in (pos1,pos2, ppos,args',code') |
|
|
end |
|
125 |
|
|
126 |
|
|
127 |
|
|
133 |
| createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim) |
| createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim) |
134 |
|
|
135 |
|
|
136 |
fun replaceH(kvar, place,args)=let |
|
|
val l1=List.take(args, place) |
|
|
val l2=List.drop(args,place+1) |
|
|
in l1@[kvar]@l2 end |
|
137 |
|
|
138 |
(*Get Img, and Kern Args*) |
(*Get Img, and Kern Args*) |
139 |
fun getArgs(hid,hArg,V,imgArg,args)= |
fun getArgs(hid,hArg,V,imgArg,args)= |
141 |
of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let |
of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let |
142 |
val hvar=DstV.new ("KNL", DstTy.KernelTy) |
val hvar=DstV.new ("KNL", DstTy.KernelTy) |
143 |
val imgvar=DstV.new ("IMG", DstTy.ImageTy img) |
val imgvar=DstV.new ("IMG", DstTy.ImageTy img) |
144 |
|
fun replaceH(kvar, place,args)=let |
145 |
|
val l1=List.take(args, place) |
146 |
|
val l2=List.drop(args,place+1) |
147 |
|
in l1@[kvar]@l2 end |
148 |
|
|
149 |
val args1=replaceH(hvar, hid,args) |
val args1=replaceH(hvar, hid,args) |
150 |
val args2=replaceH(imgvar, V,args1) |
val args2=replaceH(imgvar, V,args1) |
151 |
in |
in |
155 |
| _ => raise Fail "Not a kernel argument" |
| _ => raise Fail "Not a kernel argument" |
156 |
|
|
157 |
|
|
158 |
|
(*Get Img, and Krn Args*) |
159 |
|
fun getArgs3(hid,hArg,V,imgArg,args)= |
160 |
|
case (getRHS2 hArg,getRHS2 imgArg) |
161 |
|
of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let |
162 |
|
val hvar=DstV.new ("KNL", DstTy.KernelTy) |
163 |
|
val imgvar=DstV.new ("IMG", DstTy.ImageTy img) |
164 |
|
in |
165 |
|
(Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],[imgvar, hvar]) |
166 |
|
end |
167 |
|
| ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument" |
168 |
|
| _ => raise Fail "Not a kernel argument" |
169 |
|
|
170 |
|
|
|
fun expandEinProbe(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let |
|
|
|
|
|
|
|
|
val E.IMG(dim)=List.nth(params,V) |
|
|
val zz=print(String.concat["\n XYZ DIM-expanEin-",Int.toString(dim),"-"]) |
|
|
|
|
|
val kArg=List.nth(origargs,h) |
|
|
val imgArg=List.nth(origargs,V) |
|
|
val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args) |
|
|
val newposArg=List.nth(args, t) |
|
|
|
|
|
val ppos=[E.TEN(1,[dim]),E.TEN(3,[dim])] |
|
171 |
|
|
|
val (fid,nid,params',args',code')=Position(img,t,newposArg,dim,(length params),ppos) |
|
|
val shift=(length index) |
|
172 |
|
|
173 |
(* val z=print(String.concat["\n SHIFt SET To",Int.toString(shift),"SX IS ", Int.toString(sx)])*) |
fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let |
174 |
|
|
175 |
(*sumIndex creating summaiton Index for body*) |
(*sumIndex creating summaiton Index for body*) |
176 |
fun sumIndex(0)=[] |
fun sumIndex(0)=[] |
183 |
val dim'=dim-1 |
val dim'=dim-1 |
184 |
val sum=sx+dim' |
val sum=sx+dim' |
185 |
val dels=createDels(deltas,dim') |
val dels=createDels(deltas,dim') |
|
(*val L=print "\n creatWith " |
|
|
val LL=print(Int.toString(dim'))*) |
|
186 |
val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]] |
val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]] |
187 |
val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum))) |
val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum))) |
|
|
|
188 |
in |
in |
189 |
createKRN(dim',pos@imgpos,[rest']@rest) |
createKRN(dim',pos@imgpos,[rest']@rest) |
190 |
end |
end |
191 |
|
|
192 |
val exp=createKRN(dim, [],[]) |
val exp=createKRN(dim, [],[]) |
193 |
val esum=sumIndex (dim) |
val esum=sumIndex (dim) |
194 |
in (params@params',E.Sum(esum, exp),args2@args',argcode@code') end |
in E.Sum(esum, exp) |
195 |
|
end |
196 |
|
|
197 |
|
(* Expand probe in place *) |
198 |
|
fun expandEinProbe4(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let |
199 |
|
|
200 |
|
|
201 |
|
val E.IMG(dim)=List.nth(params,V) |
202 |
|
val kArg=List.nth(origargs,h) |
203 |
|
val imgArg=List.nth(origargs,V) |
204 |
|
val newposArg=List.nth(args, t) |
205 |
|
|
206 |
|
val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args) |
207 |
|
val params'=[E.TEN(3,[dim]),E.TEN(1,[dim])] |
208 |
|
val (args',code')=createArgs(dim,img,newposArg) |
209 |
|
val fid=length params |
210 |
|
val nid=fid+1 |
211 |
|
|
212 |
|
val body=createBody(dim, s,sx,shape,deltas,V, h, nid, fid) |
213 |
|
in (params@params',body,args2@args',argcode@code') end |
214 |
|
|
215 |
|
(*Lift Probe*) |
216 |
|
fun expandEinProbe3(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let |
217 |
|
|
218 |
|
|
219 |
|
val E.IMG(dim)=List.nth(params,V) |
220 |
|
val kArg=List.nth(origargs,h) |
221 |
|
val imgArg=List.nth(origargs,V) |
222 |
|
val newposArg=List.nth(args, t) |
223 |
|
|
224 |
|
val (s,img,argcode,argsVH) =getArgs3(h,kArg,V,imgArg,args) |
225 |
|
val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])] |
226 |
|
val (argsT,code')=createArgs(dim,img,newposArg) |
227 |
|
|
228 |
|
val h=1 |
229 |
|
val fid=2 |
230 |
|
val nid=3 |
231 |
|
|
232 |
|
val body=createBody(dim, s,sx,shape,deltas,0, h, nid, fid) |
233 |
|
in (params',body,argsVH@argsT,argcode@code') end |
234 |
|
|
235 |
|
|
236 |
|
|
237 |
fun TS x = case SrcIL.Var.binding x |
fun TS x = case SrcIL.Var.binding x |
247 |
|
|
248 |
|
|
249 |
|
|
250 |
|
|
251 |
fun ShapeConv([],n)=[] |
fun ShapeConv([],n)=[] |
252 |
| ShapeConv(E.C c::es, n)=ShapeConv(es, n) |
| ShapeConv(E.C c::es, n)=ShapeConv(es, n) |
253 |
| ShapeConv(E.V v::es, n)= |
| ShapeConv(E.V v::es, n)= |
255 |
else ShapeConv(es,n) |
else ShapeConv(es,n) |
256 |
|
|
257 |
|
|
258 |
|
(*Lift probe*) |
259 |
fun foundProbe(b,(params,args),index, sumIndex,origargs)=let |
fun foundProbe3(b,(params,args),index, sumIndex,origargs)=let |
260 |
|
val _=print "\n Lift Probe\n" |
261 |
val E.Probe(E.Conv(_,alpha,_,dx),pos)=b |
val E.Probe(E.Conv(_,alpha,_,dx),pos)=b |
262 |
|
|
263 |
val newId=length(params) |
val newId=length(params) |
275 |
val newArg = DstV.new ("PC", DstTy.tensorTy shape') |
val newArg = DstV.new ("PC", DstTy.tensorTy shape') |
276 |
|
|
277 |
|
|
|
|
|
278 |
(*Expand Probe*) |
(*Expand Probe*) |
279 |
val ns=length sumIndex |
val ns=length sumIndex |
280 |
val (params',body',args',code) =(case ns |
val (params',body',args',code) =(case ns |
281 |
of 0=> expandEinProbe(b,(params,args),index, n,origargs) |
of 0=> expandEinProbe3(b,(params,args),index, n,origargs) |
282 |
|_=>let |
|_=>let |
283 |
val (E.V v,_,_)=List.nth(sumIndex, ns-1) |
val (E.V v,_,_)=List.nth(sumIndex, ns-1) |
284 |
val (p,body',a,c)= expandEinProbe(b,(params,args),index, v+1,origargs) |
val (p,body',a,c)= expandEinProbe3(b,(params,args),index, v+1,origargs) |
285 |
in (p,E.Sum(sumIndex ,body'),a,c) |
in (p,E.Sum(sumIndex ,body'),a,c) |
286 |
end |
end |
287 |
(* end case *)) |
(* end case *)) |
289 |
|
|
290 |
val (p',i',b',a')=shiftHtM.clean(params', index,body', args') |
val (p',i',b',a')=shiftHtM.clean(params', index,body', args') |
291 |
val newbie'=Ein.EIN{params=p', index=i', body=b'} |
val newbie'=Ein.EIN{params=p', index=i', body=b'} |
292 |
val zz=print(P.printerE(newbie')) |
val _ = print(String.concat["\n ", split.printA(newArg, newbie', a'),"\n"]) |
293 |
val data=assignEin (newArg, newbie', a') |
val data=assignEin (newArg, newbie', a') |
|
val gg=print(String.concat["\n\n ---",P.printerE(newbie'),"\n\n"]) |
|
294 |
in (newB, (params@[newP],args@[newArg]) ,code@[data]) |
in (newB, (params@[newP],args@[newArg]) ,code@[data]) |
295 |
end |
end |
296 |
|
|
297 |
|
|
298 |
|
(* Expand probe in place *) |
299 |
|
fun foundProbe4(b,(params,args),index, sumIndex,origargs)=let |
300 |
|
val _=print "\n Don't replace probe \n" |
301 |
|
val E.Probe(E.Conv(_,alpha,_,dx),pos)=b |
302 |
|
|
303 |
|
val newId=length(params) |
304 |
|
val n=length(index) |
305 |
|
|
306 |
(*copied from high-to-mid.sml*) |
(*Expand Probe*) |
307 |
|
val ns=length sumIndex |
308 |
|
val (params',body',args',code) =(case ns |
309 |
|
of 0=> expandEinProbe4(b,(params,args),index, n,origargs) |
310 |
|
|_=>let |
311 |
|
val (E.V v,_,_)=List.nth(sumIndex, ns-1) |
312 |
|
(* val _ =print(String.concat["\n Found sum sx, last element is ",Int.toString(v), "\n"])*) |
313 |
|
in expandEinProbe4(b,(params,args),index, v+1,origargs) |
314 |
|
end |
315 |
|
(* end case *)) |
316 |
|
|
317 |
|
val UUU=Ein.EIN{params=params', index=index, body=body'} |
318 |
|
val _ =print(String.concat["\n $$$ new sub-expression $$$ \n",P.printerE(UUU),"\n"]) |
319 |
|
|
320 |
|
in (body',(params',args') ,code) |
321 |
|
end |
322 |
|
|
323 |
|
|
324 |
|
fun flatten []=[] |
325 |
|
| flatten(e1::es)=e1@(flatten es) |
326 |
|
|
327 |
|
|
328 |
|
(* sx-[] then move out, otherwise keep in *) |
329 |
fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let |
fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let |
330 |
|
|
331 |
val dummy=E.Const 0 |
val dummy=E.Const 0 |
333 |
|
|
334 |
(*b-current body, info-original ein op, data-new assigments*) |
(*b-current body, info-original ein op, data-new assigments*) |
335 |
fun rewriteBody(b,info)= let |
fun rewriteBody(b,info)= let |
336 |
|
val t=print(String.concat["\n\n ***:",Printer.printbody(b),"$ \n"]) |
337 |
in (case b |
in (case b |
338 |
of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let |
of E.Sum(c, E.Probe(E.Conv v, E.Tensor t)) =>let |
339 |
val ref sx=sumIndex |
val ref sx=sumIndex |
340 |
val (t',info',data)=foundProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, sx@c,origargs) |
in (case sx |
341 |
in (t',info',data) |
of [] => foundProbe3(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs) |
342 |
|
| _ => foundProbe4(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs) |
343 |
|
(* end case*)) |
344 |
end |
end |
345 |
| E.Probe(E.Conv _, E.Tensor _) =>let |
| E.Probe(E.Conv _, E.Tensor _) =>let |
346 |
val ref sx=sumIndex |
val ref sx=sumIndex |
347 |
val (t',info',data)=foundProbe(b, info,index, sx,origargs) |
in (case sx |
348 |
in (t',info',data) |
of []=> foundProbe3(b, info,index, [],origargs) |
349 |
|
| _=> foundProbe4(b, info,index, flatten sx,origargs) |
350 |
|
(* end case*)) |
351 |
end |
end |
352 |
| E.Probe _=> (dummy,info,[]) |
| E.Probe _=> (dummy,info,[]) |
353 |
| E.Conv _=> (dummy,info,[]) |
| E.Conv _=> (dummy,info,[]) |
359 |
(E.Neg(body'),info',data') |
(E.Neg(body'),info',data') |
360 |
end |
end |
361 |
| E.Sum (c,e)=> let |
| E.Sum (c,e)=> let |
362 |
val ref sx=sumIndex |
val ref x=sumIndex |
363 |
val s=sumIndex:=(sx@c) |
val c'=[c]@x |
364 |
val (body',info',data')=rewriteBody (e,info) |
val (body',info',data')=(sumIndex:=c';rewriteBody (e,info)) |
365 |
|
val ref s=sumIndex |
366 |
|
val z=hd(s) |
367 |
in |
in |
368 |
(E.Sum(c, body'),info',data') |
(sumIndex:=tl(s);(E.Sum(z, body'),info',data')) |
369 |
end |
end |
370 |
| E.Sub(a,b)=>let |
| E.Sub(a,b)=>let |
371 |
val (bodyA,infoA,dataA)= rewriteBody(a,info) |
val (bodyA,infoA,dataA)= rewriteBody(a,info) |
394 |
end |
end |
395 |
|
|
396 |
val empty =fn key =>NONE |
val empty =fn key =>NONE |
397 |
val mm=print "Starting Exapnd" |
val mm=print "\n ************************** \n Starting Exapnd" |
398 |
val (body',(params',args'),newbies)=rewriteBody(body,(params,args)) |
val (body',(params',args'),newbies)=rewriteBody(body,(params,args)) |
399 |
val e'=Ein.EIN{params=params', index=index, body=body'} |
val e'=Ein.EIN{params=params', index=index, body=body'} |
400 |
val rr=print (String.concat[P.printbody(body),"DONE expand"]) |
|
401 |
|
val rr=print (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "]) |
402 |
in ((e',args'),newbies) end |
in ((e',args'),newbies) end |
403 |
|
|
404 |
end; (* local *) |
end; (* local *) |