(*evalKrn *Evaluate an EIN kernel expression to low-IL ops *iterate over the range of the support, determine differentiation, and calls segements *) structure EvalKrnSet = struct local structure DstIL = LowIL structure DstTy = LowILTypes structure Var = LowIL.Var structure E=Ein structure H=HelperSet in fun lookup e =H.lookup e fun insert e=H.insert e fun find e=H.find e fun intToReal n=H.intToReal n fun assgnCons e=H.assgnCons e fun indexTensor e=H.indexTensor e fun mkAddVec e=H.mkAddVec e fun mkSubSca e= H.mkSubSca e fun mkProdVec e =H.mkProdVec e fun mkDotVec(setO,a,b,last)=H.mkDotVec(setO,last,[a,b]) fun iTos n=Int.toString n fun err str=raise Fail(str) val realTy=DstTy.TensorTy [] val intTy=DstTy.IntTy val testing=false fun testp n =if (testing) then (print(String.concat n);1) else 1 (* convert a rational to a FloatLit.float value. We do this by long division * with a cutoff when we get to 12 digits. *) fun ratToFloat r = (case Rational.explode r of {sign=0, ...} => FloatLit.zero false | {sign, num, denom=1} => FloatLit.fromInt(IntInf.fromInt sign * num) | {sign, num, denom} => let (* normalize so that num <= denom *) val (denom, exp) = let fun lp (n, denom) = if (denom < num) then lp(n+1, denom*10) else (denom, n) in lp (1, denom) end (* normalize so that num <= denom < 10*num *) val (num, exp) = let fun lp (n, num) = if (10*num < denom) then lp(n-1, 10*num) else (num, n) in lp (exp, num) end (* divide num/denom, computing the resulting digits *) fun divLp (n, a) = let val (q, r) = IntInf.divMod(a, denom) in if (r = 0) then (q, []) else if (n < 12) then let val (d, dd) = divLp(n+1, 10*r) in if (d < 10) then (q, (IntInf.toInt d)::dd) else (q+1, 0::dd) end else if (IntInf.div(10*r, denom) < 5) then (q, []) else (q+1, []) (* round up *) end val digits = let val (d, dd) = divLp (0, num) in (IntInf.toInt d)::dd end in FloatLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp} end (* end case *)) (* expand the EvalKernel operations into vector operations. The parameters * are * result -- the lhs variable to store the result * d -- the vector width of the operation, which should be equal * to twice the support of the kernel * h -- the kernel * k -- the derivative of the kernel to evaluate * * The generated code is computing * * result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... ) * * as a d-wide vector operation, where n is the degree of the kth derivative * of h and the a_i are coefficient vectors that have an element for each * piece of h. The computation is implemented as follows * * m_n = x * a_n * s_{n-1} = a_{n-1} + m_n * m_{n-1} = x * s_{n-1} * s_{n-2} = a_{n-2} + m_{n-1} * m_{n-2} = x * s_{n-2} * ... * s_1 = a_1 + m_2 * m_1 = x * s_1 * result = a_0 + m_1 * * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml). *) fun expandEvalKernel (setOrig,pre,d, h, k, x) = let val {isCont, segs} = Kernel.curve (h, k) (* degree of polynomial *) val deg = List.length(hd segs) - 1 (*segs is length 2*support, inner list is listof poynomial*) val segs = Vector.fromList (List.rev (List.map Vector.fromList segs)) fun coefficient d i = Literal.Float(ratToFloat (Vector.sub (Vector.sub(segs, i), d))) val ty = DstTy.vecTy d val coeffs = List.tabulate (deg+1,fn i => Var.new("P"^Int.toString i, ty)) fun filterLit([],vars,code,opset)=(vars,code,opset) | filterLit((lhs,rhs)::es,vars,code,opset)=let val (opset,var) = lowSet.filter(opset,(lhs,rhs)) in (case var of NONE=> filterLit(es,vars@[lhs], code@[(lhs, rhs)],opset) | SOME v =>filterLit(es,vars@[v],code,opset) (*end case*)) end val (coeffVecs,setOrig) = let fun mk (x, (i::xs, code,opset)) = let val lits = List.tabulate(d, coefficient i) val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy)) val code0 =ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits) val (vars,code0,opset)=filterLit(code0,[],[],opset) val code =code@code0@[(x, DstIL.CONS(Var.ty x, vars))] in (xs, code,opset) end val n=List.tabulate(deg+1, fn e=>e ) val (a,b,c)=(List.foldr mk (n, [],setOrig) (List.rev coeffs)) in (b,c) end val coeffVecsCode=List.map (fn(x,y)=>DstIL.ASSGN (x,y)) coeffVecs fun getSet([],done,opset,cnt)=(done,opset,cnt) | getSet( DstIL.ASSGN(lhs,rhs)::es,done,opset,cnt)=let val (opset,var) = lowSet.filter(opset,(lhs,rhs)) in (case var of NONE => getSet(es,done@[DstIL.ASSGN(lhs,rhs)], opset,cnt) | SOME v=> (("TASH:replacing"^DstIL.Var.toString(lhs));getSet(es,done@[DstIL.ASSGN(lhs,DstIL.VAR v)], opset,cnt+1)) (*end case*)) end | getSet (e1::es, done, opset,cnt)=getSet(es,done@[e1],opset,cnt) (*get dot product and addition of list of coeffs*) fun mkdot(setO,[e2,e1],code)=let val (setA,vA,A)= mkProdVec(setO,d,[x,e2]) val (setB,vB,B)=mkAddVec(setA,d,[e1,vA]) in (setB,vB,code@A@B) end | mkdot(setO,e2::e1::es,code)= let val (setA,vA,A)= mkProdVec(setO,d,[x,e2]) val (setB,vB,B)=mkAddVec(setA,d,[e1,vA]) in mkdot(setB,vB::es,code@A@B) end | mkdot (setO,[e1],[])= mkProdVec(setO,d,[x,e1]) | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs" val (setC,vC,code)= mkdot(setOrig,List.rev coeffs, []) val _ =(String.concat["\n coeffVecs code :",Int.toString(length(coeffVecsCode))," other code code: ",Int.toString(length(code))]) in (setC,vC,coeffVecsCode@code) end (*mkkrns:dict*string*E.params*Var List*sum_id list*(E.mu*E.mu) list*Kernel*int*int*int*int * kernels * comments on functions * * evalDels:dictionary*(E.mu*E.mu)list->int * evaluate each delta and therefore each differentiation level for each kernel * * mkSimpleOp:dict*string*E.params*Var list*E.body * -> Var * LowIL.assign list * turn position into low-IL op * *mkpos:(E.kernel_id*int*E.pos)list* Var list* Var list *dict*int* LowIL.assign list* int * -> Var * LowIL.assign list * bind summation indices by creating mapp and evaluate position * * consfn: Var list list*Var list*LowIL.assign list * ->Var * LowIL.assign list * con everything on the list, makes vectors * * evalK:(E.kernel_id*int*E.pos)list* var list*int*int*param_id*LowIL.assign list * ->Var * LowIL.assign list * evaluate kernel with segments *) fun mkkrns(setOrig,mappOrig,lhs,params,args, krns,h,sid,lb,range0,range1)=let fun evalDels(mapp,dels)=List.foldl(fn(x,y)=>x+y) 0 (List.map (fn(i,j)=>H.deltaToInt(mapp,i,j)) dels) fun mkSimpleOp(setO,mapp,lhs,params,args, E.Op2(E.Sub,E.Tensor(t1,ix1),E.Value v1))=let val (setA,vA,A)=indexTensor(setO,mapp,(lhs,params,args, t1,ix1,realTy)) in (case (find(v1,mapp)) of 0=>(setA,vA,A) | j=>let val (setB,vB,B)= intToReal(setA, j) val (setC,vC,C)=mkSubSca(setB, [vA,vB]) in (setC,vC,A@B@C) end (*end case*)) end fun mkpos(setO,k,fin,l,dict, i,code,n)= (case (k,i) of ([],_)=>(setO,fin,code) | ((_,_,pos1)::ks,0)=>let val _=testp ["\n insert",iTos n, "->",iTos lb] val mapp=insert (n, lb) dict val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args, pos1) val e=l@[l'] val mapp'=insert (n, 0) dict in mkpos(opset',ks,fin@[e],[],mapp',range0,code@code',n+1) end | ((_,_,pos1)::es,_) =>let val _=testp ["\n insert",iTos n, "->",iTos (lb+i)] val mapp= insert (n, lb+i) dict val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args, pos1) in mkpos(opset',k,fin,l@[l'],dict,i-1,code@code',n) end (*end case*)) fun consfn(setO,[],_,rest, code)=(setO,rest,code) | consfn(setO,e1::es,n,rest,code)=let val (setA, vA,A)= assgnCons(setO, "h"^iTos n,[length(e1)],List.rev e1) in consfn(setA,es, n+1,rest@[vA], code@A) end fun evalK(setO,[],[],newId,code)=(setO,newId,code) | evalK(setO,kn::kns,x::xs,newId,code)=let val (_,dk,pos) =kn val directionX= (case pos of E.Op2(E.Sub,E.Tensor (_,[E.C directionX]),_)=>directionX | _ => 0 (*end case*)) val name=String.concat["h",iTos directionX,"_",iTos dk] val _ =testp["\n",Var.toString x," = ",name] val (opsetK,id,kcode)= expandEvalKernel (setO,name,range1,h,dk, x) in evalK(opsetK,kns,xs,newId@[id],code@kcode) end | evalK _ =raise Fail "Non-equal variable list, error in mkKrns" val newkrns= List.map (fn (id,d1,pos)=>(id,evalDels(mappOrig,d1),pos)) krns val _ = testp["\n\n ***** Differentiation value of kernels ****\n "] val (set2,lftkrn,poscode)=mkpos(setOrig,newkrns,[],[],mappOrig,range0,[],sid) val (set3,lft,conscode)=consfn(set2,lftkrn,0,[],[]) val (set4,ids, evalKcode)=evalK(set3,newkrns, lft,[],[]) val _ = List.map (fn e=>testp["\n IDS",Var.toString e,","]) ids (*returns list in order h0, h1, h2*) in (set4, ids, poscode@conscode@evalKcode) end end (* local *) end
Click to toggle
does not end with </html> tag
does not end with </body> tag
The output has ended thus: (set4, ids, poscode@conscode@evalKcode) end end (* local *) end