Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee_dev / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : *) 18 : cchiw 2838 19 : structure Split = struct 20 : 21 : local 22 : 23 : structure E = Ein 24 : structure DstIL = MidIL 25 : structure DstTy = MidILTypes 26 : structure DstV = DstIL.Var 27 : structure P=Printer 28 : structure cleanP=cleanParams 29 : structure cleanI=cleanIndex 30 : cchiw 2845 structure handleE=handleEin 31 : cchiw 2838 32 : in 33 : 34 : cchiw 2847 val testing=0 35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 36 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 38 : fun setEinZero y= (y,einappzero) 39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 40 : fun cleanIndex e =cleanI.cleanIndex e 41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e 42 : fun isZero e=handleE.isZero e 43 : cchiw 2870 fun sweep e=handleE.sweep e 44 : cchiw 2838 fun itos i =Int.toString i 45 : cchiw 2923 fun filterSca e=Filter.filterSca e 46 : cchiw 2838 fun err str=raise Fail str 47 : val cnt = ref 0 48 : fun genName prefix = let 49 : val n = !cnt 50 : in 51 : cnt := n+1; 52 : String.concat[prefix, "_", Int.toString n] 53 : end 54 : fun testp n=(case testing 55 : of 0=> 1 56 : | _ =>(print(String.concat n);1) 57 : (*end case*)) 58 : 59 : 60 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 61 : *lifts expression and returns replacement tensor 62 : cchiw 2843 * cleans the index and params of subexpression 63 : *creates new param and replacement tensor for the original ein_exp 64 : cchiw 2838 *) 65 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let 66 : cchiw 2867 67 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 68 : cchiw 2843 val id=length(params) 69 : val Rparams=params@[E.TEN(1,sizes)] 70 : val Re=E.Tensor(id,tshape) 71 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 72 : cchiw 2843 val Rargs=args@[M] 73 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 74 : 75 : cchiw 2838 in 76 : cchiw 2843 (Re,Rparams,Rargs,[einapp]) 77 : cchiw 2838 end 78 : 79 : (* isOp: ein->int 80 : * checks to see if this sub-expression is pulled out or split form original 81 : * 0-becomes zero,1-remains the same, 2-operator 82 : *) 83 : fun isOp e =(case e 84 : of E.Field _ => 0 85 : | E.Conv _ => 0 86 : | E.Apply _ => 0 87 : | E.Lift _ => 0 88 : | E.Neg _ => 1 89 : cchiw 2870 | E.Sqrt _ => 1 90 : cchiw 2923 | E.PowInt _ => 1 91 : | E.PowReal _ => 1 92 : cchiw 2838 | E.Add _ => 1 93 : | E.Sub _ => 1 94 : | E.Prod _ => 1 95 : | E.Div _ => 1 96 : | E.Sum _ => 1 97 : | E.Probe _ => 1 98 : | E.Partial _ => err(" Partial used after normalize") 99 : | E.Krn _ => err("Krn used before expand") 100 : | E.Value _ => err("Value used before expand") 101 : | E.Img _ => err("Probe used before expand") 102 : | _ => 2 103 : (*end case*)) 104 : 105 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 106 : * If e1 an op then call lift() to replace it 107 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same 108 : *) 109 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1) 110 : cchiw 2838 of 0 => (E.Const 0,params,args,[]) 111 : | 2 => (e1,params,args,[]) 112 : cchiw 2867 | _ => lift(name,e1,params,index,sx,args) 113 : cchiw 2838 (*end*)) 114 : 115 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars 116 : -> ein_exp list*params*args*code 117 : * calls rewriteOp on ein_exp list 118 : cchiw 2838 *) 119 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let 120 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code) 121 : | m(e1::es,rest,params,args,code)=let 122 : cchiw 3017 123 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args) 124 : cchiw 2838 in 125 : m(es,rest@[e1'],params',args',code@code') 126 : end 127 : in 128 : m(list1,[],params,args,[]) 129 : end 130 : cchiw 2843 131 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars 132 : cchiw 2843 When the operation is zero then we return a real. 133 : cchiw 2845 -Moved is Zero to before split. 134 : cchiw 2843 *) 135 : cchiw 2845 fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body) 136 : of 1=> setEinZero y 137 : cchiw 2843 | _ => cleanParams(y,body,params,index,args) 138 : cchiw 2845 (*end case*)) 139 : cchiw 2838 140 : (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 141 : cchiw 2843 * calls rewriteOp() lift on ein_exp 142 : cchiw 2838 *) 143 : fun handleNeg(y,e1,params,index,args)=let 144 : cchiw 2922 val (e1',params',args',code)= rewriteOp("neg", e1,params,index,[],args) 145 : cchiw 2838 val body =E.Neg e1' 146 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 147 : in 148 : (einapp,code) 149 : end 150 : cchiw 2838 151 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code 152 : * calls rewriteOp() lift on ein_exp 153 : *) 154 : fun handleSqrt(y,e1,params,index,args)=let 155 : cchiw 2922 val (e1',params',args',code)= rewriteOp("sqrt", e1,params,index,[],args) 156 : cchiw 2867 val body =E.Sqrt e1' 157 : val einapp= rewriteOrig(y,body,params',index,[],args') 158 : in 159 : (einapp,code) 160 : end 161 : 162 : 163 : cchiw 2870 (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code 164 : * calls rewriteOp() lift on ein_exp 165 : *) 166 : fun handlePowInt(y,(e1,n1),params,index,args)=let 167 : cchiw 2922 val (e1',params',args',code)= rewriteOp("powint", e1,params,index,[],args) 168 : cchiw 2870 val body =E.PowInt(e1',n1) 169 : val einapp= rewriteOrig(y,body,params',index,[],args') 170 : in 171 : (einapp,code) 172 : end 173 : 174 : 175 : cchiw 2922 (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code 176 : * calls rewriteOp() lift on ein_exp 177 : *) 178 : fun handlePowReal(y,(e1,n1),params,index,args)=let 179 : val (e1',params',args',code)= rewriteOp("powreal", e1,params,index,[],args) 180 : val body =E.PowReal(e1',n1) 181 : val einapp= rewriteOrig(y,body,params',index,[],args') 182 : in 183 : (einapp,code) 184 : end 185 : cchiw 2870 186 : 187 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 188 : cchiw 2843 * calls rewriteOps() lift on ein_exp 189 : cchiw 2838 *) 190 : fun handleSub(y,e1,e2,params,index,args)=let 191 : cchiw 2922 val ([e1',e2'],params',args',code)= rewriteOps("subt",[e1,e2],params,index,[],args) 192 : cchiw 2838 val body =E.Sub(e1',e2') 193 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 194 : in 195 : (einapp,code) 196 : end 197 : cchiw 2838 198 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 199 : cchiw 2843 * calls rewriteOp() lift on ein_exp 200 : cchiw 2838 *) 201 : fun handleDiv(y,e1,e2,params,index,args)=let 202 : cchiw 2923 val (e1',params1',args1',code1')=rewriteOp("div-num",e1,params,index,[],args) 203 : val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',index,[],args1') 204 : (*val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',[],[],args1')*) 205 : cchiw 2838 val body =E.Div(e1',e2') 206 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2') 207 : in 208 : (einapp,code1'@code2') 209 : end 210 : cchiw 2838 211 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 212 : cchiw 2843 * calls rewriteOps() lift on ein_exp 213 : cchiw 2838 *) 214 : fun handleAdd(y,e1,params,index,args)=let 215 : cchiw 2922 val (e1',params',args',code)= rewriteOps("add",e1,params,index,[],args) 216 : cchiw 2838 val body =E.Add e1' 217 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 218 : in 219 : (einapp,code) 220 : end 221 : cchiw 2838 222 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 223 : cchiw 2843 * calls rewriteOps() lift on ein_exp 224 : cchiw 2838 *) 225 : fun handleProd(y,e1,params,index,args)=let 226 : cchiw 2922 val (e1',params',args',code)= rewriteOps("prod",e1,params,index,[],args) 227 : cchiw 2838 val body =E.Prod e1' 228 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 229 : in 230 : (einapp,code) 231 : end 232 : cchiw 2838 233 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 234 : cchiw 2843 * calls rewriteOps() lift on ein_exp 235 : cchiw 2838 *) 236 : fun handleSumProd(y,e1,params,index,sx,args)=let 237 : cchiw 2922 val (e1',params',args',code)= rewriteOps("sumprod",e1,params,index,sx,args) 238 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 239 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args') 240 : in 241 : (einapp,code) 242 : end 243 : cchiw 2838 244 : (* split:var*ein_app-> (var*einap)*code 245 : * split ein expression into smaller pieces 246 : cchiw 2843 note we leave summation around probe exp 247 : cchiw 2838 *) 248 : fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 249 : cchiw 2843 val zero= (setEinZero y,[]) 250 : cchiw 2838 val default=((y,einapp),[]) 251 : val sumIndex=ref [] 252 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) 253 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body] 254 : fun rewrite b=(case b 255 : cchiw 2847 of E.Probe (E.Conv _,_) => default 256 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str 257 : cchiw 2847 | E.Probe _ => raise Fail str 258 : cchiw 2838 | E.Conv _ => zero 259 : | E.Field _ => zero 260 : | E.Apply _ => zero 261 : | E.Lift e => zero 262 : | E.Delta _ => default 263 : | E.Epsilon _ => default 264 : cchiw 2843 | E.Eps2 _ => default 265 : cchiw 2838 | E.Tensor _ => default 266 : | E.Const _ => default 267 : cchiw 2923 | E.ConstR _ => default 268 : cchiw 2838 | E.Neg e1 => handleNeg(y,e1,params,index,args) 269 : cchiw 2867 | E.Sqrt e1 => handleSqrt(y,e1,params,index,args) 270 : cchiw 2923 | E.PowInt e1 => handlePowInt(y,e1,params,index,args) 271 : | E.PowReal e1 => handlePowReal(y,e1,params,index,args) 272 : cchiw 2838 | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args) 273 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args) 274 : cchiw 2923 | E.Sum(sx,E.Tensor(id,[]))=> rewrite (E.Tensor(id,[])) 275 : | E.Sum(sx,E.Const c) =>rewrite ( E.Const c ) 276 : | E.Sum(sx,E.ConstR r) => rewrite (E.ConstR r) 277 : cchiw 2922 | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n))) 278 : | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a)) 279 : | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))) 280 : cchiw 3017 | E.Sum(sx,E.Div(e1,e2)) => rewrite (E.Sum(sx,E.Prod[e1,E.Div(E.Const 1,e2)])) 281 : cchiw 2922 | E.Sum(sx,E.Lift e ) => rewrite (E.Lift(E.Sum(sx,e))) 282 : | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1)) 283 : | E.Sum(sx,E.Sqrt e) => rewrite (E.Sqrt(E.Sum(sx,e))) 284 : | E.Sum(c1,E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e)) 285 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default 286 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default 287 : | E.Sum(_,E.Probe(E.Conv _,_)) => default 288 : cchiw 2838 | E.Sum(_,E.Conv _) => zero 289 : | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args) 290 : | E.Sum(sx,_) => default 291 : | E.Add e1 => handleAdd(y,e1,params,index,args) 292 : | E.Prod e1 => handleProd(y,e1,params,index,args) 293 : | E.Partial _ => err(" Partial used after normalize") 294 : | E.Krn _ => err("Krn used before expand") 295 : | E.Value _ => err("Value used before expand") 296 : | E.Img _ => err("Probe used before expand") 297 : (*end case *)) 298 : val (einapp2,newbies) =rewrite body 299 : in 300 : (einapp2,newbies) 301 : end 302 : |split(y,app) =((y,app),[]) 303 : cchiw 2923 304 : 305 : (*Distribute summation if needed*) 306 : fun distributeSummation(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 307 : cchiw 3017 fun rewrite b=(case b 308 : cchiw 2923 of E.Sum(sx,E.Tensor(id,[])) => E.Tensor(id,[]) 309 : | E.Sum(sx,E.Const c) => E.Const c 310 : | E.Sum(sx,E.ConstR r) => E.ConstR r 311 : | E.Sum(sx,E.Neg n) => rewrite(E.Neg(E.Sum(sx,n))) 312 : | E.Sum(sx,E.Add a) => rewrite(E.Add(List.map (fn e=> E.Sum(sx,e)) a)) 313 : | E.Sum(sx,E.Sub (e1,e2)) => rewrite(E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))) 314 : cchiw 3017 (* | E.Sum(sx,E.Div(e1,e2)) => rewrite( E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))*) 315 : cchiw 2923 | E.Sum(sx,E.Lift e ) => rewrite (E.Lift(E.Sum(sx,e))) 316 : | E.Sum(sx,E.PowReal(e,n1)) => rewrite(E.PowReal(E.Sum(sx,e),n1)) 317 : | E.Sum(sx,E.Sqrt e) => rewrite(E.Sqrt(E.Sum(sx,e))) 318 : | E.Sum(sx,E.Sum (c2,e)) => rewrite (E.Sum (sx@c2,e)) 319 : | E.Sum(sx,E.Prod p) => let 320 : val (c,e)=filterSca(sx,p) 321 : cchiw 3017 in e end 322 : cchiw 2923 | E.Div(e1,e2) => E.Div(rewrite e1, rewrite e2) 323 : | E.Sub(e1,e2) => E.Sub(rewrite e1, rewrite e2) 324 : | E.Add es => E.Add(List.map rewrite es) 325 : | E.Prod es => E.Prod(List.map rewrite es) 326 : | E.Neg e => E.Neg(rewrite e) 327 : | E.Sqrt e => E.Sqrt(rewrite e) 328 : | _ => b 329 : (*end case*)) 330 : val body =rewrite body 331 : cchiw 3017 val _ =testp["\nAfter distributeSummation \n",P.printbody body] 332 : cchiw 2923 val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=body}) 333 : val einapp2= (y,DstIL.EINAPP(ein,args)) 334 : in 335 : split(einapp2) 336 : end 337 : |distributeSummation(y,app) =((y,app),[]) 338 : 339 : 340 : 341 : cchiw 2838 (* iterMultiple:code*code=> (code*code) 342 : * recursively split ein expression into smaller pieces 343 : *) 344 : fun iterMultiple(einapp2,newbies2)=let 345 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code) 346 : | itercode(e1::newbies,rest,code,cnt)=let 347 : val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"] 348 : val (einapp3,code3) = distributeSummation e1 349 : val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))] 350 : val (rest4,code4)=itercode(code3,[],[],cnt+1) 351 : in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2) 352 : cchiw 2838 end 353 : cchiw 3017 val(rest,code)= itercode(newbies2,[],[],1) 354 : cchiw 2838 in 355 : (einapp2,code@rest) 356 : end 357 : 358 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 359 : cchiw 2923 val bodysweep=handleE.sweep body 360 : cchiw 3017 val _=testp["\nPresweep\n",P.printbody body,"\n\n Sweep\n",P.printbody bodysweep,"\n"] 361 : cchiw 2923 val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=bodysweep}) 362 : val _=testp["\n\n Clean Summation\n",P.printbody(Ein.body ein),"\n"] 363 : val einapp2=(y,DstIL.EINAPP(ein, args)) 364 : cchiw 3017 val _ =testp["\n\n******* split term **",Int.toString (0)," ***** \n \t==>\n",printEINAPP(einapp2)] 365 : cchiw 2923 val (einapp3,newbies2)=distributeSummation einapp2 366 : cchiw 3017 val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies2))] 367 : 368 : cchiw 2838 in 369 : cchiw 2870 iterMultiple(einapp3,newbies2) 370 : cchiw 2838 end 371 : 372 : (* gettest:code*code=> (code*code) 373 : * print results for splitting einapp 374 : *) 375 : cchiw 2845 fun gettest einapp=(case testing 376 : cchiw 2838 of 0=>iterSplit(einapp) 377 : | _=>let 378 : cchiw 3017 val star="\n************* SPLIT INITIAL********\n" 379 : cchiw 2923 val _ =testp[star,"\n","start get test",printEINAPP einapp] 380 : cchiw 2838 val (einapp2,newbies)=iterSplit(einapp) 381 : cchiw 2923 val _ =testp["\n\n Returning \n\n =>",printEINAPP einapp2, 382 : " newbies\n\t",String.concatWith",\n\t"(List.map printEINAPP newbies), "\n",star] 383 : cchiw 2838 in 384 : (einapp2,newbies) 385 : end 386 : (*end case*)) 387 : 388 : end; (* local *) 389 : 390 : end (* local *)