Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : *) 18 : cchiw 2838 19 : structure Split = struct 20 : 21 : local 22 : 23 : structure E = Ein 24 : structure DstIL = MidIL 25 : structure DstTy = MidILTypes 26 : structure DstV = DstIL.Var 27 : cchiw 3174 28 : cchiw 2838 structure P=Printer 29 : structure cleanP=cleanParams 30 : structure cleanI=cleanIndex 31 : 32 : cchiw 3033 33 : cchiw 2838 in 34 : 35 : cchiw 3174 val numFlag=1 (*remove common subexpression*) 36 : cchiw 3033 val testing=0 37 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 38 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 39 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 40 : fun setEinZero y= (y,einappzero) 41 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 42 : fun cleanIndex e =cleanI.cleanIndex e 43 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e 44 : cchiw 2838 fun itos i =Int.toString i 45 : cchiw 2923 fun filterSca e=Filter.filterSca e 46 : cchiw 2838 fun err str=raise Fail str 47 : val cnt = ref 0 48 : cchiw 3166 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1) 49 : cchiw 2838 fun genName prefix = let 50 : val n = !cnt 51 : in 52 : cnt := n+1; 53 : String.concat[prefix, "_", Int.toString n] 54 : end 55 : fun testp n=(case testing 56 : of 0=> 1 57 : | _ =>(print(String.concat n);1) 58 : (*end case*)) 59 : 60 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 61 : *lifts expression and returns replacement tensor 62 : cchiw 2843 * cleans the index and params of subexpression 63 : *creates new param and replacement tensor for the original ein_exp 64 : cchiw 2838 *) 65 : cchiw 3166 fun lift(name,e,params,index,sx,args,fieldset,flag)=let 66 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 67 : cchiw 2843 val id=length(params) 68 : val Rparams=params@[E.TEN(1,sizes)] 69 : val Re=E.Tensor(id,tshape) 70 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 71 : cchiw 2843 val Rargs=args@[M] 72 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 73 : cchiw 3166 val (_,einapp0)=einapp 74 : val (Rargs,newbies,fieldset) =(case flag 75 : of 1=> let 76 : val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0) 77 : in (case var 78 : of NONE=> (args@[M],[einapp],fieldset) 79 : | SOME v=> (incUse v ;(args@[v],[],fieldset)) 80 : (*end case*)) 81 : end 82 : | _=>(args@[M],[einapp],fieldset) 83 : (*end case*)) 84 : in 85 : (Re,Rparams,Rargs,newbies,fieldset) 86 : end 87 : cchiw 2838 88 : cchiw 3166 89 : cchiw 2838 (* isOp: ein->int 90 : * checks to see if this sub-expression is pulled out or split form original 91 : * 0-becomes zero,1-remains the same, 2-operator 92 : *) 93 : fun isOp e =(case e 94 : of E.Field _ => 0 95 : | E.Conv _ => 0 96 : | E.Apply _ => 0 97 : | E.Lift _ => 0 98 : | E.Neg _ => 1 99 : cchiw 2870 | E.Sqrt _ => 1 100 : cchiw 3138 | E.Cosine _ => 1 101 : | E.ArcCosine _ => 1 102 : | E.Sine _ => 1 103 : cchiw 3166 | E.ArcSine _ => 1 104 : cchiw 3138 | E.PowInt _ => 1 105 : | E.PowReal _ => 1 106 : cchiw 2838 | E.Add _ => 1 107 : | E.Sub _ => 1 108 : | E.Prod _ => 1 109 : | E.Div _ => 1 110 : | E.Sum _ => 1 111 : | E.Probe _ => 1 112 : | E.Partial _ => err(" Partial used after normalize") 113 : | E.Krn _ => err("Krn used before expand") 114 : | E.Value _ => err("Value used before expand") 115 : | E.Img _ => err("Probe used before expand") 116 : | _ => 2 117 : (*end case*)) 118 : 119 : cchiw 3166 120 : 121 : fun rewriteOp3(name,sx,e1,x)=let 122 : val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x 123 : val params=Ein.params ein 124 : val index=Ein.index ein 125 : in (case (isOp e1) 126 : of 0 => (E.Const 0,params,args,[],fieldset) 127 : | 1 => lift(name,e1,params,index,sx,args,fieldset,flag) 128 : | 2 => (e1,params,args,[],fieldset) 129 : (*end*)) 130 : end 131 : 132 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 133 : cchiw 3166 * If e1 an op then call lift() to replace it 134 : *) 135 : fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1) 136 : of 0 => (E.Const 0,params,args,[],fieldset) 137 : | 1 => lift(name,e1,params,index,sx,args,fieldset,flag) 138 : | 2 => (e1,params,args,[],fieldset) (*not lifted*) 139 : cchiw 2838 (*end*)) 140 : 141 : cchiw 3166 142 : cchiw 3017 143 : cchiw 3166 144 : fun rewriteOps(name,list1,params,index,sx,args,fieldset0,flag)=let 145 : fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset) 146 : | m(e1::es,rest,params,args,code,fieldset)=let 147 : 148 : val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag) 149 : cchiw 2838 in 150 : cchiw 3166 m(es,rest@[e1'],params',args',code@code',fieldset) 151 : cchiw 2838 end 152 : in 153 : cchiw 3166 m(list1,[],params,args,[],fieldset0) 154 : cchiw 2838 end 155 : cchiw 3166 156 : 157 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars 158 : cchiw 2843 When the operation is zero then we return a real. 159 : cchiw 2845 -Moved is Zero to before split. 160 : cchiw 2843 *) 161 : cchiw 3033 fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args) 162 : cchiw 2838 163 : cchiw 3166 fun rewriteOrig3(sx,body,params,args,x) =let 164 : val ((y,DstIL.EINAPP(ein,_)),_,_)=x 165 : val index=Ein.index ein 166 : in cleanParams(y,body,params,index,args) 167 : end 168 : 169 : cchiw 2838 (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 170 : cchiw 2843 * calls rewriteOp() lift on ein_exp 171 : cchiw 2838 *) 172 : cchiw 3166 fun handleNeg(e1,x)=let 173 : val (e1',params',args',code,fieldset)= rewriteOp3("neg",[],e1,x) 174 : val body' =E.Neg e1' 175 : val einapp= rewriteOrig3([],body',params',args',x) 176 : cchiw 2845 in 177 : cchiw 3166 (einapp,code,fieldset) 178 : cchiw 2845 end 179 : cchiw 2838 180 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code 181 : * calls rewriteOp() lift on ein_exp 182 : *) 183 : cchiw 3166 fun handleSqrt(y,e1,params,index,args,fieldset,flag)=let 184 : val (e1',params',args',code,fieldset)= rewriteOp("sqrt", e1,params,index,[],args,fieldset,flag) 185 : cchiw 2867 val body =E.Sqrt e1' 186 : val einapp= rewriteOrig(y,body,params',index,[],args') 187 : in 188 : cchiw 3166 (einapp,code,fieldset) 189 : cchiw 2867 end 190 : 191 : 192 : cchiw 3138 (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code 193 : * calls rewriteOp() lift on ein_exp 194 : *) 195 : cchiw 3166 fun handleCosine(y,e1,params,index,args,fieldset,flag)=let 196 : val (e1',params',args',code,fieldset)= rewriteOp("cosine", e1,params,index,[],args,fieldset,flag) 197 : cchiw 3138 val body =E.Cosine e1' 198 : val einapp= rewriteOrig(y,body,params',index,[],args') 199 : in 200 : cchiw 3166 (einapp,code,fieldset) 201 : cchiw 3138 end 202 : 203 : (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code 204 : * calls rewriteOp() lift on ein_exp 205 : *) 206 : cchiw 3166 fun handleArcCosine(y,e1,params,index,args,fieldset,flag)=let 207 : val (e1',params',args',code,fieldset)= rewriteOp("ArcCosine", e1,params,index,[],args,fieldset,flag) 208 : cchiw 3138 val body =E.ArcCosine e1' 209 : val einapp= rewriteOrig(y,body,params',index,[],args') 210 : in 211 : cchiw 3166 (einapp,code,fieldset) 212 : cchiw 3138 end 213 : 214 : (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code 215 : * calls rewriteOp() lift on ein_exp 216 : *) 217 : cchiw 3166 fun handleSine(y,e1,params,index,args,fieldset,flag)=let 218 : val (e1',params',args',code,fieldset)= rewriteOp("sine", e1,params,index,[],args,fieldset,flag) 219 : cchiw 3138 val body =E.Sine e1' 220 : val einapp= rewriteOrig(y,body,params',index,[],args') 221 : in 222 : cchiw 3166 (einapp,code,fieldset) 223 : cchiw 3138 end 224 : 225 : (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code 226 : * calls rewriteOp() lift on ein_exp 227 : *) 228 : cchiw 3166 fun handleArcSine(y,e1,params,index,args,fieldset,flag)=let 229 : val (e1',params',args',code,fieldset)= rewriteOp("ArcSine", e1,params,index,[],args,fieldset,flag) 230 : val body =E.ArcSine e1' 231 : cchiw 3033 val einapp= rewriteOrig(y,body,params',index,[],args') 232 : in 233 : cchiw 3166 (einapp,code,fieldset) 234 : cchiw 3033 end 235 : cchiw 2870 236 : 237 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 238 : cchiw 2843 * calls rewriteOps() lift on ein_exp 239 : cchiw 2838 *) 240 : cchiw 3166 fun handleSub(y,e1,e2,params,index,args,fieldset,flag)=let 241 : val ([e1',e2'],params',args',code,fieldset)= rewriteOps("subt",[e1,e2],params,index,[],args,fieldset,flag) 242 : cchiw 2838 val body =E.Sub(e1',e2') 243 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 244 : in 245 : cchiw 3166 (einapp,code,fieldset) 246 : cchiw 2845 end 247 : cchiw 2838 248 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 249 : cchiw 2843 * calls rewriteOp() lift on ein_exp 250 : cchiw 2838 *) 251 : cchiw 3166 fun handleDiv(y,e1,e2,params,index,args,fieldset,flag)=let 252 : val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag) 253 : val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag) 254 : cchiw 2838 val body =E.Div(e1',e2') 255 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2') 256 : in 257 : cchiw 3166 (einapp,code1'@code2',fieldset) 258 : cchiw 2845 end 259 : cchiw 2838 260 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 261 : cchiw 2843 * calls rewriteOps() lift on ein_exp 262 : cchiw 2838 *) 263 : cchiw 3193 fun handleAdd(y,e1 as [_,_,_,_],params,index,args,fieldset,flag)=let 264 : cchiw 3030 265 : cchiw 3193 val (e1',params',args',code,fieldset)= rewriteOps("add",e1,params,index,[],args,fieldset,flag) 266 : fun pb es=String.concatWith "\n\n\t-*-" (List.map P.printbody es) 267 : cchiw 3194 (*)val _ =print("\n****Inside Add:"^Int.toString(length index)^"\n -"^ pb e1 ^"----- newbies\n-"^ pb e1')*) 268 : cchiw 3193 269 : val body =E.Add e1' 270 : val einapp= rewriteOrig(y,body,params',index,[],args') 271 : in 272 : (einapp,code,fieldset) 273 : end 274 : | handleAdd(y,e1,params,index,args,fieldset,flag)=let 275 : 276 : cchiw 3166 val (e1',params',args',code,fieldset)= rewriteOps("add",e1,params,index,[],args,fieldset,flag) 277 : cchiw 3193 fun pb es=String.concatWith "\n-" (List.map P.printbody es) 278 : 279 : 280 : cchiw 2838 val body =E.Add e1' 281 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 282 : in 283 : cchiw 3166 (einapp,code,fieldset) 284 : cchiw 2845 end 285 : cchiw 2838 286 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 287 : cchiw 2843 * calls rewriteOps() lift on ein_exp 288 : cchiw 2838 *) 289 : cchiw 3166 fun handleProd(y,e1,params,index,args,fieldset,flag)=let 290 : val (e1',params',args',code,fieldset)= rewriteOps("prod",e1,params,index,[],args,fieldset,flag) 291 : cchiw 2838 val body =E.Prod e1' 292 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 293 : in 294 : cchiw 3166 (einapp,code,fieldset) 295 : cchiw 2845 end 296 : cchiw 2838 297 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 298 : cchiw 2843 * calls rewriteOps() lift on ein_exp 299 : cchiw 2838 *) 300 : cchiw 3166 fun handleSumProd(y,e1,params,index,sx,args,fieldset,flag)=let 301 : val (e1',params',args',code,fieldset)= rewriteOps("sumprod",e1,params,index,sx,args,fieldset,flag) 302 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 303 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args') 304 : in 305 : cchiw 3166 (einapp,code,fieldset) 306 : cchiw 2845 end 307 : cchiw 2838 308 : (* split:var*ein_app-> (var*einap)*code 309 : * split ein expression into smaller pieces 310 : cchiw 2843 note we leave summation around probe exp 311 : cchiw 2838 *) 312 : cchiw 3166 fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let 313 : val x= ((y,einapp),fieldset,flag) 314 : val zero= (setEinZero y,[],fieldset) 315 : val default=((y,einapp),[],fieldset) 316 : cchiw 2838 val sumIndex=ref [] 317 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) 318 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body] 319 : fun rewrite b=(case b 320 : cchiw 2847 of E.Probe (E.Conv _,_) => default 321 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str 322 : cchiw 2847 | E.Probe _ => raise Fail str 323 : cchiw 2838 | E.Conv _ => zero 324 : | E.Field _ => zero 325 : | E.Apply _ => zero 326 : | E.Lift e => zero 327 : | E.Delta _ => default 328 : | E.Epsilon _ => default 329 : cchiw 2843 | E.Eps2 _ => default 330 : cchiw 2838 | E.Tensor _ => default 331 : | E.Const _ => default 332 : cchiw 2923 | E.ConstR _ => default 333 : cchiw 3166 | E.Neg e1 => handleNeg(e1,x) 334 : | E.Sqrt e1 => handleSqrt(y,e1,params,index,args,fieldset,flag) 335 : | E.Cosine e1 => handleCosine(y,e1,params,index,args,fieldset,flag) 336 : | E.ArcCosine e1 => handleArcCosine(y,e1,params,index,args,fieldset,flag) 337 : | E.Sine e1 => handleSine(y,e1,params,index,args,fieldset,flag) 338 : | E.ArcSine e1 => handleArcSine(y,e1,params,index,args,fieldset,flag) 339 : | E.PowInt e1 => err(" PowInt unsupported") 340 : | E.PowReal e1 => err(" PowReal unsupported") 341 : | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args,fieldset,flag) 342 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args,fieldset,flag) 343 : cchiw 3189 (* 344 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default 345 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default 346 : cchiw 3189 *) 347 : cchiw 2847 | E.Sum(_,E.Probe(E.Conv _,_)) => default 348 : cchiw 3166 | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args,fieldset,flag) 349 : | E.Sum(sx,E.Delta d) => handleSumProd(y,[E.Delta d],params,index,sx,args,fieldset,flag) 350 : cchiw 3194 | E.Sum(sx,E.Tensor _) => default 351 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str) 352 : cchiw 3166 | E.Add e1 => handleAdd(y,e1,params,index,args,fieldset,flag) 353 : | E.Prod e1 => handleProd(y,e1,params,index,args,fieldset,flag) 354 : cchiw 2838 | E.Partial _ => err(" Partial used after normalize") 355 : | E.Krn _ => err("Krn used before expand") 356 : | E.Value _ => err("Value used before expand") 357 : | E.Img _ => err("Probe used before expand") 358 : (*end case *)) 359 : cchiw 3166 val (einapp2,newbies,fieldset) =rewrite body 360 : cchiw 2838 in 361 : cchiw 3166 ((einapp2,newbies),fieldset) 362 : cchiw 2838 end 363 : cchiw 3166 |split((y,app),fieldset,_) =(((y,app),[]),fieldset) 364 : cchiw 2923 365 : cchiw 3166 366 : fun iterMultiple(einapp2,newbies2,fieldset)=let 367 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code) 368 : cchiw 3166 | itercode(e1::newbies,rest,code,cnt)=let 369 : val ((einapp3,code3),_) = split(e1,fieldset,numFlag) 370 : val (rest4,code4)=itercode(code3,[],[],cnt+1) 371 : in 372 : itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2) 373 : end 374 : val(rest,code)= itercode(newbies2,[],[],1) 375 : in 376 : ((code)@rest@[einapp2]) 377 : end 378 : 379 : 380 : fun iterAll(einapp2,fieldset)=let 381 : fun itercode([],rest,code,_)=(rest,code) 382 : cchiw 3017 | itercode(e1::newbies,rest,code,cnt)=let 383 : cchiw 3166 val ((einapp3,code3),_) = split(e1,fieldset,numFlag) 384 : cchiw 3017 val (rest4,code4)=itercode(code3,[],[],cnt+1) 385 : cchiw 3194 val _ =testp [printEINAPP(e1),"\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP (code4@rest4)))] 386 : cchiw 3166 in 387 : itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2) 388 : cchiw 2838 end 389 : cchiw 3166 val(rest,code)= itercode(einapp2,[],[],0) 390 : cchiw 2838 in 391 : cchiw 3166 (code@rest) 392 : cchiw 2838 end 393 : 394 : cchiw 3166 fun splitEinApp einapp3= let 395 : val fieldset= einSet.EinSet.empty 396 : cchiw 3017 397 : cchiw 3166 (* **** split in parts **** *) 398 : (* 399 : val ((einapp4,newbies4),fieldset)=split(einapp3,fieldset,0) 400 : val _ =testp["\n\t===>\n",printEINAPP(einapp4),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies4))] 401 : val (newbies5)= iterMultiple(einapp4,newbies4,fieldset) 402 : *) 403 : 404 : (* **** split all at once **** *) 405 : val (newbies5)= iterAll([einapp3],fieldset) 406 : 407 : in 408 : newbies5 409 : end 410 : 411 : 412 : cchiw 2838 end; (* local *) 413 : 414 : end (* local *)