Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3553, Thu Jan 7 17:39:31 2016 UTC revision 3602, Mon Jan 18 18:27:32 2016 UTC
# Line 2  Line 2 
2   *Evaluate an EIN kernel expression to low-IL ops   *Evaluate an EIN kernel expression to low-IL ops
3   *iterate over the range of the support, determine differentiation, and calls segements   *iterate over the range of the support, determine differentiation, and calls segements
4   *)   *)
5  structure EvalKrnSet = struct  structure EvalKrn = struct
6      local      local
7      structure DstIL = LowIL      structure DstIL = LowIL
8      structure DstTy = LowILTypes      structure DstTy = LowILTypes
9      structure Var = LowIL.Var      structure Var = LowIL.Var
10      structure E=Ein      structure E=Ein
11      structure H=HelperSet      structure H=Helper
12        structure IMap = IntRedBlackMap
13      in      in
14    
   
     fun lookup e =H.lookup e  
     fun insert e=H.insert e  
15      fun find e=H.find e      fun find e=H.find e
16      fun intToReal n=H.intToReal n      fun intToReal n=H.intToReal n
17      fun assgnCons e=H.assgnCons e      fun assgnCons e=H.assgnCons e
# Line 22  Line 19 
19      fun mkAddVec e=H.mkAddVec e      fun mkAddVec e=H.mkAddVec e
20      fun mkSubSca e= H.mkSubSca e      fun mkSubSca e= H.mkSubSca e
21      fun mkProdVec e =H.mkProdVec e      fun mkProdVec e =H.mkProdVec e
22      fun mkDotVec (setO,a,b,last) =H.mkDotVec (setO,last,[a,b])      fun mkDotVec (avail,a,b,last) =H.mkDotVec (avail,last,[a,b])
     fun iTos n=Int.toString n  
23      fun err str=raise Fail (str)      fun err str=raise Fail (str)
24      val realTy=DstTy.TensorTy []      fun lookup k d = IMap.find (d, k)
25      val intTy=DstTy.IntTy      fun insert  (k, v) d =  IMap.insert (d, k, v)
     val testing=false  
     fun testp n =if  (testing)  then  (print (String.concat n) ;1)  else 1  
   
26        (* convert a rational to a FloatLit.float value.  We do this by long division        (* convert a rational to a FloatLit.float value.  We do this by long division
27       * with a cutoff when we get to 12 digits.       * with a cutoff when we get to 12 digits.
28       *)       *)
# Line 108  Line 101 
101      *      *
102      * Note that the coeffient vectors are flipped  (cf high-to-low/probe.sml) .      * Note that the coeffient vectors are flipped  (cf high-to-low/probe.sml) .
103      *)      *)
104      fun expandEvalKernel  (setOrig,pre,d, h, k, x)  = let      fun expandEvalKernel  (avail,pre,d, h, k, x)  = let
105    
106           val {isCont, segs} = Kernel.curve  (h, k)           val {isCont, segs} = Kernel.curve  (h, k)
107            (* degree of polynomial *)            (* degree of polynomial *)
108           val deg = List.length (hd segs)  - 1           val deg = List.length (hd segs)  - 1
# Line 118  Line 112 
112                   Literal.Float (ratToFloat  (Vector.sub  (Vector.sub (segs, i) , d)))                   Literal.Float (ratToFloat  (Vector.sub  (Vector.sub (segs, i) , d)))
113           val ty = DstTy.vecTy d           val ty = DstTy.vecTy d
114           val coeffs = List.tabulate  (deg+1,fn i => Var.new ("P"^Int.toString i, ty))           val coeffs = List.tabulate  (deg+1,fn i => Var.new ("P"^Int.toString i, ty))
115           fun filterLit ([], vars, code, opset) = (vars, code, opset)          val (_, avail, coeffs)  = let
116            | filterLit ((lhs, rhs) ::es, vars, code, opset) = let              fun mk  (x,  (i::xs, avail,current) )  = let
                val (opset, var)  = lowSet.filter (opset, (lhs, rhs))  
                in  (case var  
                     of NONE=> filterLit (es, vars@[lhs], code@[(lhs, rhs)], opset)  
                     | SOME v =>filterLit (es, vars@[v], code, opset)  
                  (*end case*) )  
               end  
          val (coeffVecs,setOrig)  = let  
              fun mk  (x,  (i::xs, code,opset) )  = let  
117                      val lits = List.tabulate (d, coefficient i)                      val lits = List.tabulate (d, coefficient i)
118                      val vars = List.tabulate (d, fn _ => Var.new ("_f", DstTy.realTy))                  val vars = List.tabulate (d, fn _ => Var.new ("_f", DstTy.TensorTy []))
119                      val code0 = ListPair.map  (fn  (x, lit)  =>  (x, DstIL.LIT lit))   (vars, lits)                  val vars= ListPair.map (fn  (x, lit)  => (AvailRHS.addAssign avail (x, DstIL.LIT lit)))  (vars, lits)
120                      val (vars, code0, opset) =filterLit (code0,[],[],opset)                  val var = AvailRHS.addAssign avail (x, DstIL.CONS (Var.ty x, vars))
121                      val code = code@code0@[ (x, DstIL.CONS (Var.ty x, vars))]                  in
122                      in                      (xs, avail,var::current)
                          (xs, code, opset)  
                     end  
                 val n= List.tabulate (deg+1, fn e=>e )  
                 val (a,b,c) = (List.foldr mk  (n, [],setOrig)   (List.rev coeffs))  
                 in  
                      (b,c)  
                 end  
          val coeffVecsCode= List.map  (fn (x,y) =>DstIL.ASSGN  (x,y) )  coeffVecs  
          fun getSet ([], done, opset, cnt) = (done, opset, cnt)  
            | getSet (DstIL.ASSGN (lhs, rhs) ::es, done, opset, cnt) = let  
             val (opset,var)  = lowSet.filter (opset, (lhs, rhs) )  
                 in (case var  
                     of NONE => getSet (es,done@[DstIL.ASSGN (lhs, rhs) ], opset, cnt)  
                     | SOME v=>   getSet (es,done@[DstIL.ASSGN (lhs,DstIL.VAR v) ], opset, cnt+1)  
                     (*end case*))  
123                  end                  end
            | getSet  (e1::es, done, opset,cnt) =getSet (es,done@[e1],opset,cnt)  
          (*get dot product and addition of list of coeffs*)  
          fun mkdot (setO, [e2,e1], code) = let  
             val (setA, vA, A) = mkProdVec (setO, d, [x,e2])  
             val (setB, vB, B) = mkAddVec (setA, d, [e1,vA])  
124              in              in
125                   (setB,vB,code@A@B)                  (List.foldr mk  (List.tabulate (deg+1, fn e=>e ), avail,[])   (List.rev coeffs))
126              end              end
127            | mkdot (setO,e2::e1::es,code) = let  
128              val (setA, vA, A) = mkProdVec (setO, d, [x,e2])           (*get dot product and addition of list of coeffs*)
129              val (setB, vB, B) = mkAddVec (setA, d, [e1,vA])           fun mkdot (avail, [e2,e1]) = let
130                val (avail, vA) = mkProdVec (avail, d, [x,e2])
131                in  mkAddVec (avail, d, [e1,vA]) end
132              | mkdot (avail,e2::e1::es) = let
133                val (avail, vA) = mkProdVec (avail, d, [x,e2])
134                val (avail, vB) = mkAddVec (avail, d, [e1,vA])
135              in              in
136                  mkdot (setB,vB::es,code@A@B)                  mkdot (avail,vB::es)
137              end              end
138            | mkdot  (setO, [e1], []) = mkProdVec (setO,d,[x,e1])            | mkdot  (avail, [e1]) = mkProdVec (avail,d,[x,e1])
139            | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"            | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
140          val (setC, vC, code) = mkdot (setOrig, List.rev coeffs, [])  
         val _ = (String.concat["\n coeffVecs code :",Int.toString (length (coeffVecsCode) ) ," other code code: ",Int.toString (length (code))])  
141          in          in
142               (setC,vC,coeffVecsCode@code)              mkdot (avail, coeffs)
143          end          end
144    
145    
# Line 196  Line 166 
166      * ->Var * LowIL.assign list      * ->Var * LowIL.assign list
167      * evaluate kernel with segments      * evaluate kernel with segments
168      *)      *)
169      fun mkkrns (setOrig, mappOrig, lhs, params, args, krns, h, sid, lb, range0, range1) = let      fun mkkrns (avail, mappOrig,params, args, krns, h, sid, lb, range0, range1) = let
170    
171          fun evalDels (mapp, dels) = List.foldl (fn (x, y) =>x+y)  0    (List.map  (fn (i, j) =>H.deltaToInt (mapp, i, j) )   dels)          fun evalDels (mapp, dels) = List.foldl (fn (x, y) =>x+y)  0    (List.map  (fn (i, j) =>H.deltaToInt (mapp, i, j) )   dels)
172            fun mkSimpleOp (avail, mapp,  params, args,  E.Op2 (E.Sub,E.Tensor (t1,ix1) ,E.Value v1) ) = let
173          fun mkSimpleOp (setO, mapp, lhs, params, args,  E.Op2 (E.Sub,E.Tensor (t1,ix1) ,E.Value v1) ) = let              val (avail, vA) = indexTensor (avail, mapp, ( params, args, t1, ix1, DstTy.TensorTy []) )
             val (setA, vA, A) = indexTensor (setO, mapp, (lhs, params, args, t1, ix1, realTy) )  
174              in  (case  (find (v1,mapp) )              in  (case  (find (v1,mapp) )
175                  of 0=> (setA, vA, A)                  of 0=> (avail, vA)
176                  | j=>let                  | j=>let
177                      val (setB, vB, B) = intToReal (setA, j)                      val (avail, vB) = intToReal (avail, j)
178                      val (setC, vC, C) = mkSubSca (setB, [vA,vB])                      in mkSubSca (avail, [vA,vB]) end
                     in  (setC, vC, A@B@C)  end  
179                   (*end case*) )                   (*end case*) )
180              end              end
181            fun mkpos (avail, k, fin, rest, dict, i, n) =  (case  (k, i)
182          fun mkpos (setO, k, fin, l, dict, i, code, n) =  (case  (k, i)              of  ([], _) => (avail, List.rev fin)
             of  ([], _) => (setO,fin,code)  
183              |  ( (_, _, pos1) ::ks,0) => let              |  ( (_, _, pos1) ::ks,0) => let
                 val _=testp ["\n insert",iTos n, "->",iTos lb]  
184                  val mapp=insert  (n, lb)  dict                  val mapp=insert  (n, lb)  dict
185                  val (opset',l', code') = mkSimpleOp (setO, mapp, lhs, params, args, pos1)                  val (avail,rest') = mkSimpleOp (avail, mapp, params, args, pos1)
186                  val e=l@[l']                  val e=rest@[rest']
187                  val mapp'= insert  (n, 0)  dict                  val mapp'= insert  (n, 0)  dict
188                  in                  in
189                      mkpos (opset', ks, fin@[e], [], mapp', range0, code@code', n+1)                      mkpos (avail, ks, e::fin, [], mapp', range0, n+1)
190                  end                  end
191              |  ( (_,_,pos1) ::es,_)  => let              |  ( (_,_,pos1) ::es,_)  => let
                 val _=testp ["\n insert",iTos n, "->",iTos  (lb+i) ]  
192                  val mapp= insert  (n, lb+i)  dict                  val mapp= insert  (n, lb+i)  dict
193                  val (opset', l', code') = mkSimpleOp (setO, mapp, lhs, params, args, pos1)                  val (avail, rest') = mkSimpleOp (avail, mapp, params, args, pos1)
194                  in                  in
195                      mkpos (opset', k, fin, l@[l'], dict, i-1, code@code', n)                      mkpos (avail, k, fin, rest@[rest'], dict, i-1, n)
196                  end                  end
197               (*end case*) )               (*end case*) )
198    
199          fun consfn (setO, [], _, rest, code) = (setO, rest, code)          fun consfn (avail, [], _, rest) = (avail, List.rev rest)
200            | consfn (setO, e1::es, n, rest, code) = let            | consfn (avail, e1::es, n, rest) = let
201                  val (setA,  vA, A) = assgnCons (setO,  "h"^iTos n, [length (e1)], List.rev e1)                  val (avail,  vA) = assgnCons (avail,  "h"^Int.toString(n), [length (e1)], List.rev e1)
202                  in                  in
203                      consfn (setA, es, n+1, rest@[vA], code@A)                      consfn (avail, es, n+1, vA::rest)
204                  end                  end
205    
206          fun evalK (setO, [], [], newId, code) = (setO, newId, code)          fun evalK (avail, [], [], newId) = (avail, List.rev newId)
207            | evalK (setO, kn::kns, x::xs, newId, code) = let            | evalK (avail, kn::kns, x::xs, newId) = let
208                  val (_, dk, pos)  = kn                  val (_, dk, pos)  = kn
209                  val directionX=  (case pos                  val directionX=  (case pos
210                      of E.Op2 (E.Sub,E.Tensor  (_,[E.C directionX]) ,_) => directionX                      of E.Op2 (E.Sub,E.Tensor  (_,[E.C directionX]) ,_) => directionX
211                      | _ => 0                      | _ => 0
212                       (*end case*) )                       (*end case*) )
213                  val name=String.concat["h",iTos directionX,"_",iTos dk]                  val name=String.concat["h",Int.toString(directionX),"_",Int.toString dk]
214                  val _ =testp["\n",Var.toString x," = ",name]                  val (avail, id) = expandEvalKernel  (avail, name, range1, h, dk, x)
                 val (opsetK, id, kcode) = expandEvalKernel  (setO, name, range1, h, dk, x)  
215                  in                  in
216                      evalK (opsetK, kns, xs, newId@[id], code@kcode)                      evalK (avail, kns, xs, id::newId)
217                  end                  end
218            | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"            | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
219    
220          val newkrns = List.map  (fn  (id, d1, pos) => (id, evalDels (mappOrig,d1) , pos) )  krns          val newkrns = List.map  (fn  (id, d1, pos) => (id, evalDels (mappOrig,d1) , pos) )  krns
221          val _ = testp["\n\n ***** Differentiation value of kernels ****\n "]          val (avail, lftkrn) = mkpos (avail, newkrns, [], [], mappOrig, range0, sid)
222          val (set2, lftkrn, poscode) = mkpos (setOrig, newkrns, [], [], mappOrig, range0, [], sid)          val (avail, lft) = consfn (avail, lftkrn, 0, [])
         val (set3, lft, conscode) = consfn (set2, lftkrn, 0, [], [])  
         val (set4, ids, evalKcode) = evalK (set3, newkrns, lft, [], [])  
         val _ = List.map  (fn e=>testp["\n IDS",Var.toString e,","])  ids  
          (*returns list in order h0, h1, h2*)  
223          in          in
224               (set4, ids, poscode@conscode@evalKcode)              evalK (avail, newkrns, lft, [])
225          end          end
226    
   
227  end  (* local *)  end  (* local *)
228    
229  end  end

Legend:
Removed from v.3553  
changed lines
  Added in v.3602

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0