Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

# SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : *) 18 : cchiw 2838 19 : structure Split = struct 20 : 21 : local 22 : 23 : structure E = Ein 24 : structure DstIL = MidIL 25 : structure DstTy = MidILTypes 26 : structure DstV = DstIL.Var 27 : structure P=Printer 28 : structure cleanP=cleanParams 29 : structure cleanI=cleanIndex 30 : 31 : cchiw 3033 32 : cchiw 2838 in 33 : 34 : cchiw 3167 val numFlag=0 (*remove common subexpression*) 35 : cchiw 3033 val testing=0 36 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 37 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 38 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 39 : fun setEinZero y= (y,einappzero) 40 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 41 : fun cleanIndex e =cleanI.cleanIndex e 42 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e 43 : cchiw 2838 fun itos i =Int.toString i 44 : cchiw 2923 fun filterSca e=Filter.filterSca e 45 : cchiw 2838 fun err str=raise Fail str 46 : val cnt = ref 0 47 : cchiw 3166 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1) 48 : cchiw 2838 fun genName prefix = let 49 : val n = !cnt 50 : in 51 : cnt := n+1; 52 : String.concat[prefix, "_", Int.toString n] 53 : end 54 : fun testp n=(case testing 55 : of 0=> 1 56 : | _ =>(print(String.concat n);1) 57 : (*end case*)) 58 : 59 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 60 : *lifts expression and returns replacement tensor 61 : cchiw 2843 * cleans the index and params of subexpression 62 : *creates new param and replacement tensor for the original ein_exp 63 : cchiw 2838 *) 64 : cchiw 3166 fun lift(name,e,params,index,sx,args,fieldset,flag)=let 65 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 66 : cchiw 2843 val id=length(params) 67 : val Rparams=params@[E.TEN(1,sizes)] 68 : val Re=E.Tensor(id,tshape) 69 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 70 : cchiw 2843 val Rargs=args@[M] 71 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 72 : cchiw 3166 val (_,einapp0)=einapp 73 : val (Rargs,newbies,fieldset) =(case flag 74 : of 1=> let 75 : val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0) 76 : in (case var 77 : of NONE=> (args@[M],[einapp],fieldset) 78 : | SOME v=> (incUse v ;(args@[v],[],fieldset)) 79 : (*end case*)) 80 : end 81 : | _=>(args@[M],[einapp],fieldset) 82 : (*end case*)) 83 : in 84 : (Re,Rparams,Rargs,newbies,fieldset) 85 : end 86 : cchiw 2838 87 : cchiw 3166 88 : cchiw 2838 (* isOp: ein->int 89 : * checks to see if this sub-expression is pulled out or split form original 90 : * 0-becomes zero,1-remains the same, 2-operator 91 : *) 92 : fun isOp e =(case e 93 : of E.Field _ => 0 94 : | E.Conv _ => 0 95 : | E.Apply _ => 0 96 : | E.Lift _ => 0 97 : | E.Neg _ => 1 98 : cchiw 2870 | E.Sqrt _ => 1 99 : cchiw 3138 | E.Cosine _ => 1 100 : | E.ArcCosine _ => 1 101 : | E.Sine _ => 1 102 : cchiw 3166 | E.ArcSine _ => 1 103 : cchiw 3138 | E.PowInt _ => 1 104 : | E.PowReal _ => 1 105 : cchiw 2838 | E.Add _ => 1 106 : | E.Sub _ => 1 107 : | E.Prod _ => 1 108 : | E.Div _ => 1 109 : | E.Sum _ => 1 110 : | E.Probe _ => 1 111 : | E.Partial _ => err(" Partial used after normalize") 112 : | E.Krn _ => err("Krn used before expand") 113 : | E.Value _ => err("Value used before expand") 114 : | E.Img _ => err("Probe used before expand") 115 : | _ => 2 116 : (*end case*)) 117 : 118 : cchiw 3166 119 : 120 : fun rewriteOp3(name,sx,e1,x)=let 121 : val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x 122 : val params=Ein.params ein 123 : val index=Ein.index ein 124 : in (case (isOp e1) 125 : of 0 => (E.Const 0,params,args,[],fieldset) 126 : | 1 => lift(name,e1,params,index,sx,args,fieldset,flag) 127 : | 2 => (e1,params,args,[],fieldset) 128 : (*end*)) 129 : end 130 : 131 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 132 : cchiw 3166 * If e1 an op then call lift() to replace it 133 : *) 134 : fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1) 135 : of 0 => (E.Const 0,params,args,[],fieldset) 136 : | 1 => lift(name,e1,params,index,sx,args,fieldset,flag) 137 : | 2 => (e1,params,args,[],fieldset) (*not lifted*) 138 : cchiw 2838 (*end*)) 139 : 140 : cchiw 3166 141 : cchiw 3017 142 : cchiw 3166 143 : fun rewriteOps(name,list1,params,index,sx,args,fieldset0,flag)=let 144 : fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset) 145 : | m(e1::es,rest,params,args,code,fieldset)=let 146 : 147 : val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag) 148 : cchiw 2838 in 149 : cchiw 3166 m(es,rest@[e1'],params',args',code@code',fieldset) 150 : cchiw 2838 end 151 : in 152 : cchiw 3166 m(list1,[],params,args,[],fieldset0) 153 : cchiw 2838 end 154 : cchiw 3166 155 : 156 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars 157 : cchiw 2843 When the operation is zero then we return a real. 158 : cchiw 2845 -Moved is Zero to before split. 159 : cchiw 2843 *) 160 : cchiw 3033 fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args) 161 : cchiw 2838 162 : cchiw 3166 fun rewriteOrig3(sx,body,params,args,x) =let 163 : val ((y,DstIL.EINAPP(ein,_)),_,_)=x 164 : val index=Ein.index ein 165 : in cleanParams(y,body,params,index,args) 166 : end 167 : 168 : cchiw 2838 (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 169 : cchiw 2843 * calls rewriteOp() lift on ein_exp 170 : cchiw 2838 *) 171 : cchiw 3166 fun handleNeg(e1,x)=let 172 : val (e1',params',args',code,fieldset)= rewriteOp3("neg",[],e1,x) 173 : val body' =E.Neg e1' 174 : val einapp= rewriteOrig3([],body',params',args',x) 175 : cchiw 2845 in 176 : cchiw 3166 (einapp,code,fieldset) 177 : cchiw 2845 end 178 : cchiw 2838 179 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code 180 : * calls rewriteOp() lift on ein_exp 181 : *) 182 : cchiw 3166 fun handleSqrt(y,e1,params,index,args,fieldset,flag)=let 183 : val (e1',params',args',code,fieldset)= rewriteOp("sqrt", e1,params,index,[],args,fieldset,flag) 184 : cchiw 2867 val body =E.Sqrt e1' 185 : val einapp= rewriteOrig(y,body,params',index,[],args') 186 : in 187 : cchiw 3166 (einapp,code,fieldset) 188 : cchiw 2867 end 189 : 190 : 191 : cchiw 3138 (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code 192 : * calls rewriteOp() lift on ein_exp 193 : *) 194 : cchiw 3166 fun handleCosine(y,e1,params,index,args,fieldset,flag)=let 195 : val (e1',params',args',code,fieldset)= rewriteOp("cosine", e1,params,index,[],args,fieldset,flag) 196 : cchiw 3138 val body =E.Cosine e1' 197 : val einapp= rewriteOrig(y,body,params',index,[],args') 198 : in 199 : cchiw 3166 (einapp,code,fieldset) 200 : cchiw 3138 end 201 : 202 : (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code 203 : * calls rewriteOp() lift on ein_exp 204 : *) 205 : cchiw 3166 fun handleArcCosine(y,e1,params,index,args,fieldset,flag)=let 206 : val (e1',params',args',code,fieldset)= rewriteOp("ArcCosine", e1,params,index,[],args,fieldset,flag) 207 : cchiw 3138 val body =E.ArcCosine e1' 208 : val einapp= rewriteOrig(y,body,params',index,[],args') 209 : in 210 : cchiw 3166 (einapp,code,fieldset) 211 : cchiw 3138 end 212 : 213 : (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code 214 : * calls rewriteOp() lift on ein_exp 215 : *) 216 : cchiw 3166 fun handleSine(y,e1,params,index,args,fieldset,flag)=let 217 : val (e1',params',args',code,fieldset)= rewriteOp("sine", e1,params,index,[],args,fieldset,flag) 218 : cchiw 3138 val body =E.Sine e1' 219 : val einapp= rewriteOrig(y,body,params',index,[],args') 220 : in 221 : cchiw 3166 (einapp,code,fieldset) 222 : cchiw 3138 end 223 : 224 : (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code 225 : * calls rewriteOp() lift on ein_exp 226 : *) 227 : cchiw 3166 fun handleArcSine(y,e1,params,index,args,fieldset,flag)=let 228 : val (e1',params',args',code,fieldset)= rewriteOp("ArcSine", e1,params,index,[],args,fieldset,flag) 229 : val body =E.ArcSine e1' 230 : cchiw 3033 val einapp= rewriteOrig(y,body,params',index,[],args') 231 : in 232 : cchiw 3166 (einapp,code,fieldset) 233 : cchiw 3033 end 234 : cchiw 2870 235 : 236 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 237 : cchiw 2843 * calls rewriteOps() lift on ein_exp 238 : cchiw 2838 *) 239 : cchiw 3166 fun handleSub(y,e1,e2,params,index,args,fieldset,flag)=let 240 : val ([e1',e2'],params',args',code,fieldset)= rewriteOps("subt",[e1,e2],params,index,[],args,fieldset,flag) 241 : cchiw 2838 val body =E.Sub(e1',e2') 242 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 243 : in 244 : cchiw 3166 (einapp,code,fieldset) 245 : cchiw 2845 end 246 : cchiw 2838 247 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 248 : cchiw 2843 * calls rewriteOp() lift on ein_exp 249 : cchiw 2838 *) 250 : cchiw 3166 fun handleDiv(y,e1,e2,params,index,args,fieldset,flag)=let 251 : val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag) 252 : val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag) 253 : cchiw 2838 val body =E.Div(e1',e2') 254 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2') 255 : in 256 : cchiw 3166 (einapp,code1'@code2',fieldset) 257 : cchiw 2845 end 258 : cchiw 2838 259 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 260 : cchiw 2843 * calls rewriteOps() lift on ein_exp 261 : cchiw 2838 *) 262 : cchiw 3166 fun handleAdd(y,e1,params,index,args,fieldset,flag)=let 263 : cchiw 3030 264 : cchiw 3166 val (e1',params',args',code,fieldset)= rewriteOps("add",e1,params,index,[],args,fieldset,flag) 265 : cchiw 2838 val body =E.Add e1' 266 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 267 : in 268 : cchiw 3166 (einapp,code,fieldset) 269 : cchiw 2845 end 270 : cchiw 2838 271 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 272 : cchiw 2843 * calls rewriteOps() lift on ein_exp 273 : cchiw 2838 *) 274 : cchiw 3166 fun handleProd(y,e1,params,index,args,fieldset,flag)=let 275 : val (e1',params',args',code,fieldset)= rewriteOps("prod",e1,params,index,[],args,fieldset,flag) 276 : cchiw 2838 val body =E.Prod e1' 277 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 278 : in 279 : cchiw 3166 (einapp,code,fieldset) 280 : cchiw 2845 end 281 : cchiw 2838 282 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 283 : cchiw 2843 * calls rewriteOps() lift on ein_exp 284 : cchiw 2838 *) 285 : cchiw 3166 fun handleSumProd(y,e1,params,index,sx,args,fieldset,flag)=let 286 : val (e1',params',args',code,fieldset)= rewriteOps("sumprod",e1,params,index,sx,args,fieldset,flag) 287 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 288 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args') 289 : in 290 : cchiw 3166 (einapp,code,fieldset) 291 : cchiw 2845 end 292 : cchiw 2838 293 : (* split:var*ein_app-> (var*einap)*code 294 : * split ein expression into smaller pieces 295 : cchiw 2843 note we leave summation around probe exp 296 : cchiw 2838 *) 297 : cchiw 3166 fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let 298 : val x= ((y,einapp),fieldset,flag) 299 : val zero= (setEinZero y,[],fieldset) 300 : val default=((y,einapp),[],fieldset) 301 : cchiw 2838 val sumIndex=ref [] 302 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) 303 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body] 304 : fun rewrite b=(case b 305 : cchiw 2847 of E.Probe (E.Conv _,_) => default 306 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str 307 : cchiw 2847 | E.Probe _ => raise Fail str 308 : cchiw 2838 | E.Conv _ => zero 309 : | E.Field _ => zero 310 : | E.Apply _ => zero 311 : | E.Lift e => zero 312 : | E.Delta _ => default 313 : | E.Epsilon _ => default 314 : cchiw 2843 | E.Eps2 _ => default 315 : cchiw 2838 | E.Tensor _ => default 316 : | E.Const _ => default 317 : cchiw 2923 | E.ConstR _ => default 318 : cchiw 3166 | E.Neg e1 => handleNeg(e1,x) 319 : | E.Sqrt e1 => handleSqrt(y,e1,params,index,args,fieldset,flag) 320 : | E.Cosine e1 => handleCosine(y,e1,params,index,args,fieldset,flag) 321 : | E.ArcCosine e1 => handleArcCosine(y,e1,params,index,args,fieldset,flag) 322 : | E.Sine e1 => handleSine(y,e1,params,index,args,fieldset,flag) 323 : | E.ArcSine e1 => handleArcSine(y,e1,params,index,args,fieldset,flag) 324 : | E.PowInt e1 => err(" PowInt unsupported") 325 : | E.PowReal e1 => err(" PowReal unsupported") 326 : | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args,fieldset,flag) 327 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args,fieldset,flag) 328 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default 329 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default 330 : | E.Sum(_,E.Probe(E.Conv _,_)) => default 331 : cchiw 3166 | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args,fieldset,flag) 332 : | E.Sum(sx,E.Delta d) => handleSumProd(y,[E.Delta d],params,index,sx,args,fieldset,flag) 333 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str) 334 : cchiw 3166 | E.Add e1 => handleAdd(y,e1,params,index,args,fieldset,flag) 335 : | E.Prod e1 => handleProd(y,e1,params,index,args,fieldset,flag) 336 : cchiw 2838 | E.Partial _ => err(" Partial used after normalize") 337 : | E.Krn _ => err("Krn used before expand") 338 : | E.Value _ => err("Value used before expand") 339 : | E.Img _ => err("Probe used before expand") 340 : (*end case *)) 341 : cchiw 3166 val (einapp2,newbies,fieldset) =rewrite body 342 : cchiw 2838 in 343 : cchiw 3166 ((einapp2,newbies),fieldset) 344 : cchiw 2838 end 345 : cchiw 3166 |split((y,app),fieldset,_) =(((y,app),[]),fieldset) 346 : cchiw 2923 347 : cchiw 3166 348 : fun iterMultiple(einapp2,newbies2,fieldset)=let 349 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code) 350 : cchiw 3166 | itercode(e1::newbies,rest,code,cnt)=let 351 : val ((einapp3,code3),_) = split(e1,fieldset,numFlag) 352 : val (rest4,code4)=itercode(code3,[],[],cnt+1) 353 : in 354 : itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2) 355 : end 356 : val(rest,code)= itercode(newbies2,[],[],1) 357 : in 358 : ((code)@rest@[einapp2]) 359 : end 360 : 361 : 362 : fun iterAll(einapp2,fieldset)=let 363 : fun itercode([],rest,code,_)=(rest,code) 364 : cchiw 3017 | itercode(e1::newbies,rest,code,cnt)=let 365 : cchiw 3166 val ((einapp3,code3),_) = split(e1,fieldset,numFlag) 366 : cchiw 3017 val (rest4,code4)=itercode(code3,[],[],cnt+1) 367 : cchiw 3166 in 368 : itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2) 369 : cchiw 2838 end 370 : cchiw 3166 val(rest,code)= itercode(einapp2,[],[],0) 371 : cchiw 2838 in 372 : cchiw 3166 (code@rest) 373 : cchiw 2838 end 374 : 375 : cchiw 3166 fun splitEinApp einapp3= let 376 : val fieldset= einSet.EinSet.empty 377 : cchiw 3017 378 : cchiw 3166 (* **** split in parts **** *) 379 : (* 380 : val ((einapp4,newbies4),fieldset)=split(einapp3,fieldset,0) 381 : val _ =testp["\n\t===>\n",printEINAPP(einapp4),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies4))] 382 : val (newbies5)= iterMultiple(einapp4,newbies4,fieldset) 383 : *) 384 : 385 : (* **** split all at once **** *) 386 : val (newbies5)= iterAll([einapp3],fieldset) 387 : 388 : in 389 : newbies5 390 : end 391 : 392 : 393 : cchiw 2838 end; (* local *) 394 : 395 : end (* local *)