SCM Repository
View of /branches/charisee/src/compiler/tree-il/low-to-tree-fn.sml
Parent Directory
|
Revision Log
Revision 3168 -
(download)
(annotate)
Sun Mar 29 20:46:53 2015 UTC (6 years ago) by jhr
File size: 48347 byte(s)
Sun Mar 29 20:46:53 2015 UTC (6 years ago) by jhr
File size: 48347 byte(s)
formatting changes
(* low-to-tree-fn.sml * * COPYRIGHT (c) 2011 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. * * This module translates the LowIL representation of a program (i.e., a pure CFG) to * a block-structured AST with nested expressions. * * NOTE: this translation is pretty dumb about variable coalescing (i.e., it doesn't do any). *) functor LowToTreeFn (Target : sig val supportsPrinting : unit -> bool (* does the target support the Print op? *) (* tests for whether various expression forms can appear inline *) val inlineCons : int -> bool (* can n'th-order tensor construction appear inline *) val inlineMatrixExp : bool (* can matrix-valued expressions appear inline? *) val isHwVec : int -> bool val isVecTy : int -> bool val getPieces : int -> int list val getVecTy : int -> bool * int * int list end) : sig val translate : LowIL.program -> TreeIL.program end = struct structure Src = LowIL structure SrcOp = LowOps structure SrcV = LowIL.Var structure SrcSV = LowIL.StateVar structure VA = VarAnalysis structure InP = Inputs structure Ty = LowILTypes structure Nd = LowIL.Node structure CFG = LowIL.CFG structure LowOpToTreeOp = LowOpToTreeOp structure Dst = TreeIL structure DstOp = TreeOps structure DstSV = Dst.StateVar structure SrcTy = LowILTypes structure DstTy = TreeILTypes structure DstV = Dst.Var structure TreeToOpr = TreeToOpr structure Fnc = TreeFunc structure TySet = Fnc.TySet structure OprSet = Fnc.OprSet structure LowToS = LowToString (* create new tree IL variables *) local val newVar = Dst.Var.new val cnt = ref 0 fun genName prefix = let val n = !cnt in cnt := n+1; String.concat[prefix, "_", Int.toString n] end in val testing=0 fun testp str=(case testing of 1=> (print(String.concat str);1) | _ =>1 (*end case*)) fun iTos n =Int.toString n fun newGlobal x = newVar (genName("G_" ^ SrcV.name x), Dst.VK_Global, SrcV.ty x) fun newParam x = newVar (genName("p_" ^ SrcV.name x), Dst.VK_Local, SrcV.ty x) fun newLocal x = newVar (genName("l_" ^ SrcV.name x), Dst.VK_Local, SrcV.ty x) fun newIter x = newVar (genName("i_" ^ SrcV.name x), Dst.VK_Local, SrcV.ty x) fun newTmp (x,n) = newVar (genName("l_" ^ iTos n^SrcV.name x), Dst.VK_Local, SrcV.ty x) fun newLocalPtrTy (name,ty)= newVar(genName("l_rp_"^name), Dst.VK_Local,ty) fun newLocalWithTy (name,n)= newVar(genName("l_"^iTos n^name), Dst.VK_Local,Ty.TensorTy [n]) fun newGlobalWithTy (name,n)= newVar(genName("G_"^iTos n^name), Dst.VK_Global,Ty.TensorTy [n]) end (* associate Tree IL state variables with Low IL variables using properties *) local fun mkStateVar x = Dst.SV{ name = SrcSV.name x, id = Stamp.new(), ty = SrcSV.ty x, varying = VA.isVarying x, output = SrcSV.isOutput x } in val {getFn = getStateVar, ...} = SrcSV.newProp mkStateVar end fun mkBlock stms = Dst.Block{locals=[], body=stms} fun mkIf (x, stms, []) = Dst.S_IfThen(x, mkBlock stms) | mkIf (x, stms1, stms2) = Dst.S_IfThenElse(x, mkBlock stms1, mkBlock stms2) (* an environment that tracks bindings of variables to target expressions and the list * of locals that have been defined. *) local structure VT = SrcV.Tbl fun decCount ( Src.V{useCnt, ...}) = let val n = !useCnt - 1 in useCnt := n; (n <= 0) end datatype target_binding = GLOB of Dst.var (* variable is global *) | TREE of Dst.exp (* variable bound to target expression tree *) | DEF of Dst.exp (* either a target variable or constant for a defined variable *) structure ListSetOfInts = ListSetFn (struct type ord_key = int val compare = Int.compare end) (*changed env to add sets Tys and Oprs*) datatype env = E of { tbl : target_binding VT.hash_table, (* FIXME: perhaps the types and functs should be set refs, since they are global? *) types : TySet.set, functs : OprSet.set, locals : Dst.var list } in fun peelEnv (E{tbl, types, functs, locals}) = (types,functs) fun peelEnvLoc (E{tbl, types, functs, locals}) = locals fun setEnv (E{tbl, types,functs,locals}, types1, functs1) = E{tbl=tbl, types=types1, functs= functs1 ,locals=locals} (*addOprFromExp: env* TreeIL.Exp-> env * get new opr and type set and store it into the environment *) fun addOprFromExp(env, exp)=let val t1 = peelEnv env val (ty2, opr2) = TreeToOpr.expToOpr (t1,exp) in setEnv(env, ty2,opr2) end (*addOprFromStmt: env* TreeIL.Stmt-> env * get new opr and type set and store it into the environment *) fun addOprFromStmt (env, stms)=let val t1=peelEnv(env) val (ty2,opr2)= TreeToOpr.stmtsToOpr ( t1 ,stms) in setEnv(env, ty2,opr2) end (* DEBUG *) fun bindToString binding = (case binding of GLOB y => "GLOB " ^ Dst.Var.name y | TREE e => "TREE" | DEF(Dst.E_Var y) => "DEFVar " ^ Dst.Var.name y | DEF e => "DEF"^Dst.toString e (* end case *)) fun dumpEnv (E{tbl, ...}) = let fun prEntry (x, binding) = testp[" ", Src.Var.toString x, " --> ", bindToString binding, "\n"] in (* print "\n *** dump environment\n"; VT.appi prEntry tbl; print "***\n"*) print "" end (* DEBUG *) fun newEnv () = E{tbl = VT.mkTable (512, Fail "tbl"), types=TySet.empty, functs=OprSet.empty, locals=[]} (* use a variable. If it is a pending expression, we remove it from the table *) fun peek (env as E{tbl, ...}) x = (case (VT.find tbl x) of NONE=>"none" | SOME e=> bindToString e (*end case *)) fun useVar (env as E{tbl, ...}) x = (case VT.find tbl x of SOME(GLOB x') => ( (*print ("\n usevar found Glob "^SrcV.name x^"\n") ;*)Dst.E_Var x') | SOME(TREE e) => ( (*print(concat["useVar ", SrcV.toString x, " ==> TREE\n"]);*) ignore(VT.remove tbl x); e) | SOME(DEF e) => ( (*print(concat["useVar ", SrcV.toString x, " ==> ", bindToString(DEF e), "; use count = ", Int.toString(SrcV.useCount x), "\n"]);*) (* if this is the last use of x, then remove it from the table *) (*if (decCount x) then ignore(VT.remove tbl x) else ();*) (*print ("\n found Def "^SrcV.name x^"\n");*) e) | NONE => ( dumpEnv env; raise Fail(concat ["useVar(", SrcV.toString x, ")"]) ) (* end case *)) (* record a local variable *) fun addLocal (E{tbl, types,functs,locals}, x) = E{tbl=tbl,types=types, functs=functs,locals=x::locals} fun addLocals (E{tbl, types,functs,locals}, x) = E{tbl=tbl,types=types, functs=functs,locals=x@locals} fun global (E{tbl, ...}, x, x') = (VT.insert tbl (x, GLOB x')) (* insert a pending expression into the table. Note that x should only be used once! *) fun insert (env as E{tbl, ...}, x, exp) = ( VT.insert tbl (x, TREE exp); env) fun rename (env as E{tbl, ...}, x, x') = ( VT.insert tbl (x, DEF(Dst.E_Var x')); env) fun renameGlob (env as E{tbl, ...}, x, x') = ( VT.insert tbl (x, GLOB( x')); env) fun renameExp (env as E{tbl, ...}, x, x') = ( VT.insert tbl (x, DEF( x')); env) fun peekGlobal (E{tbl, ...}, x) = (case VT.find tbl x of SOME(GLOB x') => SOME x' | SOME e => NONE | NONE => NONE (* end case *)) (*bindLocal: env*SrcV*Dst.Exp-> env*Dst.S list * if lhs variable is used once then it is inserted * else if lhs is used multiple times then new local variables are created * when exp is a mux more ops are used *) fun bindLocal (env, lhs, rhs) =let val n=SrcV.useCount lhs val _=testp ["\n In BindLocal: \n \t LHS: ",SrcV.name lhs, " Count \t",Int.toString n," rhs:", Dst.toString rhs ,"\n"] in (case (n,rhs) of (0,_) => (env,[]) | (1,_) => (insert(addOprFromExp(env,rhs), lhs, rhs), []) | (_,Dst.E_Mux(A,isFill, nOrig,Tys as Ty.vectorLength tys,exps))=> let val name=SrcV.name lhs val vs=List.map (fn n=> newLocalWithTy(name,n) ) tys val rhs=Dst.E_Mux(A, isFill,nOrig,Tys,List.map (fn v=>Dst.E_Var v) vs) val stmts=ListPair.map (fn(x,e)=>Dst.S_Assign([x],e)) (vs,exps) in (renameExp(addLocals(env,vs),lhs,rhs),stmts) end |(_,_)=> let val t = newLocal lhs in (rename(addLocal(env, t), lhs, t), [Dst.S_Assign([t], rhs)]) end (*end case*)) end fun bind (env, lhs, rhs) =(case peekGlobal (env, lhs) of SOME x =>((env, [Dst.S_Assign([x], rhs)])) | NONE => (bindLocal (env, lhs, rhs)) (* end case *)) (* set the definition of a variable, where the RHS is either a literal constant or a variable *) fun bindSimple (env as E{tbl, ...}, lhs, rhs) =(case peekGlobal (env, lhs) of SOME x => (env, [Dst.S_Assign([x], rhs)]) | NONE => (VT.insert tbl (lhs, DEF rhs); (env, [])) (* end case *)) (* at the end of a block, we need to assign any pending expressions to locals. The * blkStms list and the resulting statement list are in reverse order. *) fun flushPending (E{tbl,types, functs,locals}, blkStms) = let fun doVar (x, TREE e, (locals, stms)) = let val t = newLocal x in VT.insert tbl (x, DEF(Dst.E_Var t)); (t::locals, Dst.S_Assign([t], e)::stms) end | doVar (_, _, acc) = acc val (locals, stms) = VT.foldi doVar (locals, blkStms) tbl in (E{tbl=tbl, types=types,functs=functs,locals=locals}, stms) end fun doPhi ((lhs, rhs), (env, predBlks : Dst.stm list list)) = let (* t will be the variable in the continuation of the JOIN *) val t = newLocal lhs val predBlks = ListPair.map (fn (x, stms) => Dst.S_Assign([t], useVar env x)::stms) (rhs, predBlks) in (rename (addLocal(env, t), lhs, t), predBlks) end (* fun endScope (E{locals, ...}, stms) = Dst.Block{ locals = List.rev locals, body = stms } *) fun endScope (env, stms) = let val env'=addOprFromStmt(env, stms) val (types,opr)=peelEnv(env') in Dst.BlockWithOpr{ locals= List.rev(peelEnvLoc env), types= types, opr=opr, body = stms } end end (* Certain IL operators cannot be compiled to inline expressions. Return * false for those and true for all others. *) fun isInlineOp rator = let fun chkTensorTy (Ty.TensorTy[]) = true | chkTensorTy (Ty.TensorTy[_]) = true | chkTensorTy (Ty.TensorTy[_, _]) = Target.inlineMatrixExp | chkTensorTy _ = false in case rator of SrcOp.LoadVoxels(_, 1) => true | SrcOp.LoadVoxels _ => false | SrcOp.EigenVecs2x2 => false | SrcOp.EigenVecs3x3 => false | SrcOp.EigenVals2x2 => false | SrcOp.EigenVals3x3 => false (* | SrcOp.Zero _ => Target.inlineMatrixExp*) | _ => true (*when true calls binding *) (* end case *) end (*HERE- since we are using arrays, nothing can be inline Fix later if it needs to be fixed*) (* is a CONS inline? *) fun isInlineCons ty = (*(case ty of Ty.SeqTy(Ty.IntTy, _) => true | Ty.TensorTy dd => Target.inlineCons(List.length dd) | Ty.SeqTy _ => false (*CCCC-? DO we have this type*) (* | Ty.DynSeqTy ty => false*) | _ => raise Fail(concat["invalid CONS<", Ty.toString ty, ">"]) (* end case *))*) false (* translate a LowIL assignment to a list of zero or more target statements in reverse * order. *) (* translate input-variable initialization to a TreeIL expression *) fun trInitialization (InP.String s) = ([], Dst.E_Lit(Literal.String s)) | trInitialization (InP.Int n) = ([], Dst.E_Lit(Literal.Int n)) | trInitialization (InP.Real f) = ([], Dst.E_Lit(Literal.Float f)) | trInitialization (InP.Bool b) = ([], Dst.E_Lit(Literal.Bool b)) | trInitialization (InP.Tensor([d], vs)) = raise Fail "trInitialization: Tensor" (* | trInitialization (InP.Tensor(shp, vs)) = let fun mk i = Dst.E_Lit(Literal.Float(Vector.sub(vs, i))) fun mkCons (i, [d]) = (Dst.E_Cons(Ty.TensorTy[d], List.tabulate(d, fn j => mk(i+j))), i+d) | mkCons (i, d::dd) = let fun f (i, j, args) = if (j < d) then let val (arg, i) = mkCons(i, dd) in f (i, j+1, arg::args) end else (List.rev args, i) val (args, i) = f (i, 0, []) val cons = Dst.E_Cons(Ty.TensorTy(d::dd), args) in if Target.inlineCons(List.length dd + 1) then (cons, i) else raise Fail "non-inline initialization not supported yet" end val (exp, _) = mkCons(0, shp) in ([], exp) end *) | trInitialization (InP.Tensor _) = raise Fail "trInitialization: Tensor" | trInitialization (InP.Seq vs) = raise Fail "trInitialization: Seq" | trInitialization _ = raise Fail "trInitialization: impossible" (* translate a LowIL assignment to a list of zero or more target statements in reverse * order. *) fun doAssign (env, (lhs, rhs)) = let fun doLHS () = (case peekGlobal(env, lhs) of SOME lhs' => (env, lhs') | NONE => let val t = newLocal lhs in (rename (addLocal(env, t), lhs, t), t) end (* end case *)) (* for expressions that are going to be compiled to a call statement *) fun assignExp (env, exp) = let (* operations that return matrices may not be supported inline *) val (env, t) = doLHS() in (env, [Dst.S_Assign([t], exp)]) end (* force an argument to be stored in something that will be mapped to an l-value *) fun bindVar (env, x) = (case useVar env x of x' as Dst.E_State _ =>(env, x', []) | x' as Dst.E_Var _ => (env, x', []) | e => let val x' = newLocal x in (addLocal(env, x'), Dst.E_Var x', [Dst.S_Assign([x'], e)]) end (* end case *)) val _= LowToS.ASSIGNtoString(lhs,rhs) (* opToString:LowIL.Ops* LowIL.Var list-> int *Just used to print information about the op *) fun opToString (rator,arg)= let val r=SrcOp.toString rator val a= String.concatWith " , " (List.map (fn e=> Dst.toString e) arg) in testp[ "\n ***** New Op**** \n ",r,"\n Args(",a,")"] end in case rhs of Src.STATE x =>let (* Hmm, what to do with State nodes. * They get represented like globals but their other operations register their kind as local * Leads to trouble in their representation * Fix me, once we change representation of state and local vars * currently we load 1 piece when it is a local var. * fun iter([],_)=[] * | iter(e1::es,counter)=[Dst.E_LoadArr(isFill ,e1,oSize, t , Dst.E_Lit(Literal.Int counter))]@ * | iter(es,counter+ IntInf.fromInt e1) * val ops=iter(Pieces,0) *) val (env, vt) = doLHS() val t=Dst.E_State(getStateVar x) val exp=(case (DstV.kind vt, DstV.ty vt) of (Dst.VK_Local,DstTy.TensorTy [oSize])=>let val (isFill,nSize,Pieces)=Target.getVecTy oSize val op1= Dst.E_LoadArr(false ,nSize,oSize, t , Dst.E_Lit(Literal.Int 0)) val splitTy=DstTy.vectorLength Pieces in Dst.E_Mux(false,isFill,oSize,splitTy,[op1]) end | _ => t (*end case *)) in bindSimple (env, lhs,exp) end | Src.VAR x => bindSimple (env, lhs, useVar env x) | Src.LIT lit => bindSimple (env, lhs, Dst.E_Lit lit) | Src.OP(SrcOp.Kernel _, _) => (env, []) | Src.OP(SrcOp.LoadImage(ty, nrrd, info), []) => let val (env, t) = doLHS() in (env, [Dst.S_LoadNrrd(t, ty, nrrd)]) end | Src.OP(SrcOp.Input(InP.INP{ty=Ty.ImageTy _, name, desc, init}), []) => let val (env, t) = doLHS() in case init of SOME(InP.Proxy(nrrd, _)) => (env, [Dst.S_InputNrrd(t, name, desc, SOME nrrd)]) | SOME(InP.Image _) => (env, [Dst.S_InputNrrd(t, name, desc, NONE)]) | _ => raise Fail "bogus initialization for image" (* end case *) end | Src.OP(SrcOp.Input(InP.INP{ty, name, desc, init=NONE}), []) => let val (env, t) = doLHS() in (env, [Dst.S_Input(t, name, desc, NONE)]) end | Src.OP(SrcOp.Input(InP.INP{ty, name, desc, init=SOME init}), []) => let val (env, t) = doLHS() val (stms, exp) = trInitialization init in (env, stms@[Dst.S_Input(t, name, desc, SOME exp)]) end | Src.OP(SrcOp.Inside(info, s), args) => let val [a,b]=List.map (useVar env) args val size=(case (ImageInfo.dim info) of 1 => raise Fail"Inside of 1-D dimension" | 2 => 2 | 3 => 4 (*end case*)) (*separated the position to make it look cleaner*) val k=newLocalWithTy ("Pos_"^SrcV.name lhs ,size) val s1= Dst.S_Assign([k],a) val rhs=Dst.E_Op(DstOp.Inside(info,s),[Dst.E_Var k,b]) val (env,s)=bind(addLocal(env,k),lhs,rhs) in (env,[s1]@s) end | Src.OP(SrcOp.Translate v, [a])=> let (*Result is a vector so we have to use Mux*) val dim = ImageInfo.dim v val splitTy=DstTy.vectorLength [dim] val op1= Dst.E_Op(DstOp.Translate v,[(useVar env a)]) val exp= (case dim of 1=> op1 | 2 => Dst.E_Mux(true,false,dim,splitTy,[op1]) | 3=> Dst.E_Mux(false,false,dim,splitTy,[op1]) (*end case*)) in bind(env,lhs,exp) end | Src.OP(SrcOp.Transform v,args) => let (*Result is an array so we have to use Store*) val (env2, t) = doLHS() val V=Dst.E_Var t val dim = ImageInfo.dim v val ty=DstTy.TensorTy [dim,dim] val args'=List.map (useVar env) args val a=List.tabulate(dim,(fn n=> Dst.E_Op(DstOp.Transform(v,n),args'))) val (env2,stmt)= (case dim of 1=> bind(env2,lhs,Dst.E_Op(DstOp.Transform (v,1),args')) | 2 =>(env2,[Dst.S_StoreVec(V,0,true,false,dim,ty,DstTy.vectorLength [2,2],a)]) | 3 =>(env2,[Dst.S_StoreVec(V,0,false,true,dim,ty,DstTy.vectorLength [4,4,4] ,a)]) (*end case*)) in (env2,stmt) end | Src.OP(SrcOp.imgLoad(info,dim,oSize),[a])=>let (*create ptr variable vp and index it with stride*) val vp= newLocalPtrTy(SrcV.name lhs,DstTy.AddrTy info) val stmt=Dst.S_Assign([vp], useVar env a) val stride = ImageInfo.stride info val IndexArgs= List.tabulate(oSize, fn n=> Dst.E_Op(DstOp.IndexTensor(false,Ty.indexTy [n*stride],Ty.TensorTy [oSize]),[Dst.E_Var vp])) (*create cons expressions*) val (isFill,nSize,Pieces)=Target.getVecTy oSize val exp=LowOpToTreeOp.consVecToTree(nSize,oSize,Pieces,IndexArgs,isFill) (*increase use count so it is easier to read c-file*) fun incUse (Src.V{useCnt, ...}) = (useCnt := !useCnt + 1) val _ = incUse lhs val (env2,stmt2)=bind (addLocal(env, vp), lhs,exp) in (env2, List.rev stmt2@[stmt]) end | Src.OP(SrcOp.IndexTensor e,[a])=> let (*IndexTensor operation is int*ty*ty * The first ty is the list of indexed position and second ty is the type of the argument * When the rhs is a mux(matrix.. or larger arg) then we look for the right argument to index * Otherwise we just pass the variable kind to the tree-il op and let c-util decide how/where to index * The kind of variable decides if there is cast in the c-code *) val a'=(useVar env a) val exp=(case ((SrcOp.IndexTensor e),a') of (SrcOp.IndexTensor(_ ,Ty.indexTy [i],_),Dst.E_Mux(_,_,_,DstTy.vectorLength pieces,ops))=> let fun findLocal(c,i,indexAt,v::vs,a1::args)=let val newsize=c+v in if(newsize>i) then Dst.E_Op(DstOp.IndexTensor(true,Ty.indexTy [indexAt],Ty.TensorTy [v]), [a1]) else findLocal(newsize,i,indexAt-v,vs,args) end val exp =findLocal(0,i,i,pieces,ops) in exp end | (SrcOp.IndexTensor( _ , indTy,argTy),Dst.E_Var v) => (case (DstV.kind v) of TreeIL.VK_Local=>Dst.E_Op(DstOp.IndexTensor(true ,indTy,argTy),[a']) | _ =>Dst.E_Op(DstOp.IndexTensor(false ,indTy,argTy),[a']) (*end case*)) | (SrcOp.IndexTensor( _ , indTy,argTy),_) => Dst.E_Op(DstOp.IndexTensor(true ,indTy,argTy),[a']) (*end case*)) in bind (env, lhs, exp) end | Src.OP(rator,args) => let val args'=List.map (useVar env) args val _ =testp[ "\n ***** New Op \n \t\t",SrcV.name lhs,"-",SrcOp.toString rator,Int.toString(length(args)) , " Args(\n\t",String.concatWith"\n\t\t," (List.map (fn e=> Dst.toString e) args'),")"] (*foundVec:SrcOp.op* DstOP.ops*int*DstVar list *DstVar list * Found a vector operation. * Rewrites to correctly-sized vector operations *) fun foundVec(origrator,rator,oSize,argsS,argsV)= let val (isFill,nSize,Pieces)=Target.getVecTy oSize val (env, t) = doLHS() val stmt = LowOpToTreeOp.vecToTree(t,origrator,rator,nSize,oSize,Pieces,argsS,argsV,isFill) val (envv,stmts)=(case stmt of Dst.S_Assign(_,exp)=> bind (env, lhs, exp) | stmt=> (env,[stmt]) (*end case*)) val _ = testp["\n \n\t",Dst.toStringS stmt] in (envv,stmts) end in case (rator,args') of (SrcOp.addVec n,_) => foundVec(rator,DstOp.addVec,n,[],args') | (SrcOp.subVec n,_) => foundVec(rator,DstOp.subVec,n,[],args') | (SrcOp.prodScaV n,e1::es) => foundVec(rator,DstOp.prodScaV ,n, [e1], es) | (SrcOp.prodVec n,_) => foundVec(rator,DstOp.prodVec,n,[],args') | (SrcOp.sumVec n ,_) => foundVec(rator,DstOp.addVec ,n,[],args') | (SrcOp.Floor n ,_) => foundVec(rator,DstOp.Floor ,n,[],args') | (SrcOp.ProjectTensor(_,n,_,_),_) => foundVec(rator,DstOp.addVec ,n,[],args') | (SrcOp.Clamp (Ty.TensorTy[n]) ,_) => foundVec(rator,DstOp.clampVec ,n,[],args') | (SrcOp.Lerp (Ty.TensorTy[n]) ,[a,b,c]) => foundVec(rator,DstOp.lerpVec ,n,[c],[a,b]) | (SrcOp.Normalize n,_) => foundVec(rator,DstOp.Normalize ,n,[],args') | (SrcOp.addSca ,[a,Dst.E_Lit (Literal.Int 0)]) => assignExp (env,a) | (SrcOp.addSca ,[Dst.E_Lit (Literal.Int 0),a]) => assignExp (env,a) | (SrcOp.subSca ,[a,Dst.E_Lit (Literal.Int 0)]) => assignExp (env,a) | _ => let val Trator = LowOpToTreeOp.expandOp rator val exp = Dst.E_Op(Trator, args') in if isInlineOp rator then (bind (env, lhs, exp)) else (assignExp (env, exp)) end (*end case *) end | Src.APPLY(f, args) => bind (env, lhs, Dst.E_Apply(f, List.map (useVar env) args)) | Src.CONS(ty as Ty.TensorTy[oSize], args) => let (* CONS of a vector with real arguments * If lhs is a local var then we assume lhs will be represented with vectors * and we use Mux and E_ConsVec on pieces, much like a vector op * Otherwise, assume it's an array and use S_Cons * testp["\n ****** here **\n ",LowToS.rhsToString (Src.CONS(ty , args)), "\n\t* lhs " ,SrcV.name lhs,"type",Ty.toString(SrcV.ty lhs), "\nt",DstV.name t,"-kind:",Dst.kindToString (DstV.kind t),"\n"] *) val args' = List.map (useVar env) args val (env2, t) = doLHS() in case DstV.kind t of TreeIL.VK_Local=> let val (isFill,nSize,Pieces)=Target.getVecTy oSize val exp = LowOpToTreeOp.consVecToTree(nSize,oSize,Pieces,args',isFill) val _ =testp["\nExp\n",Dst.toString exp] in bind (env2, lhs, exp) end | _ => (env2, [Dst.S_Cons(t, oSize, args')]) (*end case*) end (*| Src.CONS(ty as Ty.TensorTy [_,2], args)=>let val args'=List.map (useVar env) args val _=case args' val _=(case ) val (env2, t) = doLHS() in (env2,[Dst.S_Cons(t,4,args')]) end*) | Src.CONS(ty as Ty.TensorTy [_,j], args) =>let (* Cons is a matrix with vector arguments * Each Vector Arg could be a global, local or state variable * Which means their representation could be different * when it is a global or state then we copy it S_Copy * when it is a local or other then we use S_Store *) val args' = List.map (useVar env) args val _ =testp["******************************\n CONS_Matrix \n With Args", Dst.toStrings args'] val (env2, t) = doLHS() (*Vector params for last matrix index. Retrieved in case we use S_Store*) val (isFill,nSize,Pieces)=Target.getVecTy j val splitTy=LowILTypes.vectorLength Pieces val n=length Pieces val A =LowOpToTreeOp.isAlignedStore(isFill,n) fun f ([], _ ) = [] | f (e1::es,count)=let val t=(case e1 of Dst.E_State v=> Dst.S_Copy(Dst.E_Var t, e1, count,j) | Dst.E_Var v=>(case (DstV.kind v) of TreeIL.VK_Global => Dst.S_Copy(Dst.E_Var t, e1, count,j) | _ => Dst.S_StoreVec(Dst.E_Var t,count,A,isFill,j,ty,splitTy, [e1]) (*end case*)) | _ => Dst.S_StoreVec(Dst.E_Var t,count,A,isFill,j,ty,splitTy, [e1]) (*end case*)) in [t]@f(es,count+j) end val stmts=f (args',0) val _ =testp["\n returning statements \n",Dst.toStringSs stmts,"\n end ******************************\n"] in (env2, List.rev stmts) end | Src.CONS(ty as Ty.TensorTy [_,i,j], args) =>let (* CONS is larger tensor with non-vector arguments * Hooray! We can assume everything is an array and S_Copy everything *) val args' = List.map (useVar env) args val _ =testp["******************************\n CONS_Matrix \n ", "Number of args",Int.toString (length args),"---\n",Dst.toStrings args'] val (env2, t) = doLHS() val shift=j*i (*New row index shift *) fun f ([], _ ) = [] | f (e1::es,count)= [Dst.S_Copy(Dst.E_Var t, e1, count,shift)]@ f(es,count+shift) val stmts=f (args',0) val _ =testp["\n returning statements \n"^Dst.toStringSs stmts,"\n end ******************************\n"] in (env2, List.rev stmts) end | Src.EINAPP _=> raise Fail "EINAPP in Low-IL to Tree-IL" (* end case *) end (* In order to reconstruct the block-structure from the CFG, we keep a stack of open ifs. * the items on this stack distinguish between when we are processing the then and else * branches of the if. *) datatype open_if (* working on the "then" branch. The fields are statments that preceed the if, the condition, * and the else-branch node. *) = THEN_BR of Dst.stm list * Dst.exp * Src.node (* working on the "else" branch. The fields are statments that preceed the if, the condition, * the "then" branch statements, and the node that terminated the "then" branch (will be * a JOIN, DIE, or STABILIZE). *) | ELSE_BR of Dst.stm list * Dst.exp * Dst.stm list * Src.node_kind fun mkBlockOrig(Dst.BlockWithOpr{ locals ,types,opr,body})= Dst.Block{locals=locals ,body=body} fun peelBlockOrig(env,Dst.BlockWithOpr{ locals ,types,opr,body})=let val env= setEnv(env,types,opr) in (env,Dst.Block{locals=locals ,body=body}) end fun decCount ( Src.V{useCnt, ...}) = let val n = !useCnt - 1 in useCnt := n; (0 >= n) end fun trCFG (env, prefix, finish, cfg) = let fun join (env, [], _, Src.JOIN _) = raise Fail "JOIN with no open if" | join (env, [], stms, _) = let val env'=addOprFromStmt(env, stms) in endScope (env', prefix @ List.rev stms) end | join (env, THEN_BR(stms1, cond, elseBr)::stk, thenBlk, k) = let val (env, thenBlk) = flushPending (env, thenBlk) val env'=addOprFromStmt(env, stms1) in doNode (env', ELSE_BR(stms1, cond, thenBlk, k)::stk, [], elseBr) end | join (env, ELSE_BR(stms, cond, thenBlk, k1)::stk, elseBlk, k2) = let val (env, elseBlk) = flushPending (env, elseBlk) in case (k1, k2) of ( Src.JOIN{phis, succ, ...}, Src.JOIN _) => let val (env, [thenBlk, elseBlk]) = List.foldl doPhi (env, [thenBlk, elseBlk]) (!phis) val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk) val env'=addOprFromStmt(env, stm::stms) in doNode (env', stk, stm::stms, !succ) end | ( Src.JOIN{phis, succ, ...}, _) => let val (env, [thenBlk]) = List.foldl doPhi (env, [thenBlk]) (!phis) val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk) in doNode (addOprFromStmt(env, [stm]), stk, stm::stms, !succ) end | (_, Src.JOIN{phis, succ, ...}) => let val (env, [elseBlk]) = List.foldl doPhi (env, [elseBlk]) (!phis) val stm = mkIf(cond, List.rev thenBlk, List.rev elseBlk) in doNode (addOprFromStmt(env, [stm]), stk, stm::stms, !succ) end | (_, _) => raise Fail "no path to exit unimplemented" (* FIXME *) (* end case *) end and doNode (env, ifStk : open_if list, stms, nd) = (* testp ["******************* \n doNode\n ",LowToS.printNode (Nd.kind nd),"\n"]*) (case Nd.kind nd of Src.NULL => raise Fail "unexpected NULL" | Src.ENTRY{succ} => doNode (env, ifStk, stms, !succ) | k as Src.JOIN{phis, succ, ...} => (join (env, ifStk, stms, k)) | Src.COND{cond, trueBranch, falseBranch, ...} => let val cond = useVar env cond val (env, stms) = flushPending (env, stms) in doNode (env, THEN_BR(stms, cond, !falseBranch)::ifStk, [], !trueBranch) end | Src.COM {text, succ, ...} => doNode (env, ifStk, Dst.S_Comment text :: stms, !succ) | Src.ASSIGN{stm, succ, ...} => let val (env, stms') = doAssign (env, stm) in doNode (addOprFromStmt(env, stms') , ifStk, stms' @ stms, !succ) end | Src.MASSIGN{stm=(ys, rator, xs), succ, ...} => let fun doit () = let fun doLHSVar (y, (env, ys)) = (case peekGlobal(env, y) of SOME y' => ((env, y'::ys)) | NONE => let val t = newLocal y in (rename (addLocal(env, t), y, t), t::ys) end (* end case *)) val (env, ys) = List.foldr doLHSVar (env, []) ys val Trator = LowOpToTreeOp.expandOp rator val exp = Dst.E_Op(Trator, List.map (useVar env) xs) val stm = Dst.S_Assign(ys, exp) in doNode (env, ifStk, stm :: stms, !succ) end in case rator of SrcOp.Print _ => if Target.supportsPrinting() then doit () else doNode (env, ifStk, stms, !succ) | _ => doit() (* end case *) end | Src.NEW{strand, args, succ, ...} => raise Fail "NEW unimplemented" | Src.SAVE{lhs, rhs, succ, ...} => let (* There is a Save and lhs is an array, * Stmt depends on how rhs exp is stored * If rhs is stored as a vector then use S_StoreVec(Mux, or Local Vector) * If rhs is an array uses S_copy (higher order tensors) * otherwise regular save *) val x=getStateVar lhs val rhs2=useVar env rhs val _ =testp["\n *********** \n FOUND SAVE \n\t StateVar: ",Dst.stateVarToString x, ": Rest rhs: ",Dst.toString rhs2,"--end "] fun size n=foldl (fn (a,b) => b*a) 1 n val stm=(case rhs2 of Dst.E_Mux(A,isFill, oSize,splitTy,args) => (decCount rhs ;Dst.S_StoreVec( Dst.E_State x,0, A,isFill, oSize,Ty.TensorTy [oSize],splitTy,args)) | Dst.E_Var rhs3 => (case (DstV.kind rhs3,DstV.rTy rhs3) of ( _ ,Ty.TensorTy []) => Dst.S_Save([x], rhs2) | (Dst.VK_Local,Ty.TensorTy [oSize]) =>let val (isFill,nSize,Pieces)=Target.getVecTy oSize in Dst.S_StoreVec( Dst.E_State x,0, false,isFill, oSize,Ty.TensorTy [oSize],Ty.vectorLength Pieces,[rhs2]) end | (_,Ty.TensorTy xs) => Dst.S_Copy( Dst.E_State x, rhs2,0,size xs) | _ => Dst.S_Save([x], rhs2) (*end case*)) | Dst.E_State rhs3 => (case (DstSV.ty rhs3) of Ty.TensorTy xs => Dst.S_Copy( Dst.E_State x, rhs2,0,size xs) | _ => Dst.S_Save([x], rhs2) (*end case*)) | _ => Dst.S_Save([x], rhs2) (*end case*)) val _ = testp [" \nSrc.Save: ",LowToS.SAVEtoString(lhs,rhs),"\n New stmt --", Dst.toStringS stm,"\nend save **************\n"] val stmts=stm::stms in doNode (addOprFromStmt(env, stmts), ifStk, stmts, !succ) end | k as Src.EXIT{kind, live, ...} => (case kind of ExitKind.FRAGMENT => endScope (env, prefix @ List.revAppend(stms, finish env)) | ExitKind.SINIT => let (* FIXME: we should probably call flushPending here! *) val suffix = finish env @ [Dst.S_Exit[]] in endScope (env, prefix @ List.revAppend(stms, suffix)) end | ExitKind.RETURN => let (* FIXME: we should probably call flushPending here! *) val suffix = finish env @ [Dst.S_Exit(List.map (useVar env) live)] in endScope (env, prefix @ List.revAppend(stms, suffix)) end | ExitKind.ACTIVE => let (* FIXME: we should probably call flushPending here! *) val suffix = finish env @ [Dst.S_Active] in endScope (env, prefix @ List.revAppend(stms, suffix)) end | ExitKind.STABILIZE => let (* FIXME: we should probably call flushPending here! *) val stms = Dst.S_Stabilize :: stms in (* FIXME: we should probably call flushPending here! *) (join (env, ifStk, stms, k)) end | ExitKind.DIE => (join (env, ifStk, Dst.S_Die :: stms, k)) (* end case *)) (* end case *)) in doNode (env, [], [], CFG.entry cfg) end fun trInitially (env, Src.Initially{isArray, rangeInit, iters, create=(createInit, strand, args)}) = let val (env2,iterPrefix) = peelBlockOrig(env,trCFG (env, [], fn _ => [], rangeInit)) (*val (iterPrefix) = mkBlockOrig(trCFG (env, [], fn _ => [], rangeInit))*) fun cvtIter ((param, lo, hi), (env, iters)) = let val param' = newIter param val env = rename (env, param, param') in (env, (param', useVar env lo, useVar env hi)::iters) end val (env, iters) = List.foldr cvtIter (env, []) iters val (env,createPrefix) = peelBlockOrig(env,trCFG (env, [], fn _ => [], createInit)) in (env,{ isArray = isArray, iterPrefix = iterPrefix, iters = iters, createPrefix = createPrefix, strand = strand, args = List.map (useVar env) args }) end fun trMethod (env, Src.Method{name, body}) = let val (env, blk) = peelBlockOrig(env, trCFG (env, [], fn _ => [], body)) (*val (blk)=mkBlockOrig(trCFG (env, [], fn _ => [], body))*) in (env, Dst.Method{name = name, body = blk}) end fun trStrands (env, strands) = let fun tr (Src.Strand{name, params, state, stateInit, methods}, (env, strands')) = let val params' = List.map newParam params val env = ListPair.foldlEq (fn (x, x', env) => rename(env, x, x')) env (params, params') val (env', sInit) = peelBlockOrig(env,trCFG (env, [], fn _ => [], stateInit)) fun callmethod (env, [], M) = (env,M) | callmethod (env, b::es, Ms) = let val (env2,M1) = trMethod(env,b) in callmethod(env2, es, [M1]@Ms) end val (env', methods) = callmethod(env',methods,[]) val strand' = Dst.Strand{ name = name, params = params', state = List.map getStateVar state, stateInit =sInit, methods = methods } in (env', strand'::strands') end val (env', strands') = List.foldl tr (env, []) strands in (env', List.rev strands') end (* split the globalInit into the part that specifies the inputs and the rest of * the global initialization. *) fun splitGlobalInit globalInit = let (* FIXME: can split as soon as we see a non-Input statement! *) fun walk (nd, lastInput, live) = (case Nd.kind nd of Src.ENTRY{succ} => walk (!succ, lastInput, live) | Src.COM{succ, ...} => walk (!succ, lastInput, live) | Src.ASSIGN{stm=(lhs, rhs), succ, ...} => (case rhs of Src.OP(SrcOp.Input _, _) => walk (!succ, nd, lhs::live) | _ => walk (!succ, lastInput, live) (* end case *)) | _ => if Nd.isNULL lastInput then let (* no inputs *) val entry = Nd.mkENTRY() val exit = Nd.mkEXIT(ExitKind.RETURN, []) in Nd.addEdge (entry, exit); {inputInit = Src.CFG{entry=entry, exit=exit}, globalInit = globalInit} end else let (* split at lastInput *) val inputExit = Nd.mkEXIT(ExitKind.RETURN, live) val globalEntry = Nd.mkENTRY() val [gFirst] = Nd.succs lastInput in Nd.replaceInEdge {src = lastInput, oldDst = gFirst, dst = inputExit}; Nd.replaceOutEdge {oldSrc = lastInput, src = globalEntry, dst = gFirst}; { inputInit = Src.CFG{entry = Src.CFG.entry globalInit, exit = inputExit}, globalInit = Src.CFG{entry = globalEntry, exit = Src.CFG.exit globalInit} } end (* end case *)) in walk ( Src.CFG.entry globalInit, Nd.dummy, []) end fun getInfo(env,Init)=let val inputInit' = trCFG (env, [], fn _ => [], Init) in peelBlockOrig(env,inputInit') end fun translate prog = let (* first we do a variable analysis pass on the Low IL *) val prog as Src.Program{props, globalInit, initially, strands} = VA.optimize prog (* FIXME: here we should do a contraction pass to eliminate unused variables that VA may have created *) val _ = (* DEBUG *) LowPP.output (Log.logFile(), "LowIL after variable analysis", prog) val envOrig = newEnv() val globals = List.map (fn x => let val x' = newGlobal x in global(envOrig, x, x'); x' end) ( Src.CFG.liveAtExit globalInit) val {inputInit, globalInit} = splitGlobalInit globalInit val (env, inputInit)=getInfo(envOrig,inputInit) val (env, globalInit)=getInfo(env, globalInit) val (env, strands) = trStrands (env, strands) val (env, initially) = trInitially (env, initially) val (typs,opr)= peelEnv(env) val typsList=TySet.listItems(typs); val oprList=OprSet.listItems(opr); val _=testp[(Fnc.setListToString(typsList,oprList,"--FinalPostStrands--"))] in Dst.Program{ props = props, types=typsList, operations = oprList, globals = globals, inputInit = inputInit, globalInit = globalInit, strands = strands, initially = initially } end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |