Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 5570, Wed May 30 22:09:45 2018 UTC revision 5574, Thu May 31 22:28:40 2018 UTC
# Line 7  Line 7 
7   * COPYRIGHT (c) 2016 The University of Chicago   * COPYRIGHT (c) 2016 The University of Chicago
8   * All rights reserved.   * All rights reserved.
9   *)   *)
10    
11  structure TranslateCFExp : sig  structure TranslateCFExp : sig
12    
13    (* FIXME: add comment explaining function and arguments *)
14      val transform_CFExp : MidIR.var * Ein.ein * MidIR.var list      val transform_CFExp : MidIR.var * Ein.ein * MidIR.var list
15            -> MidIR.var list * Ein.param_kind list * Ein.ein_exp            -> MidIR.var list * Ein.param_kind list * Ein.ein_exp
16    
# Line 31  Line 33 
33        | paramToString (i, E.KRN) = "H" ^ i2s i        | paramToString (i, E.KRN) = "H" ^ i2s i
34        | paramToString (i, E.IMG(d, shp)) = concat["V", i2s i, "(", i2s d, ")[", shp2s shp, "]"]        | paramToString (i, E.IMG(d, shp)) = concat["V", i2s i, "(", i2s d, ")[", shp2s shp, "]"]
35    
     fun iterP es = let  
           fun iterPP ([], [r]) = r  
             | iterPP ([], rest) = E.Opn(E.Prod, rest)  
             | iterPP (E.Const 0::es, rest) = E.Const(0)  
             | iterPP (E.Const 1::es, rest) = iterPP(es, rest)  
             | iterPP (E.Delta(E.C c1, E.V v1)::E.Delta(E.C c2, E.V v2)::es, rest) =  
               (* variable can't be 0 and 1 '*)  
                 if (c1 = c2 orelse (not (v1 = v2)))  
                   then iterPP (es, E.Delta(E.C c1, E.V v1)::E.Delta(E.C c2, E.V v2)::rest)  
                   else  E.Const(0)  
             | iterPP (E.Opn(E.Prod, ys)::es, rest) = iterPP(ys@es, rest)  
             | iterPP (e1::es, rest) = iterPP(es, e1::rest)  
           in  
             iterPP(es, [])  
           end  
   
     fun iterA es =  let  
           fun iterAA ([], []) = E.Const 0  
             | iterAA ([], [r]) = r  
             | iterAA ([], rest) = E.Opn(E.Add, rest)  
             | iterAA (E.Const 0::es, rest) = iterAA(es, rest)  
             | iterAA (E.Opn(E.Add, ys)::es, rest) = iterAA(ys@es, rest)  
             | iterAA (e1::es, rest) = iterAA(es, e1::rest)  
           in  
             iterAA(es, [])  
           end  
   
36   (* The terms with a param_id in the mapp are replaced   (* The terms with a param_id in the mapp are replaced
37    * body - ein expression    * body - ein expression
38    * args - variable arguments    * args - variable arguments
# Line 96  Line 71 
71                    | E.Op1(E.PowInt n, e1) => let                    | E.Op1(E.PowInt n, e1) => let
72                          val tmp = rewrite e1                          val tmp = rewrite e1
73                          in                          in
74                            iterP (List.tabulate(n, fn _ => tmp))                          EinUtil.iterPP (List.tabulate(n, fn _ => tmp))
75                          end                          end
76                    | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)                    | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
77                    | E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)                    | E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)
78                    | E.Opn(E.Prod, E.Opn(E.Add, ps)::es) => let                    | E.Opn(E.Prod, E.Opn(E.Add, ps)::es) => let
79                          val ps = List.map (fn e1 => iterP(e1::es)) ps                        val ps = List.map (fn e1 => EinUtil.iterPP(e1::es)) ps
80                          val body = E.Opn(E.Add, ps)                          val body = E.Opn(E.Add, ps)
81                          in                          in
82                            rewrite body                            rewrite body
83                          end                          end
84                    | E.Opn(E.Prod, ps) => iterP(List.map rewrite ps)                    | E.Opn(E.Prod, ps) => EinUtil.iterPP(List.map rewrite ps)
85                    | E.Opn(E.Add , ps) => iterA(List.map rewrite ps)                    | E.Opn(E.Add , ps) => EinUtil.iterAA(List.map rewrite ps)
86                    | _ => body                    | _ => body
87                  (* end case*))                  (* end case*))
88            in            in
# Line 127  Line 102 
102          (* rewrites a single variable          (* rewrites a single variable
103          * rewritement instances of arg at pid position with arg at idx position          * rewritement instances of arg at pid position with arg at idx position
104          *)          *)
105            fun single_TF (pid, args, params, idx, e) = let            fun singleTF (pid, args, params, idx, e) = let
106                      (*check if the current parameter is a sequence and get dimension*)                      (*check if the current parameter is a sequence and get dimension*)
107                      (*Note Dev branch supports sequence parameter*)                      (*Note Dev branch supports sequence parameter*)
108                      val dim = (case List.nth(params, idx)                      val dim = (case List.nth(params, idx)
# Line 136  Line 111 
111                              | p => raise Fail("unsupported argument type:"^paramToString(idx, p))                              | p => raise Fail("unsupported argument type:"^paramToString(idx, p))
112                          (* end case *))                          (* end case *))
113                      (*variable arg, and param*)                      (*variable arg, and param*)
114                      val arg_new = List.nth(args, idx)                      val newArg = List.nth(args, idx)
115                      val param_new = List.nth(params, idx)                      val newParam = List.nth(params, idx)
116                      val arg_rewrited = List.nth(args, pid)                      val rwArg = List.nth(args, pid)
117                      (*id keeps track of placement and puts it in mapp*)                      (*id keeps track of placement and puts it in mapp*)
118                      fun findArg(_, es, newargs, [], newparams, mapp) = ((List.rev newargs)@es, List.rev newparams, mapp)                      fun findArg(_, es, newargs, [], newparams, mapp) =
119                              (List.revAppend(newargs, es), List.rev newparams, mapp)
120                      | findArg(id, e1::es, newargs, p1::ps, newparams, mapp) =                      | findArg(id, e1::es, newargs, p1::ps, newparams, mapp) =
121                          if(IR.Var.same(e1, arg_rewrited))                            if (IR.Var.same(e1, rwArg))
122                          then findArg(id+1, es, arg_new::newargs, ps, param_new::newparams, ISet.add(mapp, id))                              then findArg(id+1, es, newArg::newargs, ps, newParam::newparams, ISet.add(mapp, id))
123                          else findArg(id+1, es, e1::newargs, ps , p1::newparams, mapp)                          else findArg(id+1, es, e1::newargs, ps , p1::newparams, mapp)
124                      val (args, params, mapp) = findArg(0, args, [], params, [], ISet.empty)                      val (args, params, mapp) = findArg(0, args, [], params, [], ISet.empty)
125                      (* get dimension of vector that is being broken into components*)                      (* get dimension of vector that is being broken into components*)
126                      val param_pos = List.nth(params, pid)                      val param_pos = List.nth(params, pid)
127                      (* rewrite position tensor with deltas in body *)                      (* rewrite position tensor with deltas in body *)
128                      val e = replace (e, dim, mapp)                      val e = replace (e, dim, mapp)
129                  in (args, params, e) end                  in
130                      (args, params, e)
131                    end
132          (*iterate over all the input tensor variable expressions *)          (*iterate over all the input tensor variable expressions *)
133            fun iter ([], args, params, _, e) = (args, params, e)            fun iter ([], args, params, _, e) = (args, params, e)
134              | iter ((pid, E.T)::es, args, params, idx::idxs, e) = let              | iter ((pid, E.T)::es, args, params, idx::idxs, e) = let
# Line 161  Line 139 
139                  end                  end
140              | iter ((pid, E.F)::es, args, params, idx::idxs, e) = let              | iter ((pid, E.F)::es, args, params, idx::idxs, e) = let
141                (*variable is treated as a field so it needs to be expanded into its components*)                (*variable is treated as a field so it needs to be expanded into its components*)
142                  val (args, params, e) = single_TF (pid, args, params, idx, e)                  val (args, params, e) = singleTF (pid, args, params, idx, e)
143                  in                  in
144                    iter(es, args, params, idxs, e)                    iter(es, args, params, idxs, e)
145                  end                  end
# Line 196  Line 174 
174          (*check that the number of into parameters matches number of probed arguments*)          (*check that the number of into parameters matches number of probed arguments*)
175            val n_pargs = length(cfexp_ids)            val n_pargs = length(cfexp_ids)
176            val n_probe = length(probe_ids)            val n_probe = length(probe_ids)
177            val _ = if(not(n_pargs = n_probe))            val _ = if (n_pargs <> n_probe)
178                    then raise  Fail(concat[" n_pargs:", Int.toString( n_pargs), "n_probe:", Int.toString(n_probe)])                      then raise Fail(concat[
179                    else 1                          "n_pargs:", Int.toString( n_pargs), "n_probe:", Int.toString(n_probe)
180                          ])
181                        else ()
182          (* replace polywrap args/params with probed position(s) args/params *)          (* replace polywrap args/params with probed position(s) args/params *)
183            val (args, params, e) = polyArgs(params, e, args, cfexp_ids, probe_ids)            val (args, params, e) = polyArgs(params, e, args, cfexp_ids, probe_ids)
184          (* normalize ein by cleaning it up and differntiating*)          (* normalize ein by cleaning it up and differntiating*)

Legend:
Removed from v.5570  
changed lines
  Added in v.5574

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0