(* apply.sml * * Apply EIN operator arguments to EIN operator. * * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) * * COPYRIGHT (c) 2015 The University of Chicago * All rights reserved. *) structure Apply : sig val apply : Ein.ein * int * Ein.ein -> (bool * Ein.ein) end = struct structure E = Ein structure IMap = IntRedBlackMap fun mapId (i, dict, shift) = (case IMap.find(dict, i) of NONE => i + shift | SOME j => j (* end case *)) fun mapIndex (ix, dict, shift) = (case IMap.find(dict, ix) of NONE => E.V(ix + shift) | SOME j => j (* end case *)) fun mapId2 (i, dict, shift) = (case IMap.find(dict, i) of NONE => ( print(concat["Error: ", Int.toString i, " is out of range\n"]); i+shift) | SOME j => j (* end case *)) fun rewriteSubst (e, subId, mx, paramShift, sumShift) = let fun insertIndex ([], _, dict, shift) = (dict, shift) | insertIndex (e::es, n, dict, _) = let val shift = (case e of E.V ix => ix - n | E.C i => i - n) in insertIndex(es, n+1, IMap.insert(dict, n, e), shift) end val (subMu, shift) = insertIndex(mx, 0, IMap.empty, 0) val shift' = Int.max(sumShift, shift) fun mapMu (E.V i) = mapIndex(i, subMu, shift') | mapMu c = c fun mapAlpha mx = List.map mapMu mx fun mapSingle i = let val E.V v = mapIndex(i, subMu, shift') in v end fun mapSum l = List.map (fn (a, b, c) => (mapMu a, b, c)) l fun mapParam id = mapId2(id, subId, 0) fun apply e = (case e of E.Const _ => e | E.ConstR _ => e | E.Tensor(id, mx) => E.Tensor(mapParam id, mapAlpha mx) | E.Delta(i, j) => E.Delta(mapMu i,mapMu j) | E.Epsilon(i, j, k) => E.Epsilon(mapSingle i, mapSingle j, mapSingle k) | E.Eps2(i, j) => E.Eps2(mapSingle i, mapSingle j) | E.Field(id, mx) => E.Field(mapParam id, mapAlpha mx) | E.Lift e1 => E.Lift(apply e1) | E.Conv (v, mx, h, ux) => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux) | E.Partial mx => E.Partial (mapAlpha mx) | E.Apply(e1, e2) => E.Apply(apply e1, apply e2) | E.Probe(f, pos) => E.Probe(apply f, apply pos) | E.Value _ => raise Fail "expression before expand" | E.Img _ => raise Fail "expression before expand" | E.Krn _ => raise Fail "expression before expand" | E.Sum(c, esum) => E.Sum(mapSum c, apply esum) | E.Op1(E.PowEmb(sx, n), e1) => E.Op1(E.PowEmb(mapSum sx, n), apply e1) | E.Op1(op1, e1) => E.Op1(op1, apply e1) | E.Op2(op2, e1, e2) => E.Op2(op2, apply e1, apply e2) | E.Opn(opn, e1) => E.Opn(opn, List.map apply e1) (* end case *)) in apply e end (* params subst *) fun rewriteParams (params, params2, place) = let val beg = List.take(params, place) val next = List.drop(params, place+1) val params' = beg@params2@next val n= length params val n2 = length params2 val nbeg = length beg val nnext = length next fun createDict (0, shift1, shift2, dict) = dict | createDict (n, shift1, shift2, dict) = createDict (n-1, shift1, shift2, IMap.insert (dict, n+shift1, n+shift2)) val origId = createDict (nnext, place, place+n2-1, IMap.empty) val subId = createDict (n2, ~1, place-1, IMap.empty) in (params', origId, subId, nbeg) end (* Looks for params id that match substitution *) fun apply (e1 as E.EIN{params, index, body}, place, e2) = let val E.EIN{params=params2, index=index2, body=body2} = e2 val changed = ref false val (params', origId, substId, paramShift) = rewriteParams(params, params2, place) val sumIndex = ref(length index) fun rewrite (id, mx, e) = let val x = !sumIndex in if (id = place) then if (length mx = length index2) then ( changed := true; rewriteSubst (body2, substId, mx, paramShift, x)) else raise Fail "argument/parameter mismatch" else (case e of E.Tensor(id, mx) => E.Tensor(mapId(id, origId, 0), mx) | E.Field(id, mx) => E.Field(mapId(id, origId, 0), mx) | _ => raise Fail "term to be replaced is not a Tensor or Fields" (* end case *)) end fun sumI e = let val (E.V v,_,_) = List.last e in v end fun apply b = (case b of E.Tensor(id, mx) => rewrite (id, mx, b) | E.Field(id, mx) => rewrite (id, mx, b) | E.Lift e1 => E.Lift(apply e1) | E.Conv(v, mx, h, ux) => E.Conv(mapId(v, origId, 0), mx, mapId(h, origId, 0), ux) | E.Apply(e1, e2) => E.Apply(apply e1, apply e2) | E.Probe(f, pos) => E.Probe(apply f, apply pos) | E.Value _ => raise Fail "expression before expand" | E.Img _ => raise Fail "expression before expand" | E.Krn _ => raise Fail "expression before expand" | E.Sum(c, esum) => ( (* QUESTION: should we flag a change here? *) sumIndex := sumI c; E.Sum(c, apply esum)) | E.Op1(E.PowEmb(sx, n), e1) => ( (* QUESTION: should we flag a change here? *) sumIndex := sumI sx; E.Op1(E.PowEmb(sx, n), apply e1)) | E.Op1(op1, e1) => E.Op1(op1, apply e1) | E.Op2(op2, e1, e2) => E.Op2(op2, apply e1, apply e2) | E.Opn(opn, es) => E.Opn(opn, List.map apply es) | _ => b (* end case *)) val body'' = apply body in (* QUESTION: can we do the following? if (! changed) then SOME(E.EIN{params=params', index=index, body=body''}) else NONE *) (!changed, E.EIN{params=params', index=index, body=body''}) end end
Click to toggle
does not end with </html> tag
does not end with </body> tag
The output has ended thus: dy''}) else NONE *) (!changed, E.EIN{params=params', index=index, body=body''}) end end