Home My Page Projects Code Snippets Project Openings diderot

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * All rights reserved. 5 : *) 6 : cchiw 2843 7 : (* 8 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 9 : 10 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 11 : 12 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 13 : 14 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 15 : 16 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 17 : *) 18 : cchiw 2838 19 : structure Split = struct 20 : 21 : local 22 : 23 : structure E = Ein 24 : structure DstIL = MidIL 25 : structure DstTy = MidILTypes 26 : structure DstV = DstIL.Var 27 : structure P=Printer 28 : structure cleanP=cleanParams 29 : structure cleanI=cleanIndex 30 : cchiw 2845 structure handleE=handleEin 31 : cchiw 2838 32 : in 33 : 34 : cchiw 2845 val testing=1 35 : cchiw 2838 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} 36 : fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) 37 : cchiw 2843 val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) 38 : fun setEinZero y= (y,einappzero) 39 : cchiw 2838 fun cleanParams e =cleanP.cleanParams e 40 : fun cleanIndex e =cleanI.cleanIndex e 41 : cchiw 2845 fun printEINAPP e=MidToString.printEINAPP e 42 : fun isZero e=handleE.isZero e 43 : cchiw 2838 fun itos i =Int.toString i 44 : fun err str=raise Fail str 45 : val cnt = ref 0 46 : fun genName prefix = let 47 : val n = !cnt 48 : in 49 : cnt := n+1; 50 : String.concat[prefix, "_", Int.toString n] 51 : end 52 : fun testp n=(case testing 53 : of 0=> 1 54 : | _ =>(print(String.concat n);1) 55 : (*end case*)) 56 : 57 : 58 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 59 : *lifts expression and returns replacement tensor 60 : cchiw 2843 * cleans the index and params of subexpression 61 : *creates new param and replacement tensor for the original ein_exp 62 : cchiw 2838 *) 63 : cchiw 2845 fun lift(name,e,params,index,sx,args)=let 64 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 65 : cchiw 2843 val id=length(params) 66 : val Rparams=params@[E.TEN(1,sizes)] 67 : val Re=E.Tensor(id,tshape) 68 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 69 : cchiw 2843 val Rargs=args@[M] 70 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 71 : 72 : cchiw 2838 in 73 : cchiw 2843 (Re,Rparams,Rargs,[einapp]) 74 : cchiw 2838 end 75 : 76 : (* isOp: ein->int 77 : * checks to see if this sub-expression is pulled out or split form original 78 : * 0-becomes zero,1-remains the same, 2-operator 79 : *) 80 : fun isOp e =(case e 81 : of E.Field _ => 0 82 : | E.Conv _ => 0 83 : | E.Apply _ => 0 84 : | E.Lift _ => 0 85 : | E.Neg _ => 1 86 : | E.Add _ => 1 87 : | E.Sub _ => 1 88 : | E.Prod _ => 1 89 : | E.Div _ => 1 90 : | E.Sum _ => 1 91 : | E.Probe _ => 1 92 : | E.Partial _ => err(" Partial used after normalize") 93 : | E.Krn _ => err("Krn used before expand") 94 : | E.Value _ => err("Value used before expand") 95 : | E.Img _ => err("Probe used before expand") 96 : | _ => 2 97 : (*end case*)) 98 : 99 : cchiw 2843 (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code 100 : * If e1 an op then call lift() to replace it 101 : cchiw 2838 * Otherwise rewrite to 0 or it remains the same 102 : *) 103 : cchiw 2845 fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1) 104 : cchiw 2838 of 0 => (E.Const 0,params,args,[]) 105 : | 2 => (e1,params,args,[]) 106 : cchiw 2845 | _ => lift(name,e1,params,index,sx,args) 107 : cchiw 2838 (*end*)) 108 : 109 : cchiw 2843 (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars 110 : -> ein_exp list*params*args*code 111 : * calls rewriteOp on ein_exp list 112 : cchiw 2838 *) 113 : cchiw 2845 fun rewriteOps(name,list1,params,index,sx,args)=let 114 : cchiw 2838 fun m([],rest,params,args,code)=(rest,params,args,code) 115 : | m(e1::es,rest,params,args,code)=let 116 : cchiw 2845 val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args) 117 : cchiw 2838 in 118 : m(es,rest@[e1'],params',args',code@code') 119 : end 120 : in 121 : m(list1,[],params,args,[]) 122 : end 123 : cchiw 2843 124 : cchiw 2845 (*rewriteOrig: var* ein_exp* params*index list*mid-il vars 125 : cchiw 2843 When the operation is zero then we return a real. 126 : cchiw 2845 -Moved is Zero to before split. 127 : cchiw 2843 *) 128 : cchiw 2845 fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body) 129 : of 1=> setEinZero y 130 : cchiw 2843 | _ => cleanParams(y,body,params,index,args) 131 : cchiw 2845 (*end case*)) 132 : cchiw 2838 133 : (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code 134 : cchiw 2843 * calls rewriteOp() lift on ein_exp 135 : cchiw 2838 *) 136 : fun handleNeg(y,e1,params,index,args)=let 137 : cchiw 2845 val (e1',params',args',code)= rewriteOp(DstV.name y, e1,params,index,[],args) 138 : cchiw 2838 val body =E.Neg e1' 139 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 140 : in 141 : (einapp,code) 142 : end 143 : cchiw 2838 144 : (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code 145 : cchiw 2843 * calls rewriteOps() lift on ein_exp 146 : cchiw 2838 *) 147 : fun handleSub(y,e1,e2,params,index,args)=let 148 : cchiw 2845 val ([e1',e2'],params',args',code)= rewriteOps(DstV.name y,[e1,e2],params,index,[],args) 149 : cchiw 2838 val body =E.Sub(e1',e2') 150 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 151 : in 152 : (einapp,code) 153 : end 154 : cchiw 2838 155 : (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code 156 : cchiw 2843 * calls rewriteOp() lift on ein_exp 157 : cchiw 2838 *) 158 : fun handleDiv(y,e1,e2,params,index,args)=let 159 : cchiw 2845 val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args) 160 : val (e2',params2',args2',code2')=rewriteOp(DstV.name y,e2,params1',[],[],args1') 161 : cchiw 2838 val body =E.Div(e1',e2') 162 : cchiw 2845 val einapp= rewriteOrig(y,body,params2',index,[],args2') 163 : in 164 : (einapp,code1'@code2') 165 : end 166 : cchiw 2838 167 : (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code 168 : cchiw 2843 * calls rewriteOps() lift on ein_exp 169 : cchiw 2838 *) 170 : fun handleAdd(y,e1,params,index,args)=let 171 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,[],args) 172 : cchiw 2838 val body =E.Add e1' 173 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 174 : in 175 : (einapp,code) 176 : end 177 : cchiw 2838 178 : (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code 179 : cchiw 2843 * calls rewriteOps() lift on ein_exp 180 : cchiw 2838 *) 181 : fun handleProd(y,e1,params,index,args)=let 182 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,[],args) 183 : cchiw 2838 val body =E.Prod e1' 184 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,[],args') 185 : in 186 : (einapp,code) 187 : end 188 : cchiw 2838 189 : (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code 190 : cchiw 2843 * calls rewriteOps() lift on ein_exp 191 : cchiw 2838 *) 192 : fun handleSumProd(y,e1,params,index,sx,args)=let 193 : cchiw 2845 val (e1',params',args',code)= rewriteOps(DstV.name y,e1,params,index,sx,args) 194 : cchiw 2838 val body= E.Sum(sx,E.Prod e1') 195 : cchiw 2845 val einapp= rewriteOrig(y,body,params',index,sx,args') 196 : in 197 : (einapp,code) 198 : end 199 : cchiw 2838 200 : (* split:var*ein_app-> (var*einap)*code 201 : * split ein expression into smaller pieces 202 : cchiw 2843 note we leave summation around probe exp 203 : cchiw 2838 *) 204 : fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 205 : cchiw 2843 val zero= (setEinZero y,[]) 206 : cchiw 2838 val default=((y,einapp),[]) 207 : val sumIndex=ref [] 208 : fun rewrite b=(case b 209 : of E.Probe _ => default 210 : | E.Conv _ => zero 211 : | E.Field _ => zero 212 : | E.Apply _ => zero 213 : | E.Lift e => zero 214 : | E.Delta _ => default 215 : | E.Epsilon _ => default 216 : cchiw 2843 | E.Eps2 _ => default 217 : cchiw 2838 | E.Tensor _ => default 218 : | E.Const _ => default 219 : | E.Neg e1 => handleNeg(y,e1,params,index,args) 220 : | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args) 221 : | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args) 222 : cchiw 2843 | E.Sum(_,E.Prod[E.Eps2 _, E.Probe _ ]) => default 223 : | E.Sum(_,E.Prod[E.Epsilon _, E.Probe _ ]) => default 224 : cchiw 2838 | E.Sum(_,E.Probe _) => default 225 : | E.Sum(_,E.Conv _) => zero 226 : | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args) 227 : | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n))) 228 : | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a)) 229 : | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))) 230 : | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2))) 231 : | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e)) 232 : | E.Sum(sx,_) => default 233 : | E.Add e1 => handleAdd(y,e1,params,index,args) 234 : | E.Prod e1 => handleProd(y,e1,params,index,args) 235 : | E.Partial _ => err(" Partial used after normalize") 236 : | E.Krn _ => err("Krn used before expand") 237 : | E.Value _ => err("Value used before expand") 238 : | E.Img _ => err("Probe used before expand") 239 : (*end case *)) 240 : val (einapp2,newbies) =rewrite body 241 : in 242 : (einapp2,newbies) 243 : end 244 : |split(y,app) =((y,app),[]) 245 : 246 : (* iterMultiple:code*code=> (code*code) 247 : * recursively split ein expression into smaller pieces 248 : *) 249 : fun iterMultiple(einapp2,newbies2)=let 250 : fun itercode([],rest,code)=(rest,code) 251 : | itercode(e1::newbies,rest,code)=let 252 : val (einapp3,code3) =split(e1) 253 : val (rest4,code4)=itercode(code3,[],[]) 254 : in itercode(newbies,rest@[einapp3],code4@rest4@code) 255 : end 256 : val(rest,code)= itercode(newbies2,[],[]) 257 : in 258 : (einapp2,code@rest) 259 : end 260 : 261 : cchiw 2843 fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let 262 : cchiw 2845 (*val (_,_,body')=cleanIndex(body,index,[]) 263 : cchiw 2843 val einapp1= assignEinApp(y,params,index,body',args) 264 : cchiw 2845 *) 265 : val (_,sizes,body')=cleanIndex(body,index,[]) 266 : val einapp1= assignEinApp(y,params,index,body',args) 267 : val a=testp["\n rewriten einapp\n \t",printEINAPP einapp1] 268 : cchiw 2843 val (einapp2,newbies2)=split einapp1 269 : cchiw 2838 in 270 : iterMultiple(einapp2,newbies2) 271 : end 272 : 273 : (* gettest:code*code=> (code*code) 274 : * print results for splitting einapp 275 : *) 276 : cchiw 2845 fun gettest einapp=(case testing 277 : cchiw 2838 of 0=>iterSplit(einapp) 278 : | _=>let 279 : cchiw 2845 val star="\n************* SPLIT********\n" 280 : val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp]) 281 : cchiw 2838 val (einapp2,newbies)=iterSplit(einapp) 282 : val a=printEINAPP einapp2 283 : val b=String.concatWith",\n\t"(List.map printEINAPP newbies) 284 : val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star]) 285 : in 286 : (einapp2,newbies) 287 : end 288 : (*end case*)) 289 : 290 : end; (* local *) 291 : 292 : end (* local *)