Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : (2) All the lifted subexpressions in the original EIN operator are replaced with tensors and non-probed fields with zeros. Call isZero() to determine if the body is zero. If so, needs to return 0. Otherwise clean the EIN operator. 18 : 19 : *) 20 : cchiw 2838 21 : structure Split = struct 22 : 23 : local 24 : 25 : structure E = Ein 26 : structure mk= mkOperators 27 : structure SrcIL = HighIL 28 : structure SrcTy = HighILTypes 29 : structure SrcOp = HighOps 30 : structure SrcSV = SrcIL.StateVar 31 : structure VTbl = SrcIL.Var.Tbl 32 : structure DstIL = MidIL 33 : structure DstTy = MidILTypes 34 : structure DstOp = MidOps 35 : structure DstV = DstIL.Var 36 : structure SrcV = SrcIL.Var 37 : structure P=Printer 38 : structure F=Filter 39 : structure T=TransformEin 40 : structure Var = MidIL.Var 41 : structure cleanP=cleanParams 42 : structure cleanI=cleanIndex 43 : 44 : val testing=1 45 : in 46 : 47 : 48 : fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 49 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 50 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 51 : fun setEinZero y= (y,einappzero) 52 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 53 : fun cleanIndex e =cleanI.cleanIndex e 54 : fun itos i =Int.toString i 55 : fun err str=raise Fail str 56 : val cnt = ref 0 57 : fun genName prefix = let 58 : val n = !cnt 59 : in 60 : cnt := n+1; 61 : String.concat[prefix, "_", Int.toString n] 62 : end 63 : fun testp n=(case testing 64 : of 0=> 1 65 : | _ =>(print(String.concat n);1) 66 : (*end case*)) 67 : 68 : fun printEINAPP(id, DstIL.EINAPP(rator, args))=let 69 : val a=String.concatWith " , " (List.map Var.toString args) 70 : in 71 : String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> ==",P.printerE rator, a,"\n"]) 72 : end 73 : | printEINAPP(id, DstIL.OP(rator, args))=let 74 : val a=String.concatWith " , " (List.map Var.toString args) 75 : in 76 : String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> =",DstOp.toString rator,a,"\n"]) 77 : end 78 : 79 : | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"]) 80 : 81 : 82 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 83 : *lifts expression and returns replacement tensor 84 : cchiw 2843 * cleans the index and params of subexpression 85 : *creates new param and replacement tensor for the original ein_exp 86 : cchiw 2838 *) 87 : fun lift(e,params,index,sx,args)=let 88 : val (tshape,sizes,body)=cleanIndex(e,index,sx) 89 : cchiw 2843 val id=length(params) 90 : val Rparams=params@[E.TEN(1,sizes)] 91 : val Re=E.Tensor(id,tshape) 92 : val M = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes) 93 : val Rargs=args@[M] 94 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 95 : 96 : cchiw 2838 in 97 : cchiw 2843 (Re,Rparams,Rargs,[einapp]) 98 : cchiw 2838 end 99 : 100 : 101 : (* isOp: ein->int 102 : * checks to see if this sub-expression is pulled out or split form original 103 : * 0-becomes zero,1-remains the same, 2-operator 104 : *) 105 : fun isOp e =(case e 106 : of E.Field _ => 0 107 : | E.Conv _ => 0 108 : | E.Apply _ => 0 109 : | E.Lift _ => 0 110 : | E.Neg _ => 1 111 : | E.Add _ => 1 112 : | E.Sub _ => 1 113 : | E.Prod _ => 1 114 : | E.Div _ => 1 115 : | E.Sum _ => 1 116 : | E.Probe _ => 1 117 : | E.Partial _ => err(" Partial used after normalize") 118 : | E.Krn _ => err("Krn used before expand") 119 : | E.Value _ => err("Value used before expand") 120 : | E.Img _ => err("Probe used before expand") 121 : | _ => 2 122 : (*end case*)) 123 : 124 : 125 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 126 : * If e1 an op then call lift() to replace it 127 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same 128 : *) 129 : cchiw 2843 fun rewriteOp(e1,params,index,sx,args)=(case (isOp e1) 130 : cchiw 2838 of 0 => (E.Const 0,params,args,[]) 131 : | 2 => (e1,params,args,[]) 132 : cchiw 2843 | _ => lift(e1,params,index,sx,args) 133 : cchiw 2838 (*end*)) 134 : 135 : 136 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars 137 : -> ein_exp list*params*args*code 138 : * calls rewriteOp on ein_exp list 139 : cchiw 2838 *) 140 : cchiw 2843 fun rewriteOps(list1,params,index,sx,args)=let 141 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code) 142 : | m(e1::es,rest,params,args,code)=let 143 : cchiw 2843 val (e1',params',args',code')= rewriteOp(e1,params,index,sx,args) 144 : cchiw 2838 in 145 : m(es,rest@[e1'],params',args',code@code') 146 : end 147 : in 148 : m(list1,[],params,args,[]) 149 : end 150 : cchiw 2843 151 : (*isZero: var* ein_exp* params*index list*mid-il vars 152 : When the operation is zero then we return a real. 153 : *) 154 : fun isZero(y,body,params,index,sx,args) =(case (cleanP.isZero body) 155 : of 1=> setEinZero y 156 : | _ => cleanParams(y,body,params,index,args) 157 : cchiw 2838 (*end case*)) 158 : 159 : (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 160 : cchiw 2843 * calls rewriteOp() lift on ein_exp 161 : cchiw 2838 *) 162 : fun handleNeg(y,e1,params,index,args)=let 163 : cchiw 2843 val (e1',params',args',code)= rewriteOp(e1,params,index,[],args) 164 : cchiw 2838 val body =E.Neg e1' 165 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args') 166 : cchiw 2838 in 167 : (einapp,code) 168 : end 169 : 170 : (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 171 : cchiw 2843 * calls rewriteOps() lift on ein_exp 172 : cchiw 2838 *) 173 : fun handleSub(y,e1,e2,params,index,args)=let 174 : cchiw 2843 val ([e1',e2'],params',args',code)= rewriteOps([e1,e2],params,index,[],args) 175 : cchiw 2838 val body =E.Sub(e1',e2') 176 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args') 177 : cchiw 2838 in 178 : (einapp,code) 179 : end 180 : 181 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 182 : cchiw 2843 * calls rewriteOp() lift on ein_exp 183 : cchiw 2838 *) 184 : fun handleDiv(y,e1,e2,params,index,args)=let 185 : cchiw 2843 val (e1',params1',args1',code1')=rewriteOp(e1,params,index,[],args) 186 : val (e2',params2',args2',code2')=rewriteOp(e2,params1',[],[],args1') 187 : cchiw 2838 val body =E.Div(e1',e2') 188 : cchiw 2843 val einapp= isZero(y,body,params2',index,[],args2') 189 : cchiw 2838 in 190 : (einapp,code1'@code2') 191 : end 192 : 193 : 194 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 195 : cchiw 2843 * calls rewriteOps() lift on ein_exp 196 : cchiw 2838 *) 197 : fun handleAdd(y,e1,params,index,args)=let 198 : cchiw 2843 val (e1',params',args',code)= rewriteOps(e1,params,index,[],args) 199 : cchiw 2838 val body =E.Add e1' 200 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args') 201 : cchiw 2838 in 202 : (einapp,code) 203 : end 204 : 205 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 206 : cchiw 2843 * calls rewriteOps() lift on ein_exp 207 : cchiw 2838 *) 208 : fun handleProd(y,e1,params,index,args)=let 209 : cchiw 2843 val (e1',params',args',code)= rewriteOps(e1,params,index,[],args) 210 : cchiw 2838 val body =E.Prod e1' 211 : cchiw 2843 val einapp= isZero(y,body,params',index,[],args') 212 : cchiw 2838 in 213 : (einapp,code) 214 : end 215 : 216 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 217 : cchiw 2843 * calls rewriteOps() lift on ein_exp 218 : cchiw 2838 *) 219 : fun handleSumProd(y,e1,params,index,sx,args)=let 220 : cchiw 2843 val (e1',params',args',code)= rewriteOps(e1,params,index,sx,args) 221 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 222 : cchiw 2843 val einapp= isZero(y,body,params',index,sx,args') 223 : cchiw 2838 in 224 : (einapp,code) 225 : end 226 : 227 : (* split:var*ein_app-> (var*einap)*code 228 : * split ein expression into smaller pieces 229 : cchiw 2843 note we leave summation around probe exp 230 : cchiw 2838 *) 231 : fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 232 : cchiw 2843 val zero= (setEinZero y,[]) 233 : cchiw 2838 val default=((y,einapp),[]) 234 : val sumIndex=ref [] 235 : fun rewrite b=(case b 236 : of E.Probe _ => default 237 : | E.Conv _ => zero 238 : | E.Field _ => zero 239 : | E.Apply _ => zero 240 : | E.Lift e => zero 241 : | E.Delta _ => default 242 : | E.Epsilon _ => default 243 : cchiw 2843 | E.Eps2 _ => default 244 : cchiw 2838 | E.Tensor _ => default 245 : | E.Const _ => default 246 : | E.Neg e1 => handleNeg(y,e1,params,index,args) 247 : | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args) 248 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args) 249 : cchiw 2843 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe _ ]) => default 250 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe _ ]) => default 251 : cchiw 2838 | E.Sum(_,E.Probe _) => default 252 : | E.Sum(_,E.Conv _) => zero 253 : | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args) 254 : | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n))) 255 : | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a)) 256 : | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))) 257 : | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2))) 258 : | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e)) 259 : | E.Sum(sx,_) => default 260 : | E.Add e1 => handleAdd(y,e1,params,index,args) 261 : | E.Prod e1 => handleProd(y,e1,params,index,args) 262 : | E.Partial _ => err(" Partial used after normalize") 263 : | E.Krn _ => err("Krn used before expand") 264 : | E.Value _ => err("Value used before expand") 265 : | E.Img _ => err("Probe used before expand") 266 : (*end case *)) 267 : 268 : cchiw 2843 269 : cchiw 2838 val (einapp2,newbies) =rewrite body 270 : in 271 : (einapp2,newbies) 272 : end 273 : |split(y,app) =((y,app),[]) 274 : 275 : 276 : (* iterMultiple:code*code=> (code*code) 277 : * recursively split ein expression into smaller pieces 278 : *) 279 : fun iterMultiple(einapp2,newbies2)=let 280 : fun itercode([],rest,code)=(rest,code) 281 : | itercode(e1::newbies,rest,code)=let 282 : val (einapp3,code3) =split(e1) 283 : val (rest4,code4)=itercode(code3,[],[]) 284 : in itercode(newbies,rest@[einapp3],code4@rest4@code) 285 : end 286 : val(rest,code)= itercode(newbies2,[],[]) 287 : in 288 : (einapp2,code@rest) 289 : end 290 : 291 : 292 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 293 : val (_,_,body')=cleanIndex(body,index,[]) 294 : val einapp1= assignEinApp(y,params,index,body',args) 295 : val (einapp2,newbies2)=split einapp1 296 : cchiw 2838 in 297 : iterMultiple(einapp2,newbies2) 298 : end 299 : 300 : 301 : 302 : (* gettest:code*code=> (code*code) 303 : * print results for splitting einapp 304 : *) 305 : fun gettest(einapp)=(case testing 306 : of 0=>iterSplit(einapp) 307 : | _=>let 308 : val star="\n*************\n" 309 : val _ =print(String.concat[star]) 310 : val (einapp2,newbies)=iterSplit(einapp) 311 : val a=printEINAPP einapp2 312 : val b=String.concatWith",\n\t"(List.map printEINAPP newbies) 313 : val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star]) 314 : in 315 : (einapp2,newbies) 316 : end 317 : (*end case*)) 318 : 319 : 320 : end; (* local *) 321 : 322 : end (* local *)