Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/low-opt/low-contract.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/low-opt/low-contract.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4317 - (view) (download)

1 : jhr 3714 (* low-contract.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * Contraction phase for LowIR.
9 :     *)
10 :    
11 :     structure LowContract : sig
12 :    
13 :     val transform : LowIR.program -> LowIR.program
14 :    
15 :     end = struct
16 :    
17 :     structure IR = LowIR
18 :     structure Op = LowOps
19 :     structure Ty = LowTypes
20 :     structure V = IR.Var
21 :     structure ST = Stats
22 :    
23 :     (********** Counters for statistics **********)
24 : jhr 3799 val cntAddNeg = ST.newCounter "low-contract:add-neg"
25 : jhr 3955 val cntAddConst = ST.newCounter "low-contract:add-const"
26 : jhr 3799 val cntSubNeg = ST.newCounter "low-contract:sub-neg"
27 :     val cntSubSame = ST.newCounter "low-contract:sub-same"
28 :     val cntNegNeg = ST.newCounter "low-contract:neg-neg"
29 :     val cntIntToReal = ST.newCounter "low-contract:int-to-real"
30 : jhr 4317 val cntTensorIndex = ST.newCounter "low-contract:tensor-index"
31 :     val cntProjectLast = ST.newCounter "low-contract:project-last"
32 :     val cntSubscript = ST.newCounter "low-contract:subscript"
33 : jhr 3799 val cntUnused = ST.newCounter "low-contract:unused"
34 : jhr 3714 val firstCounter = cntAddNeg
35 :     val lastCounter = cntUnused
36 :    
37 :     structure UnusedElim = UnusedElimFn (
38 :     structure IR = IR
39 :     val cntUnused = cntUnused)
40 :    
41 :     fun useCount (IR.V{useCnt, ...}) = !useCnt
42 :    
43 :     (* adjust a variable's use count *)
44 :     fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
45 :     fun decUse (IR.V{useCnt, ...}) = (useCnt := !useCnt - 1)
46 :     fun use x = (incUse x; x)
47 :    
48 :     fun getRHSOpt x = (case V.getDef x
49 :     of IR.OP arg => SOME arg
50 :     | _ => NONE
51 :     (* end case *))
52 :    
53 : jhr 3814 (* get the local definition of a variable. Unlike getDef, this function does
54 :     * not chase through global definitions, which means that we do not have to
55 :     * worry about lifting local variables in the globalInit to global scope.
56 :     *)
57 :     fun getLocalDef (x as IR.V{bind, ...}) = (case !bind
58 : jhr 4317 of IR.VB_RHS rhs => (case rhs
59 :     of IR.VAR x => getLocalDef x
60 :     | _ => rhs
61 :     (* end case *))
62 :     | _ => IR.VAR x
63 :     (* end case *))
64 : jhr 3814
65 : jhr 3714 (* TODO: tensor selection operations *)
66 :     fun doAssign (lhs, IR.OP rhs) = (case rhs
67 : jhr 3955 of (Op.IAdd, [a, b]) => (case (V.getDef a, V.getDef b)
68 : jhr 4317 of (IR.LIT(Literal.Int a'), IR.LIT(Literal.Int b')) => (
69 : jhr 3955 (* rewrite to sum of a' and b' *)
70 :     ST.tick cntAddConst;
71 :     decUse a; decUse b;
72 :     SOME[(lhs, IR.LIT(Literal.Int(a' + b')))])
73 : jhr 4317 | (IR.LIT(Literal.Int 0), _) => (
74 : jhr 3955 (* rewrite to b *)
75 :     ST.tick cntAddConst;
76 :     decUse a;
77 :     SOME[(lhs, IR.VAR b)])
78 : jhr 4317 | (_, IR.LIT(Literal.Int 0)) => (
79 : jhr 3955 (* rewrite to a *)
80 :     ST.tick cntAddConst;
81 :     decUse b;
82 :     SOME[(lhs, IR.VAR a)])
83 : jhr 4317 | (_, IR.OP(Op.INeg, [c])) => (
84 : jhr 3714 (* rewrite to "a-c" *)
85 :     ST.tick cntAddNeg;
86 :     decUse b;
87 :     SOME[(lhs, IR.OP(Op.ISub, [a, use c]))])
88 :     | _ => NONE
89 :     (* end case *))
90 :     | (Op.ISub, [a, b]) => if IR.Var.same(a, b)
91 :     then ( (* rewrite to 0 *)
92 :     ST.tick cntSubSame;
93 :     decUse a; decUse b;
94 :     SOME[(lhs, IR.LIT(Literal.Int 0))])
95 :     else (case getRHSOpt b
96 :     of SOME(Op.INeg, [c]) => (
97 :     (* rewrite to "a+c" *)
98 :     ST.tick cntSubNeg;
99 :     decUse b;
100 :     SOME[(lhs, IR.OP(Op.IAdd, [a, use c]))])
101 :     | _ => NONE
102 :     (* end case *))
103 :     | (Op.INeg, [a]) => (case getRHSOpt a
104 :     of SOME(Op.INeg, [b]) => (
105 :     (* rewrite to "b" *)
106 :     ST.tick cntNegNeg;
107 :     decUse a;
108 :     SOME[(lhs, IR.VAR(use b))])
109 :     | _ => NONE
110 :     (* end case *))
111 :     | (Op.IntToReal, [a]) => (case V.getDef a
112 :     of IR.LIT(Literal.Int n) => (
113 :     (* rerite to a real literal *)
114 :     ST.tick cntIntToReal;
115 :     decUse a;
116 :     SOME[(lhs, IR.LIT(Literal.Real(RealLit.fromInt n)))])
117 :     | _ => NONE
118 :     (* end case *))
119 : jhr 4317 | (Op.TensorIndex(Ty.TensorTy dims, idxs), [t]) => let
120 :     fun get ([], [], x) = (
121 :     SOME[(lhs, IR.VAR(use x))])
122 :     | get (ix::ixs, d::ds, x) = (case getLocalDef x
123 :     of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
124 :     | _ => let
125 :     val rator = if List.null ds
126 :     then Op.VIndex(d, ix)
127 :     else Op.TensorIndex(Ty.tensorTy(d::ds), ix::ixs)
128 :     in
129 :     SOME[(lhs, IR.OP(rator, [use x]))]
130 :     end
131 :     (* end case *))
132 :     | get _ = raise Fail "malformed TensorIndex"
133 :     in
134 :     case getLocalDef t
135 :     of IR.CONS _ => (ST.tick cntTensorIndex; decUse t; get(idxs, dims, t))
136 :     | _ => NONE
137 :     (* end case *)
138 :     end
139 :     | (Op.ProjectLast(Ty.TensorTy dims, idxs), [t]) => let
140 :     fun get ([], [_], x) = (
141 :     SOME[(lhs, IR.VAR(use x))])
142 :     | get (ix::ixs, d::ds, x) = (case getLocalDef x
143 :     of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
144 :     | _ => SOME[
145 :     (lhs, IR.OP(Op.ProjectLast(Ty.tensorTy(d::ds), ix::ixs), [use x]))
146 :     ]
147 :     (* end case *))
148 :     | get _ = raise Fail "malformed ProjectLast"
149 :     in
150 :     case getLocalDef t
151 :     of IR.CONS _ => (ST.tick cntProjectLast; decUse t; get(idxs, dims, t))
152 :     | _ => NONE
153 :     (* end case *)
154 :     end
155 :     | (Op.Subscript ty, [seq, idx]) => (case (getLocalDef seq, V.getDef idx)
156 :     of (IR.SEQ(vs, _), IR.LIT(Literal.Int i)) => (
157 :     ST.tick cntSubscript; decUse seq; decUse idx;
158 :     SOME[(lhs, IR.VAR(use (List.nth(vs, Int.fromLarge i))))])
159 :     | _ => NONE
160 :     (* end case *))
161 : jhr 3714 | _ => NONE
162 :     (* end case *))
163 :     | doAssign _ = NONE
164 :    
165 : jhr 3747 fun doMAssign _ = NONE
166 : jhr 3714
167 :     structure Rewrite = RewriteFn (
168 :     struct
169 :     structure IR = IR
170 :     val doAssign = doAssign
171 :     val doMAssign = doMAssign
172 :     val elimUnusedVars = UnusedElim.reduce
173 :     end)
174 :    
175 :     val transform = Rewrite.transform
176 :    
177 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0