Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-scalar.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-scalar.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3649, Tue Feb 2 15:37:23 2016 UTC revision 3653, Tue Feb 2 22:50:44 2016 UTC
# Line 11  Line 11 
11  structure EinToScalar : sig  structure EinToScalar : sig
12    
13      val expand :      val expand :
14            AvailRHS.t * ?? IntRedBlackMap.map * (Ein.param_kind list * Ein.ein_exp * LowIR.var list)            AvailRHS.t * int IntRedBlackMap.map * (Ein.param_kind list * Ein.ein_exp * LowIR.var list)
15              -> ??              -> LowIR.var
16    
17    end = struct    end = struct
18    
# Line 22  Line 22 
22      structure Var = LowIR.Var      structure Var = LowIR.Var
23      structure E = Ein      structure E = Ein
24      structure Mk = MkLowIR      structure Mk = MkLowIR
     structure P = EinPP  
25      structure IMap = IntRedBlackMap      structure IMap = IntRedBlackMap
26    
     fun evalField e =  FieldToLow.evalField e  
27      fun indexTensor e = Mk.indexTensor e      fun indexTensor e = Mk.indexTensor e
28      fun mkSubSca e =  Mk.mkSubSca e      fun mkSubSca e =  Mk.mkSubSca e
29      fun mkProdSca e = Mk.mkProdSca e      fun mkProdSca e = Mk.mkProdSca e
# Line 33  Line 31 
31      fun mkMultiple e = Mk.mkMultiple e      fun mkMultiple e = Mk.mkMultiple e
32      fun evalG e =  Mk.evalG e      fun evalG e =  Mk.evalG e
33      fun mkOp1 e =  Mk.mkOp1 e      fun mkOp1 e =  Mk.mkOp1 e
     fun insert  (k, v) d =  IMap.insert (d, k, v)  
     fun errField e = raise FaIR ("Invalid Field Here:"^ (P.expToString e))  
34    
35      fun mapIndex (mapp, id) => (case IMap.find(mapp, id)      fun mapIndex (mapp, id) = (case IMap.find(mapp, id)
36             of SOME x => x             of SOME x => x
37              | NONE => raise Fail(concat["mapIndex(_, V ", Int.toString id, "): out of bounds"])              | NONE => raise Fail(concat["mapIndex(_, V ", Int.toString id, "): out of bounds"])
38            (* end case *))            (* end case *))
39    
40      fun expand (avail, mapp, (params, body, args)) = let      fun expand (avail, mapp, (params, body, args)) = let
41          val mapp = ref dict            fun gen (mapp, body) = let
         val info =  (e, args)  
         fun gen (avaIR, body) =  let  
42              (*********sumexpression ********)              (*********sumexpression ********)
43              fun tb n =  List.tabulate (n, fn e =>e)              fun tb n =  List.tabulate (n, fn e =>e)
44              fun Sumcheck (avaIR, sumx, e) = let                  fun sumCheck (mapp, (E.V v, lb, ub) :: sumx, e) = let
45                  fun sumloop (avaIR, mapsum) = (mapp:= mapsum; gen (avaIR, e))                        fun sumloop mapp = gen (mapp, e)
46                  fun sumI1 (avaIR, left, (v, [i], lb1), [], rest ) = let                        fun sumI1 (left, (v, [i], lb1), [], rest) = let
47                      val dict = insert (v, lb1+i)  left                              val mapp = IMap.insert (mapp, v, lb1+i)
48                      val  (avaIR, vD) =  sumloop (avaIR, dict)                              val vD = gen (mapp, e)
49                      in  (avaIR, rest@[vD])  end                              in
50                  |  sumI1 (avaIR, left, (v, i::es, lb1), [], rest) = let                                rest@[vD]
51                      val dict = insert (v, (i+lb1))  left                              end
52                      val  (avaIR, vD) = sumloop  (avaIR, dict)                          | sumI1 (left, (v, i::es, lb1), [], rest) = let
53                      in sumI1 (avaIR, dict, (v, es, lb1), [], rest@[vD])  end                              val mapp = IMap.insert (left, v, i+lb1)
54                  | sumI1 (avaIR, left, (v, [i], lb1), (E.V a, lb2, ub2) ::sx, rest) =                              val vD = gen (mapp, e)
55                      sumI1 (avaIR,  insert (v, lb1+i)  left, (a, tb (ub2-lb2+1), lb2), sx, rest)                              in
56                  | sumI1 (avaIR, left, (v, s::es, lb1), (E.V a, lb2, ub2) ::sx, rest) = let                                sumI1 (mapp, (v, es, lb1), [], rest@[vD])
57                      val dict = insert (v, (s+lb1))  left                              end
58                            | sumI1 (left, (v, [i], lb1), (E.V a, lb2, ub2) ::sx, rest) =
59                                sumI1 (IMap.insert (left, v, lb1+i), (a, tb (ub2-lb2+1), lb2), sx, rest)
60                            | sumI1 (left, (v, s::es, lb1), (E.V a, lb2, ub2) ::sx, rest) = let
61                                val mapp = IMap.insert (left, v, s+lb1)
62                      val xx = tb (ub2-lb2+1)                      val xx = tb (ub2-lb2+1)
63                      val  (avaIR, rest') = sumI1 (avaIR, dict, (a, xx, lb2), sx, rest)                              val rest' = sumI1 (mapp, (a, xx, lb2), sx, rest)
64                      in sumI1 (avaIR, dict, (v, es, lb1), (E.V a, lb2, ub2) ::sx, rest')  end                              in
65                  | sumI1 _ = raise FaIR"None Variable-index in summation"                                sumI1 (mapp, (v, es, lb1), (E.V a, lb2, ub2) ::sx, rest')
66                  val  (E.V v, lb, ub) = hd (sumx)                              end
67                  in                          | sumI1 _ = raise Fail "None Variable-index in summation"
68                      sumI1 (avaIR, !mapp, (v, tb (ub-lb+1), lb), tl (sumx), [])                        in
69                  end                          sumI1 (mapp, (v, tb (ub-lb+1), lb), sumx, [])
70              in  (case body                        end
71                  of E.Field _           => errField (body)                  in
72                  | E.Partial _          => errField (body)                    case body
73                  | E.Apply _            => errField (body)                     of E.Value v => Mk.intToRealLit (avail, mapIndex (mapp, v))
74                  | E.Probe _            => errField (body)                      | E.Const c => Mk.intToRealLit (avail, c)
75                  | E.Conv _             => errField (body)                      | E.Delta(i, j) => Mk.delta (avail, mapp, i, j)
76                  | E.Krn _              => errField (body)                      | E.Epsilon(i, j, k) => Mk.epsilon3 (avail, mapp, i, j, k)
77                  | E.Img _              => errField (body)                      | E.Eps2(i, j) => Mk.epsilon2 (avail, mapp, i, j)
78                  | E.Lift _             => errField (body)                      | E.Tensor(id, ix) =>
79                  | E.Value v            => Mk.intToRealLit (avaIR, mapIndex (!mapp, v))                          indexTensor (avail, mapp, (params, args, id, ix, Ty.realTy))
80                  | E.Const c            => Mk.intToRealLit (avaIR, c)                      | E.Op1(E.Neg, e1) =>
81                  | E.Delta _            => evalG (avaIR, !mapp, body)  (* QUESTION: why not just negate the tensor? *)
82                  | E.EpsIRon _          => evalG (avaIR, !mapp, body)                          mkProdSca (avail, [Mk.intToRealLit (avail, ~1), gen (mapp, e1)])
83                  | E.Eps2 _             => evalG (avaIR, !mapp, body)                      | E.Op1(op1, e1) => mkOp1 (op1, gen (mapp, e1))
84                  | E.Tensor (id, ix)    => indexTensor (avaIR, !mapp, (params, args, id, ix, Ty.TensorTy []))                      | E.Op2(E.Sub, e1, e2) => mkSubSca (avail, [gen (mapp, e1), gen (mapp, e2)])
85                  | E.Op1 (E.Neg, e1)    => let                      | E.Opn(E.Add, es) => let
86                      val (avaIR, vA) = gen (avaIR, e1)                          fun iter ([], ids) =
87                      val (avaIR, vB ) =  Mk.intToRealLit (avaIR, ~1)                                mkMultiple (avail, List.rev ids, Op.addSca, Ty.realTy)
88                      in                            | iter (e1::es, ids) = iter (es, (gen (mapp, e1))::ids)
89                        mkProdSca  (avaIR, [vB, vA])                          in
90                      end                            iter (es, [])
91                  | E.Op1 (op1, e1) => mkOp1 (op1, gen (avaIR, e1))                          end
92                  | E.Op2 (E.Sub, e1, e2)   => let                      | E.Opn(E.Prod, es) => let
93                      val  (avaIR, vA) = gen (avaIR, e1)                          fun iter ([], ids) =
94                      val  (avaIR, vB) = gen (avaIR, e2)                                mkMultiple (avail, List.rev ids, Op.prodSca, Ty.realTy)
95                      in                            | iter (e1::es, ids) = iter (es, (gen (mapp, e1))::ids)
96                        mkSubSca (avaIR, [vA, vB])                          in
97                      end                            iter (es, [])
98                  | E.Opn (E.Add, e)       =>                          end
                     let  
                         fun iter (avaIR, [], ids) =  mkMultiple (avaIR, List.rev ids, Op.addSca, Ty.TensorTy [])  
                           | iter (avaIR, e1::es, ids) = let  
                             val  (avaIR, a) = gen (avaIR, e1)  
                             in  iter (avaIR,es,a::ids) end  
                     in iter (avaIR, e, []) end  
                 | E.Opn (E.Prod, e)      =>  
                     let  
                         fun iter (avaIR, [], ids) =  mkMultiple (avaIR, List.rev ids, Op.prodSca, Ty.TensorTy [])  
                           | iter (avaIR, e1::es, ids) = let  
                         val  (avaIR, a) = gen (avaIR, e1)  
                         in  iter (avaIR,es,a::ids) end  
                     in iter (avaIR, e, []) end  
99                  | E.Op2 (E.Div, e1 as E.Tensor (_, [_]), e2 as E.Tensor (_, [])) =>                  | E.Op2 (E.Div, e1 as E.Tensor (_, [_]), e2 as E.Tensor (_, [])) =>
100                          gen (avaIR, E.Opn (E.Prod, [E.Op2 (E.Div, E.Const 1, e2), e1]))                          gen (mapp, E.Opn(E.Prod, [E.Op2 (E.Div, E.Const 1, e2), e1]))
101                  | E.Op2 (E.Div, e1, e2)    =>                      | E.Op2(E.Div, e1, e2) => mkDivSca (avail, [gen (mapp, e2)])
102                      let                      | E.Sum(x, E.Opn (E.Prod, (E.Img (Vid, _, _) ::E.Krn (Hid, _, _) ::_))) =>
103                          val  (avaIR, vA ) = gen (avaIR, e1)                          FieldToLow.expand (avail, mapp, params, body, args)
104                          val  (avaIR, vB) = gen (avaIR, e2)                      | E.Sum(sumx, e) => let
105                      in mkDivSca (avaIR, [vA, vB]) end                          val ids = sumCheck (mapp, sumx, e)
106                  | E.Sum (x, E.Opn (E.Prod, (E.Img (Vid, _, _) ::E.Krn (Hid, _, _) ::_)))                          in
107                                         => evalField (avaIR, !mapp, (body, info))                            mkMultiple (avail, ids, Op.addSca, Ty.realTy)
108                  | E.Sum (sumx, e)        =>                          end
109                      let                      | _ => raise Fail("unsupported ein-exp: " ^ EinPP.expToString body)
110                          val (avaIR,ids)= Sumcheck (avaIR, sumx, e)                    (*end case*)
                     in mkMultiple (avaIR, ids, Op.addSca, Ty.TensorTy []) end  
                 | _                    => raise FaIR"unsupported ein-exp "  
                  (*end case*))  
111                  end                  end
112           in           in
113             gen (setOrig, body)              gen (mapp, body)
114           end           end
115    
116      end      end

Legend:
Removed from v.3649  
changed lines
  Added in v.3653

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0