Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

 1 : cchiw 2843 (* Currently under construction 2 : cchiw 2838 * 3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) 4 : * 5 : * COPYRIGHT (c) 2015 The University of Chicago 6 : cchiw 2838 * All rights reserved. 7 : *) 8 : cchiw 2843 9 : (* 10 : During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. 11 : 12 : (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. 13 : 14 : (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. 15 : 16 : (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. 17 : 18 : (1c) Call cleanParams.sml to clean the params in the subexpression.\\ 19 : *) 20 : cchiw 2838 21 : structure Split = struct 22 : 23 : local 24 : 25 : structure E = Ein 26 : structure DstIL = MidIL 27 : structure DstTy = MidILTypes 28 : structure DstV = DstIL.Var 29 : cchiw 3174 30 : cchiw 2838 structure P=Printer 31 : structure cleanP=cleanParams 32 : structure cleanI=cleanIndex 33 : 34 : cchiw 3033 35 : cchiw 2838 in 36 : 37 : cchiw 3174 val numFlag=1 (*remove common subexpression*) 38 : cchiw 3033 val testing=0 39 : cchiw 3260 fun mkEin e = E.mkEin e 40 : cchiw 3448 val einappzero= DstIL.EINAPP(mkEin([],[],E.B(E.Const 0)),[]) 41 : cchiw 2843 fun setEinZero y= (y,einappzero) 42 : cchiw 3260 fun cleanParams e = cleanP.cleanParams e 43 : fun cleanIndex e = cleanI.cleanIndex e 44 : fun toStringBind e= MidToString.toStringBind e 45 : fun itos i = Int.toString i 46 : fun err str = raise Fail str 47 : cchiw 2838 val cnt = ref 0 48 : cchiw 3166 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1) 49 : cchiw 2838 fun genName prefix = let 50 : val n = !cnt 51 : in 52 : cnt := n+1; 53 : String.concat[prefix, "_", Int.toString n] 54 : end 55 : fun testp n=(case testing 56 : of 0=> 1 57 : | _ =>(print(String.concat n);1) 58 : (*end case*)) 59 : 60 : (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) 61 : *lifts expression and returns replacement tensor 62 : cchiw 2843 * cleans the index and params of subexpression 63 : *creates new param and replacement tensor for the original ein_exp 64 : cchiw 2838 *) 65 : cchiw 3166 fun lift(name,e,params,index,sx,args,fieldset,flag)=let 66 : cchiw 2838 val (tshape,sizes,body)=cleanIndex(e,index,sx) 67 : cchiw 2843 val id=length(params) 68 : val Rparams=params@[E.TEN(1,sizes)] 69 : val Re=E.Tensor(id,tshape) 70 : cchiw 2845 val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) 71 : cchiw 2843 val Rargs=args@[M] 72 : val einapp=cleanParams(M,body,Rparams,sizes,Rargs) 73 : cchiw 3166 val (_,einapp0)=einapp 74 : val (Rargs,newbies,fieldset) =(case flag 75 : of 1=> let 76 : val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0) 77 : in (case var 78 : of NONE=> (args@[M],[einapp],fieldset) 79 : | SOME v=> (incUse v ;(args@[v],[],fieldset)) 80 : (*end case*)) 81 : end 82 : | _=>(args@[M],[einapp],fieldset) 83 : (*end case*)) 84 : in 85 : (Re,Rparams,Rargs,newbies,fieldset) 86 : end 87 : cchiw 2838 88 : (* isOp: ein->int 89 : * checks to see if this sub-expression is pulled out or split form original 90 : * 0-becomes zero,1-remains the same, 2-operator 91 : *) 92 : fun isOp e =(case e 93 : of E.Field _ => 0 94 : | E.Conv _ => 0 95 : | E.Apply _ => 0 96 : | E.Lift _ => 0 97 : cchiw 3448 | E.Op1 _ => 1 98 : | E.Op2 _ => 1 99 : | E.Opn _ => 1 100 : cchiw 2838 | E.Sum _ => 1 101 : | E.Probe _ => 1 102 : | E.Partial _ => err(" Partial used after normalize") 103 : | E.Krn _ => err("Krn used before expand") 104 : | E.Value _ => err("Value used before expand") 105 : | E.Img _ => err("Probe used before expand") 106 : | _ => 2 107 : (*end case*)) 108 : 109 : cchiw 3444 (* *************************************** helpers ******************************** *) 110 : cchiw 3166 fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1) 111 : cchiw 3448 of 0 => (E.B(E.Const 0),params,args,[],fieldset) 112 : cchiw 3166 | 1 => lift(name,e1,params,index,sx,args,fieldset,flag) 113 : cchiw 3448 | _ => (e1,params,args,[],fieldset) (*not lifted*) 114 : cchiw 2838 (*end*)) 115 : 116 : cchiw 3444 fun unaryOp(name,sx,e1,x)=let 117 : val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x 118 : val params=Ein.params ein 119 : val index=Ein.index ein 120 : in 121 : rewriteOp(name,e1,params,index,sx,args,fieldset,flag) 122 : end 123 : 124 : fun multOp(name,sx,list1,x)=let 125 : val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x 126 : val params=Ein.params ein 127 : val index=Ein.index ein 128 : cchiw 3166 fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset) 129 : cchiw 3444 | m(e1::es,rest,params,args,code,fieldset)=let 130 : cchiw 3166 val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag) 131 : cchiw 2838 in 132 : cchiw 3166 m(es,rest@[e1'],params',args',code@code',fieldset) 133 : cchiw 2838 end 134 : in 135 : cchiw 3444 m(list1,[],params,args,[],fieldset) 136 : cchiw 2838 end 137 : cchiw 3166 138 : cchiw 3444 (*clean params*) 139 : fun cleanOrig(body,params,args,x) =let 140 : cchiw 3166 val ((y,DstIL.EINAPP(ein,_)),_,_)=x 141 : val index=Ein.index ein 142 : in cleanParams(y,body,params,index,args) 143 : cchiw 3444 end 144 : cchiw 3166 145 : cchiw 3444 (* *************************************** general handle Ops ******************************** *) 146 : fun handleUnaryOp(name,opp,x,e1)=let 147 : val (e1',params',args',code,fieldset)= unaryOp(name,[],e1,x) 148 : cchiw 3448 val body' =E.Op1(opp, e1') 149 : cchiw 3444 val einapp= cleanOrig(body',params',args',x) 150 : cchiw 2845 in 151 : cchiw 3166 (einapp,code,fieldset) 152 : cchiw 2845 end 153 : cchiw 3444 fun handleBinaryOp(name,opp,x,es)=let 154 : val ([e1',e2'],params',args',code,fieldset)= multOp(name,[],es,x) 155 : cchiw 3448 val body' =E.Op2(opp,e1',e2') 156 : cchiw 3444 val einapp= cleanOrig(body',params',args',x) 157 : cchiw 3138 in 158 : cchiw 3166 (einapp,code,fieldset) 159 : cchiw 3444 end 160 : fun handleMultOp(name,opp,x,es)= let 161 : val (e1',params',args',code,fieldset)= multOp(name,[],es,x) 162 : cchiw 3448 val body =E.Opn(opp ,e1') 163 : cchiw 3444 val einapp= cleanOrig(body,params',args',x) 164 : cchiw 3138 in 165 : cchiw 3166 (einapp,code,fieldset) 166 : cchiw 3444 end 167 : (* ***************************************specific handle Ops ******************************** *) 168 : fun handleDiv(e1,e2,x)=let 169 : val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x 170 : val params=Ein.params ein 171 : val index=Ein.index ein 172 : cchiw 3166 val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag) 173 : val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag) 174 : cchiw 3448 val body' =E.Op2(E.Div,e1',e2') 175 : cchiw 3444 val einapp= cleanOrig(body',params2',args2',x) 176 : cchiw 2845 in 177 : cchiw 3166 (einapp,code1'@code2',fieldset) 178 : cchiw 2845 end 179 : cchiw 3444 fun handleSumProd(e1,sx,x)=let 180 : val (e1',params',args',code,fieldset)= multOp("sumprod",sx,e1,x) 181 : cchiw 3448 val body'= E.Sum(sx,E.Opn(E.Prod, e1')) 182 : cchiw 3444 val einapp= cleanOrig(body',params',args',x) 183 : cchiw 2845 in 184 : cchiw 3166 (einapp,code,fieldset) 185 : cchiw 2845 end 186 : cchiw 2838 187 : cchiw 3444 (* *************************************** Split ******************************** *) 188 : cchiw 2838 189 : (* split:var*ein_app-> (var*einap)*code 190 : * split ein expression into smaller pieces 191 : cchiw 2843 note we leave summation around probe exp 192 : cchiw 2838 *) 193 : cchiw 3166 fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let 194 : val x= ((y,einapp),fieldset,flag) 195 : val zero= (setEinZero y,[],fieldset) 196 : val default=((y,einapp),[],fieldset) 197 : cchiw 2838 val sumIndex=ref [] 198 : cchiw 2867 val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) 199 : cchiw 3017 val _=testp["\n\nStarting split",P.printbody body] 200 : fun rewrite b=(case b 201 : cchiw 3448 of E.B _ => default 202 : | E.Tensor _ => default 203 : | E.G _ => default 204 : | E.Field _ => raise Fail "should have been swept" 205 : | E.Lift e => raise Fail "should have been swept" 206 : | E.Conv _ => raise Fail "should have been swept" 207 : | E.Partial _ => err(" Partial used after normalize") 208 : | E.Apply _ => raise Fail "should have been swept" 209 : | E.Probe(E.Conv _,_) => default 210 : cchiw 2870 | E.Probe(E.Field _,_) => raise Fail str 211 : cchiw 2847 | E.Probe _ => raise Fail str 212 : cchiw 3448 | E.Value _ => err("Value used before expand") 213 : | E.Img _ => err("Probe used before expand") 214 : | E.Krn _ => err("Krn used before expand") 215 : cchiw 3444 | E.Sum(_,E.Probe(E.Conv _,_)) => default 216 : | E.Sum(sx,E.Tensor _) => default 217 : cchiw 3448 218 : (* | E.Sum(_,E.Opn(E.Prod,[E.Eps2 _, E.Probe(E.Conv _,_)])) => default 219 : | E.Sum(_,E.Opn(E.Prod,[E.Epsilon _, E.Probe(E.Conv _,_) ])) => default*) 220 : | E.Sum(sx,E.Opn(E.Prod, e1)) => handleSumProd(e1,sx,x) 221 : | E.Sum(sx,E.G(E.Delta d)) => handleSumProd([E.G(E.Delta d)],sx,x) 222 : cchiw 3033 | E.Sum(sx,_) => err(" summation not distributed:"^str) 223 : cchiw 3448 | E.Op1(op1,e1) => 224 : (case op1 225 : of E.Neg => handleUnaryOp("neg",op1,x,e1) 226 : | E.Sqrt => handleUnaryOp("sqrt",op1,x,e1) 227 : | E.Exp => handleUnaryOp("exp",op1,x,e1) 228 : | E.PowInt n1 => handleUnaryOp("PowInt",op1,x,e1) 229 : | _ => handleUnaryOp("Trig",op1,x,e1) 230 : (*end case *)) 231 : | E.Op2(E.Sub,e1,e2) => handleBinaryOp("subtract",E.Sub,x,[e1,e2]) 232 : | E.Op2(E.Div,e1,e2) => handleDiv(e1,e2,x) 233 : | E.Opn(E.Add,es) => handleMultOp("add",E.Add,x,es) 234 : | E.Opn(Prod,[E.Tensor(id0,[]),E.Tensor(id1,[i]),E.Tensor(id2,[])])=> 235 : rewrite (E.Opn(E.Prod,[ 236 : E.Opn(E.Prod,[E.Tensor(id0,[]),E.Tensor(id2,[])]),E.Tensor(id1,[i])])) 237 : | E.Opn(E.Prod,es) => handleMultOp("prod",E.Prod,x,es) 238 : 239 : cchiw 2838 (*end case *)) 240 : cchiw 3166 val (einapp2,newbies,fieldset) =rewrite body 241 : cchiw 2838 in 242 : cchiw 3166 ((einapp2,newbies),fieldset) 243 : cchiw 2838 end 244 : cchiw 3166 |split((y,app),fieldset,_) =(((y,app),[]),fieldset) 245 : cchiw 2923 246 : cchiw 3166 247 : cchiw 3444 (* *************************************** main ******************************** *) 248 : fun limitSplit(einapp2,fields2,splitlimit)=let 249 : val fieldset= einSet.EinSet.empty 250 : val _ =print ("\nSPLit with limit"^(Int.toString(splitlimit))) 251 : fun itercode([],rest,code,cnt)=(("\n Empty-SplitCount: "^Int.toString(cnt));(rest,code)) 252 : | itercode(e1::newbies,rest,code,cnt)=let 253 : val ((einapp3,code3),_) = split(e1,fieldset,numFlag) 254 : val (rest4,code4)=itercode(code3,[],[],cnt+1) 255 : val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))] 256 : in 257 : if (length(rest@newbies@code) > splitlimit) then let 258 : val _ =("\n SplitCount: "^Int.toString(cnt)) 259 : val code5=code4@rest4@code 260 : val rest5=rest@[einapp3] 261 : in 262 : (rest5,code5@newbies)(*tab4*) 263 : end 264 : else itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2) 265 : end 266 : val(rest,code)= itercode([einapp2],[],[],0) 267 : cchiw 3166 in 268 : cchiw 3444 269 : fields2@code@rest (*B*) 270 : cchiw 3166 end 271 : 272 : cchiw 3444 fun splitEinApp einapp0 =let 273 : val fieldset= einSet.EinSet.empty 274 : val einapp2=[einapp0] 275 : cchiw 3166 fun itercode([],rest,code,_)=(rest,code) 276 : cchiw 3017 | itercode(e1::newbies,rest,code,cnt)=let 277 : cchiw 3166 val ((einapp3,code3),_) = split(e1,fieldset,numFlag) 278 : cchiw 3017 val (rest4,code4)=itercode(code3,[],[],cnt+1) 279 : cchiw 3260 val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))] 280 : cchiw 3166 in 281 : itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2) 282 : cchiw 2838 end 283 : cchiw 3166 val(rest,code)= itercode(einapp2,[],[],0) 284 : cchiw 2838 in 285 : cchiw 3166 (code@rest) 286 : cchiw 2838 end 287 : 288 : end; (* local *) 289 : 290 : end (* local *)