(* util.sml * * Utility code for Simplification. * * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) * * COPYRIGHT (c) 2015 The University of Chicago * All rights reserved. *) structure Util : sig (* return information about a reduction operator *) val reductionInfo : Var.t -> { rator : Var.t, (* primitive operator *) init : Literal.t, (* identity element to use for initialization *) mvs : SimpleTypes.meta_arg list (* meta-variable arguments for primitive application *) } (* convert a block into a function by closing over its free variables *) val makeFunction : string * Simple.block * SimpleTypes.ty -> Simple.func * Simple.var list end = struct structure S = Simple structure BV = BasisVars structure L = Literal structure R = RealLit structure VMap = SimpleVar.Map fun reductionInfo rator = if Var.same(BV.red_all, rator) then {rator = BV.op_and, init = L.Bool true, mvs = []} else if Var.same(BV.red_exists, rator) then {rator = BV.op_or, init = L.Bool false, mvs = []} else if Var.same(BV.red_max, rator) then {rator = BV.fn_max, init = L.Real R.negInf, mvs = []} else if Var.same(BV.red_mean, rator) then raise Fail "FIXME: 'mean' reduction not yet supported" else if Var.same(BV.red_min, rator) then {rator = BV.fn_min, init = L.Real R.posInf, mvs = []} else if Var.same(BV.red_product, rator) then {rator = BV.mul_rr, init = L.Real R.one, mvs = []} else if Var.same(BV.red_sum, rator) then {rator = BV.add_tt, init = L.Real R.one, mvs = [SimpleTypes.SHAPE[]]} else if Var.same(BV.red_variance, rator) then raise Fail "FIXME: 'variance' reduction not yet supported" else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator") local val n = ref 0 fun mkFuncId (name, ty) = let val id = !n in n := id + 1; SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty) end in fun makeFunction (name, blk, resTy) = let val freeVars = ref [] fun cvtVar (env, x) = (case VMap.find(env, x) of SOME x' => (env, x') | NONE => let val x' = SimpleVar.copy(x, SimpleVar.FunParam) in freeVars := (x, x') :: !freeVars; (VMap.insert(env, x, x'), x') end (* end case *)) fun cvtVars (env, xs) = let fun cvt (x, (env, xs')) = let val (env, x') = cvtVar (env, x) in (env, x'::xs') end in List.foldr cvt (env, []) xs end fun newVar (env, x) = let val x' = SimpleVar.copy(x, SimpleVar.LocalVar) in (VMap.insert(env, x, x'), x') end fun cvtBlock (env, S.Block stms) = let fun cvtStms (env, [], stms') = (env, S.Block(List.rev stms')) | cvtStms (env, stm::stms, stms') = let val (env, stm') = cvtStm (env, stm) in cvtStms (env, stms, stm'::stms') end in cvtStms (env, stms, []) end and cvtStm (env, stm) = (case stm of S.S_Var(x, NONE) => let val (env, x') = newVar (env, x) in (env, S.S_Var(x', NONE)) end | S.S_Var(x, SOME e) => let val (env, e') = cvtExp (env, e) val (env, x') = newVar (env, x) in (env, S.S_Var(x', SOME e')) end | S.S_Assign(x, e) => let val (env, e') = cvtExp (env, e) val (env, x') = cvtVar (env, x) in (env, S.S_Assign(x', e')) end | S.S_IfThenElse(x, b1, b2) => let val (env, x') = cvtVar (env, x) val (env, b1') = cvtBlock (env, b1) val (env, b2') = cvtBlock (env, b2) in (env, S.S_IfThenElse(x', b1', b2')) end | S.S_Foreach(x, xs, b) => let val (env, x') = cvtVar (env, x) val (env, xs') = cvtVar (env, xs) val (env, b') = cvtBlock (env, b) in (env, S.S_Foreach(x', xs', b')) end | S.S_New(name, args) => let val (env, args') = cvtVars (env, args) in (env, S.S_New(name, args')) end | S.S_Continue => (env, stm) | S.S_Die => (env, stm) | S.S_Stabilize => (env, stm) | S.S_Return x => let val (env, x') = cvtVar (env, x) in (env, S.S_Return x') end | S.S_Print xs => let val (env, xs') = cvtVars (env, xs) in (env, S.S_Print xs') end | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce" (* end case *)) and cvtExp (env, exp) = (case exp of S.E_Var x => let val (env, x') = cvtVar (env, x) in (env, S.E_Var x') end | S.E_Lit _ => (env, exp) | S.E_Select(x, fld) => let val (env, x') = cvtVar (env, x) in (env, S.E_Select(x', fld)) end | S.E_Apply(f, args, ty) => let val (env, args') = cvtVars (env, args) in (env, S.E_Apply(f, args', ty)) end | S.E_Prim(f, mvs, args, ty) => let val (env, args') = cvtVars (env, args) in (env, S.E_Prim(f, mvs, args', ty)) end | S.E_Tensor(args, ty) => let val (env, args') = cvtVars (env, args) in (env, S.E_Tensor(args', ty)) end | S.E_Seq(args, ty) => let val (env, args') = cvtVars (env, args) in (env, S.E_Seq(args', ty)) end | S.E_Slice(x, indices, ty) => let fun cvt (NONE, (env, idxs)) = (env, NONE::idxs) | cvt (SOME x, (env, idxs)) = let val (env, x') = cvtVar (env, x) in (env, SOME x' :: idxs) end val (env, x') = cvtVar (env, x) val (env, indices') = List.foldr cvt (env, []) indices in (env, S.E_Slice(x', indices', ty)) end | S.E_Coerce{srcTy, dstTy, x} => let val (env, x') = cvtVar (env, x) in (env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'}) end | S.E_LoadSeq _ => (env, exp) | S.E_LoadImage _ => (env, exp) (* end case *)) val (env, blk) = cvtBlock (VMap.empty, blk) val (args, params) = ListPair.unzip (List.rev (! freeVars)) val fnTy = SimpleTypes.T_Fun(List.map SimpleVar.typeOf params, resTy) in (S.Func{f=mkFuncId(name, fnTy), params=params, body=blk}, args) end end (* local *) end
Click to toggle
does not end with </html> tag
does not end with </body> tag
The output has ended thus: (S.Func{f=mkFuncId(name, fnTy), params=params, body=blk}, args) end end (* local *) end