Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee_dev / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee_dev/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : *) 18 : cchiw 2838 19 : structure Split = struct 20 : 21 : local 22 : 23 : structure E = Ein 24 : structure DstIL = MidIL 25 : structure DstTy = MidILTypes 26 : structure DstV = DstIL.Var 27 : structure P=Printer 28 : structure cleanP=cleanParams 29 : structure cleanI=cleanIndex 30 : cchiw 2845 structure handleE=handleEin 31 : cchiw 2838 32 : in 33 : 34 : cchiw 2847 val testing=0 35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 36 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 38 : fun setEinZero y= (y,einappzero) 39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 40 : fun cleanIndex e =cleanI.cleanIndex e 41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e 42 : fun isZero e=handleE.isZero e 43 : cchiw 2870 fun sweep e=handleE.sweep e 44 : cchiw 2838 fun itos i =Int.toString i 45 : fun err str=raise Fail str 46 : val cnt = ref 0 47 : fun genName prefix = let 48 : val n = !cnt 49 : in 50 : cnt := n+1; 51 : String.concat[prefix, "_", Int.toString n] 52 : end 53 : fun testp n=(case testing 54 : of 0=> 1 55 : | _ =>(print(String.concat n);1) 56 : (*end case*)) 57 : 58 : 59 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 60 : *lifts expression and returns replacement tensor 61 : cchiw 2843 * cleans the index and params of subexpression 62 : *creates new param and replacement tensor for the original ein_exp 63 : cchiw 2838 *) 64 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let 65 : cchiw 2867 66 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 67 : cchiw 2843 val id=length(params) 68 : val Rparams=params@[E.TEN(1,sizes)] 69 : val Re=E.Tensor(id,tshape) 70 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 71 : cchiw 2843 val Rargs=args@[M] 72 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 73 : 74 : cchiw 2838 in 75 : cchiw 2843 (Re,Rparams,Rargs,[einapp]) 76 : cchiw 2838 end 77 : 78 : (* isOp: ein->int 79 : * checks to see if this sub-expression is pulled out or split form original 80 : * 0-becomes zero,1-remains the same, 2-operator 81 : *) 82 : fun isOp e =(case e 83 : of E.Field _ => 0 84 : | E.Conv _ => 0 85 : | E.Apply _ => 0 86 : | E.Lift _ => 0 87 : | E.Neg _ => 1 88 : cchiw 2870 | E.Sqrt _ => 1 89 : | E.PowInt _ => 1 90 : | E.PowReal _ => 1 91 : cchiw 2838 | E.Add _ => 1 92 : | E.Sub _ => 1 93 : | E.Prod _ => 1 94 : | E.Div _ => 1 95 : | E.Sum _ => 1 96 : | E.Probe _ => 1 97 : | E.Partial _ => err(" Partial used after normalize") 98 : | E.Krn _ => err("Krn used before expand") 99 : | E.Value _ => err("Value used before expand") 100 : | E.Img _ => err("Probe used before expand") 101 : | _ => 2 102 : (*end case*)) 103 : 104 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 105 : * If e1 an op then call lift() to replace it 106 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same 107 : *) 108 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1) 109 : cchiw 2838 of 0 => (E.Const 0,params,args,[]) 110 : | 2 => (e1,params,args,[]) 111 : cchiw 2867 | _ => lift(name,e1,params,index,sx,args) 112 : cchiw 2838 (*end*)) 113 : 114 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars 115 : -> ein_exp list*params*args*code 116 : * calls rewriteOp on ein_exp list 117 : cchiw 2838 *) 118 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let 119 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code) 120 : | m(e1::es,rest,params,args,code)=let 121 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args) 122 : cchiw 2838 in 123 : m(es,rest@[e1'],params',args',code@code') 124 : end 125 : in 126 : m(list1,[],params,args,[]) 127 : end 128 : cchiw 2843 129 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars 130 : cchiw 2843 When the operation is zero then we return a real. 131 : cchiw 2845 -Moved is Zero to before split. 132 : cchiw 2843 *) 133 : cchiw 2845 fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body) 134 : of 1=> setEinZero y 135 : cchiw 2843 | _ => cleanParams(y,body,params,index,args) 136 : cchiw 2845 (*end case*)) 137 : cchiw 2838 138 : (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 139 : cchiw 2843 * calls rewriteOp() lift on ein_exp 140 : cchiw 2838 *) 141 : fun handleNeg(y,e1,params,index,args)=let 142 : cchiw 2845 val (e1',params',args',code)= rewriteOp(DstV.name y, e1,params,index,[],args) 143 : cchiw 2838 val body =E.Neg e1' 144 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 145 : in 146 : (einapp,code) 147 : end 148 : cchiw 2838 149 : cchiw 2867 (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code 150 : * calls rewriteOp() lift on ein_exp 151 : *) 152 : fun handleSqrt(y,e1,params,index,args)=let 153 : val (e1',params',args',code)= rewriteOp(DstV.name y, e1,params,index,[],args) 154 : val body =E.Sqrt e1' 155 : val einapp= rewriteOrig(y,body,params',index,[],args') 156 : in 157 : (einapp,code) 158 : end 159 : 160 : 161 : cchiw 2870 (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code 162 : * calls rewriteOp() lift on ein_exp 163 : *) 164 : fun handlePowInt(y,(e1,n1),params,index,args)=let 165 : val (e1',params',args',code)= rewriteOp(DstV.name y, e1,params,index,[],args) 166 : val body =E.PowInt(e1',n1) 167 : val einapp= rewriteOrig(y,body,params',index,[],args') 168 : in 169 : (einapp,code) 170 : end 171 : 172 : 173 : (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code 174 : * calls rewriteOp() lift on ein_exp 175 : *) 176 : fun handlePowReal(y,(e1,n1),params,index,args)=let 177 : val (e1',params',args',code)= rewriteOp(DstV.name y, e1,params,index,[],args) 178 : val body =E.PowReal(e1',n1) 179 : val einapp= rewriteOrig(y,body,params',index,[],args') 180 : in 181 : (einapp,code) 182 : end 183 : 184 : 185 : cchiw 2838 (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 186 : cchiw 2843 * calls rewriteOps() lift on ein_exp 187 : cchiw 2838 *) 188 : fun handleSub(y,e1,e2,params,index,args)=let 189 : cchiw 2845 val ([e1',e2'],params',args',code)= rewriteOps(DstV.name y,[e1,e2],params,index,[],args) 190 : cchiw 2838 val body =E.Sub(e1',e2') 191 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 192 : in 193 : (einapp,code) 194 : end 195 : cchiw 2838 196 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 197 : cchiw 2843 * calls rewriteOp() lift on ein_exp 198 : cchiw 2838 *) 199 : fun handleDiv(y,e1,e2,params,index,args)=let 200 : cchiw 2845 val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args) 201 : val (e2',params2',args2',code2')=rewriteOp(DstV.name y,e2,params1',[],[],args1') 202 : cchiw 2838 val body =E.Div(e1',e2') 203 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2') 204 : in 205 : (einapp,code1'@code2') 206 : end 207 : cchiw 2838 208 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 209 : cchiw 2843 * calls rewriteOps() lift on ein_exp 210 : cchiw 2838 *) 211 : fun handleAdd(y,e1,params,index,args)=let 212 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,[],args) 213 : cchiw 2838 val body =E.Add e1' 214 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 215 : in 216 : (einapp,code) 217 : end 218 : cchiw 2838 219 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 220 : cchiw 2843 * calls rewriteOps() lift on ein_exp 221 : cchiw 2838 *) 222 : fun handleProd(y,e1,params,index,args)=let 223 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,[],args) 224 : cchiw 2838 val body =E.Prod e1' 225 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 226 : in 227 : (einapp,code) 228 : end 229 : cchiw 2838 230 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 231 : cchiw 2843 * calls rewriteOps() lift on ein_exp 232 : cchiw 2838 *) 233 : fun handleSumProd(y,e1,params,index,sx,args)=let 234 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,sx,args) 235 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 236 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args') 237 : in 238 : (einapp,code) 239 : end 240 : cchiw 2838 241 : (* split:var*ein_app-> (var*einap)*code 242 : * split ein expression into smaller pieces 243 : cchiw 2843 note we leave summation around probe exp 244 : cchiw 2838 *) 245 : fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 246 : cchiw 2870 247 : cchiw 2843 val zero= (setEinZero y,[]) 248 : cchiw 2838 val default=((y,einapp),[]) 249 : val sumIndex=ref [] 250 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) 251 : cchiw 2838 fun rewrite b=(case b 252 : cchiw 2847 of E.Probe (E.Conv _,_) => default 253 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str 254 : cchiw 2847 | E.Probe _ => raise Fail str 255 : cchiw 2838 | E.Conv _ => zero 256 : | E.Field _ => zero 257 : | E.Apply _ => zero 258 : | E.Lift e => zero 259 : | E.Delta _ => default 260 : | E.Epsilon _ => default 261 : cchiw 2843 | E.Eps2 _ => default 262 : cchiw 2838 | E.Tensor _ => default 263 : | E.Const _ => default 264 : cchiw 2870 | E.ConstR _ => default 265 : cchiw 2838 | E.Neg e1 => handleNeg(y,e1,params,index,args) 266 : cchiw 2867 | E.Sqrt e1 => handleSqrt(y,e1,params,index,args) 267 : cchiw 2870 | E.PowInt e1 => handlePowInt(y,e1,params,index,args) 268 : | E.PowReal e1 => handlePowReal(y,e1,params,index,args) 269 : cchiw 2838 | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args) 270 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args) 271 : cchiw 2847 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default 272 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default 273 : | E.Sum(_,E.Probe(E.Conv _,_)) => default 274 : cchiw 2838 | E.Sum(_,E.Conv _) => zero 275 : | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args) 276 : | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n))) 277 : | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a)) 278 : | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))) 279 : | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2))) 280 : | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e)) 281 : | E.Sum(sx,_) => default 282 : | E.Add e1 => handleAdd(y,e1,params,index,args) 283 : | E.Prod e1 => handleProd(y,e1,params,index,args) 284 : | E.Partial _ => err(" Partial used after normalize") 285 : | E.Krn _ => err("Krn used before expand") 286 : | E.Value _ => err("Value used before expand") 287 : | E.Img _ => err("Probe used before expand") 288 : (*end case *)) 289 : val (einapp2,newbies) =rewrite body 290 : in 291 : (einapp2,newbies) 292 : end 293 : |split(y,app) =((y,app),[]) 294 : 295 : (* iterMultiple:code*code=> (code*code) 296 : * recursively split ein expression into smaller pieces 297 : *) 298 : fun iterMultiple(einapp2,newbies2)=let 299 : fun itercode([],rest,code)=(rest,code) 300 : | itercode(e1::newbies,rest,code)=let 301 : val (einapp3,code3) =split(e1) 302 : val (rest4,code4)=itercode(code3,[],[]) 303 : in itercode(newbies,rest@[einapp3],code4@rest4@code) 304 : end 305 : val(rest,code)= itercode(newbies2,[],[]) 306 : in 307 : (einapp2,code@rest) 308 : end 309 : 310 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 311 : cchiw 2870 (*)val _ =print(String.concat["\n******\n",printEINAPP (y,einapp)])*) 312 : val b=handleE.sweep body 313 : val einapp2=assignEinApp(y,params,index,b,args) 314 : val (einapp3,newbies2)=split einapp2 315 : cchiw 2838 in 316 : cchiw 2870 iterMultiple(einapp3,newbies2) 317 : cchiw 2838 end 318 : 319 : (* gettest:code*code=> (code*code) 320 : * print results for splitting einapp 321 : *) 322 : cchiw 2845 fun gettest einapp=(case testing 323 : cchiw 2838 of 0=>iterSplit(einapp) 324 : | _=>let 325 : cchiw 2845 val star="\n************* SPLIT********\n" 326 : val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp]) 327 : cchiw 2838 val (einapp2,newbies)=iterSplit(einapp) 328 : cchiw 2870 (*val a=printEINAPP einapp2 329 : cchiw 2838 val b=String.concatWith",\n\t"(List.map printEINAPP newbies) 330 : cchiw 2870 val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])*) 331 : cchiw 2838 in 332 : (einapp2,newbies) 333 : end 334 : (*end case*)) 335 : 336 : end; (* local *) 337 : 338 : end (* local *)