Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : *) 18 : cchiw 2838 19 : structure Split = struct 20 : 21 : local 22 : 23 : structure E = Ein 24 : structure DstIL = MidIL 25 : structure DstTy = MidILTypes 26 : structure DstV = DstIL.Var 27 : structure P=Printer 28 : structure cleanP=cleanParams 29 : structure cleanI=cleanIndex 30 : 31 : cchiw 3033 32 : cchiw 2838 in 33 : 34 : cchiw 3033 val testing=0 35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 36 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 38 : fun setEinZero y= (y,einappzero) 39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 40 : fun cleanIndex e =cleanI.cleanIndex e 41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e 42 : cchiw 2838 fun itos i =Int.toString i 43 : cchiw 2923 fun filterSca e=Filter.filterSca e 44 : cchiw 2838 fun err str=raise Fail str 45 : val cnt = ref 0 46 : fun genName prefix = let 47 : val n = !cnt 48 : in 49 : cnt := n+1; 50 : String.concat[prefix, "_", Int.toString n] 51 : end 52 : fun testp n=(case testing 53 : of 0=> 1 54 : | _ =>(print(String.concat n);1) 55 : (*end case*)) 56 : 57 : 58 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 59 : *lifts expression and returns replacement tensor 60 : cchiw 2843 * cleans the index and params of subexpression 61 : *creates new param and replacement tensor for the original ein_exp 62 : cchiw 2838 *) 63 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let 64 : cchiw 2867 65 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 66 : cchiw 2843 val id=length(params) 67 : val Rparams=params@[E.TEN(1,sizes)] 68 : val Re=E.Tensor(id,tshape) 69 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 70 : cchiw 2843 val Rargs=args@[M] 71 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 72 : 73 : cchiw 2838 in 74 : cchiw 2843 (Re,Rparams,Rargs,[einapp]) 75 : cchiw 2838 end 76 : 77 : (* isOp: ein->int 78 : * checks to see if this sub-expression is pulled out or split form original 79 : * 0-becomes zero,1-remains the same, 2-operator 80 : *) 81 : fun isOp e =(case e 82 : of E.Field _ => 0 83 : | E.Conv _ => 0 84 : | E.Apply _ => 0 85 : | E.Lift _ => 0 86 : | E.Neg _ => 1 87 : cchiw 2870 | E.Sqrt _ => 1 88 : cchiw 3138 | E.Cosine _ => 1 89 : | E.ArcCosine _ => 1 90 : | E.Sine _ => 1 91 : | E.ArcSine _ => 1 92 : | E.PowInt _ => 1 93 : | E.PowReal _ => 1 94 : cchiw 2838 | E.Add _ => 1 95 : | E.Sub _ => 1 96 : | E.Prod _ => 1 97 : | E.Div _ => 1 98 : | E.Sum _ => 1 99 : | E.Probe _ => 1 100 : | E.Partial _ => err(" Partial used after normalize") 101 : | E.Krn _ => err("Krn used before expand") 102 : | E.Value _ => err("Value used before expand") 103 : | E.Img _ => err("Probe used before expand") 104 : | _ => 2 105 : (*end case*)) 106 : 107 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 108 : * If e1 an op then call lift() to replace it 109 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same 110 : *) 111 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1) 112 : cchiw 2838 of 0 => (E.Const 0,params,args,[]) 113 : | 2 => (e1,params,args,[]) 114 : cchiw 2867 | _ => lift(name,e1,params,index,sx,args) 115 : cchiw 2838 (*end*)) 116 : 117 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars 118 : -> ein_exp list*params*args*code 119 : * calls rewriteOp on ein_exp list 120 : cchiw 2838 *) 121 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let 122 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code) 123 : | m(e1::es,rest,params,args,code)=let 124 : cchiw 3017 125 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args) 126 : cchiw 3033 (*val _ =testp["rewriteOP:\n",P.printbody e1,"\n\t=>", P.printbody e1']*) 127 : cchiw 2838 in 128 : m(es,rest@[e1'],params',args',code@code') 129 : end 130 : in 131 : m(list1,[],params,args,[]) 132 : end 133 : cchiw 2843 134 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars 135 : cchiw 2843 When the operation is zero then we return a real. 136 : cchiw 2845 -Moved is Zero to before split. 137 : cchiw 2843 *) 138 : cchiw 3033 fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args) 139 : cchiw 2838 140 : (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 141 : cchiw 2843 * calls rewriteOp() lift on ein_exp 142 : cchiw 2838 *) 143 : fun handleNeg(y,e1,params,index,args)=let 144 : cchiw 2922 val (e1',params',args',code)= rewriteOp("neg", e1,params,index,[],args) 145 : cchiw 2838 val body =E.Neg e1' 146 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 147 : in 148 : (einapp,code) 149 : end 150 : cchiw 2838 151 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code 152 : * calls rewriteOp() lift on ein_exp 153 : *) 154 : fun handleSqrt(y,e1,params,index,args)=let 155 : cchiw 2922 val (e1',params',args',code)= rewriteOp("sqrt", e1,params,index,[],args) 156 : cchiw 2867 val body =E.Sqrt e1' 157 : val einapp= rewriteOrig(y,body,params',index,[],args') 158 : in 159 : (einapp,code) 160 : end 161 : 162 : 163 : cchiw 3138 (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code 164 : * calls rewriteOp() lift on ein_exp 165 : *) 166 : fun handleCosine(y,e1,params,index,args)=let 167 : val (e1',params',args',code)= rewriteOp("cosine", e1,params,index,[],args) 168 : val body =E.Cosine e1' 169 : val einapp= rewriteOrig(y,body,params',index,[],args') 170 : in 171 : (einapp,code) 172 : end 173 : 174 : (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code 175 : * calls rewriteOp() lift on ein_exp 176 : *) 177 : fun handleArcCosine(y,e1,params,index,args)=let 178 : val (e1',params',args',code)= rewriteOp("ArcCosine", e1,params,index,[],args) 179 : val body =E.ArcCosine e1' 180 : val einapp= rewriteOrig(y,body,params',index,[],args') 181 : in 182 : (einapp,code) 183 : end 184 : 185 : (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code 186 : * calls rewriteOp() lift on ein_exp 187 : *) 188 : fun handleSine(y,e1,params,index,args)=let 189 : val (e1',params',args',code)= rewriteOp("sine", e1,params,index,[],args) 190 : val body =E.Sine e1' 191 : val einapp= rewriteOrig(y,body,params',index,[],args') 192 : in 193 : (einapp,code) 194 : end 195 : 196 : (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code 197 : * calls rewriteOp() lift on ein_exp 198 : *) 199 : fun handleArcSine(y,e1,params,index,args)=let 200 : val (e1',params',args',code)= rewriteOp("ArcSine", e1,params,index,[],args) 201 : val body =E.ArcSine e1' 202 : val einapp= rewriteOrig(y,body,params',index,[],args') 203 : in 204 : (einapp,code) 205 : end 206 : 207 : 208 : cchiw 3033 (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code 209 : * calls rewriteOp() lift on ein_exp 210 : *) 211 : fun handlePowInt(y,(e1,n1),params,index,args)=let 212 : val (e1',params',args',code)= rewriteOp("powint", e1,params,index,[],args) 213 : val body =E.PowInt(e1',n1) 214 : val einapp= rewriteOrig(y,body,params',index,[],args') 215 : in 216 : (einapp,code) 217 : end 218 : cchiw 2870 219 : 220 : cchiw 2922 (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code 221 : * calls rewriteOp() lift on ein_exp 222 : *) 223 : fun handlePowReal(y,(e1,n1),params,index,args)=let 224 : val (e1',params',args',code)= rewriteOp("powreal", e1,params,index,[],args) 225 : val body =E.PowReal(e1',n1) 226 : val einapp= rewriteOrig(y,body,params',index,[],args') 227 : in 228 : (einapp,code) 229 : end 230 : cchiw 2870 231 : 232 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 233 : cchiw 2843 * calls rewriteOps() lift on ein_exp 234 : cchiw 2838 *) 235 : fun handleSub(y,e1,e2,params,index,args)=let 236 : cchiw 2922 val ([e1',e2'],params',args',code)= rewriteOps("subt",[e1,e2],params,index,[],args) 237 : cchiw 2838 val body =E.Sub(e1',e2') 238 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 239 : in 240 : (einapp,code) 241 : end 242 : cchiw 2838 243 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 244 : cchiw 2843 * calls rewriteOp() lift on ein_exp 245 : cchiw 2838 *) 246 : fun handleDiv(y,e1,e2,params,index,args)=let 247 : cchiw 2923 val (e1',params1',args1',code1')=rewriteOp("div-num",e1,params,index,[],args) 248 : val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',index,[],args1') 249 : (*val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',[],[],args1')*) 250 : cchiw 2838 val body =E.Div(e1',e2') 251 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2') 252 : in 253 : (einapp,code1'@code2') 254 : end 255 : cchiw 2838 256 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 257 : cchiw 2843 * calls rewriteOps() lift on ein_exp 258 : cchiw 2838 *) 259 : fun handleAdd(y,e1,params,index,args)=let 260 : cchiw 3030 261 : cchiw 2922 val (e1',params',args',code)= rewriteOps("add",e1,params,index,[],args) 262 : cchiw 2838 val body =E.Add e1' 263 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 264 : in 265 : (einapp,code) 266 : end 267 : cchiw 2838 268 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 269 : cchiw 2843 * calls rewriteOps() lift on ein_exp 270 : cchiw 2838 *) 271 : fun handleProd(y,e1,params,index,args)=let 272 : cchiw 2922 val (e1',params',args',code)= rewriteOps("prod",e1,params,index,[],args) 273 : cchiw 2838 val body =E.Prod e1' 274 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 275 : in 276 : (einapp,code) 277 : end 278 : cchiw 2838 279 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 280 : cchiw 2843 * calls rewriteOps() lift on ein_exp 281 : cchiw 2838 *) 282 : fun handleSumProd(y,e1,params,index,sx,args)=let 283 : cchiw 2922 val (e1',params',args',code)= rewriteOps("sumprod",e1,params,index,sx,args) 284 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 285 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args') 286 : in 287 : (einapp,code) 288 : end 289 : cchiw 2838 290 : (* split:var*ein_app-> (var*einap)*code 291 : * split ein expression into smaller pieces 292 : cchiw 2843 note we leave summation around probe exp 293 : cchiw 2838 *) 294 : fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 295 : cchiw 2843 val zero= (setEinZero y,[]) 296 : cchiw 2838 val default=((y,einapp),[]) 297 : val sumIndex=ref [] 298 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) 299 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body] 300 : fun rewrite b=(case b 301 : cchiw 2847 of E.Probe (E.Conv _,_) => default 302 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str 303 : cchiw 2847 | E.Probe _ => raise Fail str 304 : cchiw 2838 | E.Conv _ => zero 305 : | E.Field _ => zero 306 : | E.Apply _ => zero 307 : | E.Lift e => zero 308 : | E.Delta _ => default 309 : | E.Epsilon _ => default 310 : cchiw 2843 | E.Eps2 _ => default 311 : cchiw 2838 | E.Tensor _ => default 312 : | E.Const _ => default 313 : cchiw 2923 | E.ConstR _ => default 314 : cchiw 2838 | E.Neg e1 => handleNeg(y,e1,params,index,args) 315 : cchiw 2867 | E.Sqrt e1 => handleSqrt(y,e1,params,index,args) 316 : cchiw 3138 | E.Cosine e1 => handleCosine(y,e1,params,index,args) 317 : | E.ArcCosine e1 => handleArcCosine(y,e1,params,index,args) 318 : | E.Sine e1 => handleSine(y,e1,params,index,args) 319 : | E.ArcSine e1 => handleArcSine(y,e1,params,index,args) 320 : cchiw 2923 | E.PowInt e1 => handlePowInt(y,e1,params,index,args) 321 : | E.PowReal e1 => handlePowReal(y,e1,params,index,args) 322 : cchiw 2838 | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args) 323 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args) 324 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default 325 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default 326 : | E.Sum(_,E.Probe(E.Conv _,_)) => default 327 : cchiw 2838 | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args) 328 : cchiw 3138 | E.Sum(sx,E.Delta d) => handleSumProd(y,[E.Delta d],params,index,sx,args) 329 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str) 330 : cchiw 2838 | E.Add e1 => handleAdd(y,e1,params,index,args) 331 : | E.Prod e1 => handleProd(y,e1,params,index,args) 332 : | E.Partial _ => err(" Partial used after normalize") 333 : | E.Krn _ => err("Krn used before expand") 334 : | E.Value _ => err("Value used before expand") 335 : | E.Img _ => err("Probe used before expand") 336 : (*end case *)) 337 : val (einapp2,newbies) =rewrite body 338 : in 339 : (einapp2,newbies) 340 : end 341 : |split(y,app) =((y,app),[]) 342 : cchiw 2923 343 : cchiw 2838 (* iterMultiple:code*code=> (code*code) 344 : * recursively split ein expression into smaller pieces 345 : *) 346 : fun iterMultiple(einapp2,newbies2)=let 347 : cchiw 3017 fun itercode([],rest,code,_)=(rest,code) 348 : | itercode(e1::newbies,rest,code,cnt)=let 349 : val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"] 350 : cchiw 3033 val (einapp3,code3) = split e1 351 : cchiw 3017 val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))] 352 : val (rest4,code4)=itercode(code3,[],[],cnt+1) 353 : in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2) 354 : cchiw 2838 end 355 : cchiw 3017 val(rest,code)= itercode(newbies2,[],[],1) 356 : cchiw 2838 in 357 : (einapp2,code@rest) 358 : end 359 : 360 : cchiw 3017 361 : cchiw 2838 end; (* local *) 362 : 363 : end (* local *)