Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/low-opt/low-contract.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/low-opt/low-contract.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3814 - (view) (download)

1 : jhr 3714 (* low-contract.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * Contraction phase for LowIR.
9 :     *)
10 :    
11 :     structure LowContract : sig
12 :    
13 :     val transform : LowIR.program -> LowIR.program
14 :    
15 :     end = struct
16 :    
17 :     structure IR = LowIR
18 :     structure Op = LowOps
19 :     structure Ty = LowTypes
20 :     structure V = IR.Var
21 :     structure ST = Stats
22 :    
23 :     (********** Counters for statistics **********)
24 : jhr 3799 val cntAddNeg = ST.newCounter "low-contract:add-neg"
25 :     val cntSubNeg = ST.newCounter "low-contract:sub-neg"
26 :     val cntSubSame = ST.newCounter "low-contract:sub-same"
27 :     val cntNegNeg = ST.newCounter "low-contract:neg-neg"
28 :     val cntIntToReal = ST.newCounter "low-contract:int-to-real"
29 :     val cntTensorIndex = ST.newCounter "low-contract:tensor-index"
30 :     val cntProjectLast = ST.newCounter "low-contract:project-last"
31 :     val cntUnused = ST.newCounter "low-contract:unused"
32 : jhr 3714 val firstCounter = cntAddNeg
33 :     val lastCounter = cntUnused
34 :    
35 :     structure UnusedElim = UnusedElimFn (
36 :     structure IR = IR
37 :     val cntUnused = cntUnused)
38 :    
39 :     fun useCount (IR.V{useCnt, ...}) = !useCnt
40 :    
41 :     (* adjust a variable's use count *)
42 :     fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
43 :     fun decUse (IR.V{useCnt, ...}) = (useCnt := !useCnt - 1)
44 :     fun use x = (incUse x; x)
45 :    
46 :     fun getRHSOpt x = (case V.getDef x
47 :     of IR.OP arg => SOME arg
48 :     | _ => NONE
49 :     (* end case *))
50 :    
51 : jhr 3814 (* get the local definition of a variable. Unlike getDef, this function does
52 :     * not chase through global definitions, which means that we do not have to
53 :     * worry about lifting local variables in the globalInit to global scope.
54 :     *)
55 :     fun getLocalDef (x as IR.V{bind, ...}) = (case !bind
56 :     of IR.VB_RHS rhs => (case rhs
57 :     of IR.VAR x => getLocalDef x
58 :     | _ => rhs
59 :     (* end case *))
60 :     | _ => IR.VAR x
61 :     (* end case *))
62 :    
63 : jhr 3714 (* TODO: tensor selection operations *)
64 :     fun doAssign (lhs, IR.OP rhs) = (case rhs
65 :     of (Op.IAdd, [a, b]) => (case getRHSOpt b
66 :     of SOME(Op.INeg, [c]) => (
67 :     (* rewrite to "a-c" *)
68 :     ST.tick cntAddNeg;
69 :     decUse b;
70 :     SOME[(lhs, IR.OP(Op.ISub, [a, use c]))])
71 :     | _ => NONE
72 :     (* end case *))
73 :     | (Op.ISub, [a, b]) => if IR.Var.same(a, b)
74 :     then ( (* rewrite to 0 *)
75 :     ST.tick cntSubSame;
76 :     decUse a; decUse b;
77 :     SOME[(lhs, IR.LIT(Literal.Int 0))])
78 :     else (case getRHSOpt b
79 :     of SOME(Op.INeg, [c]) => (
80 :     (* rewrite to "a+c" *)
81 :     ST.tick cntSubNeg;
82 :     decUse b;
83 :     SOME[(lhs, IR.OP(Op.IAdd, [a, use c]))])
84 :     | _ => NONE
85 :     (* end case *))
86 :     | (Op.INeg, [a]) => (case getRHSOpt a
87 :     of SOME(Op.INeg, [b]) => (
88 :     (* rewrite to "b" *)
89 :     ST.tick cntNegNeg;
90 :     decUse a;
91 :     SOME[(lhs, IR.VAR(use b))])
92 :     | _ => NONE
93 :     (* end case *))
94 :     | (Op.IntToReal, [a]) => (case V.getDef a
95 :     of IR.LIT(Literal.Int n) => (
96 :     (* rerite to a real literal *)
97 :     ST.tick cntIntToReal;
98 :     decUse a;
99 :     SOME[(lhs, IR.LIT(Literal.Real(RealLit.fromInt n)))])
100 :     | _ => NONE
101 :     (* end case *))
102 : jhr 3799 | (Op.TensorIndex(Ty.TensorTy dims, idxs), [t]) => let
103 :     fun get ([], [], x) = (
104 :     SOME[(lhs, IR.VAR(use x))])
105 : jhr 3814 | get (ix::ixs, d::ds, x) = (case getLocalDef x
106 : jhr 3799 of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
107 : jhr 3814 | _ => SOME[
108 :     (lhs, IR.OP(Op.TensorIndex(Ty.tensorTy(d::ds), ix::ixs), [use x]))
109 :     ]
110 : jhr 3799 (* end case *))
111 :     | get _ = raise Fail "malformed TensorIndex"
112 :     in
113 : jhr 3814 case getLocalDef t
114 : jhr 3799 of IR.CONS _ => (ST.tick cntTensorIndex; decUse t; get(idxs, dims, t))
115 :     | _ => NONE
116 :     (* end case *)
117 :     end
118 :     | (Op.ProjectLast(Ty.TensorTy dims, idxs), [t]) => let
119 :     fun get ([], [_], x) = (
120 :     SOME[(lhs, IR.VAR(use x))])
121 : jhr 3814 | get (ix::ixs, d::ds, x) = (case getLocalDef x
122 : jhr 3799 of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
123 : jhr 3814 | _ => SOME[
124 :     (lhs, IR.OP(Op.ProjectLast(Ty.tensorTy(d::ds), ix::ixs), [use x]))
125 :     ]
126 : jhr 3799 (* end case *))
127 :     | get _ = raise Fail "malformed ProjectLast"
128 :     in
129 : jhr 3814 case getLocalDef t
130 : jhr 3799 of IR.CONS _ => (ST.tick cntProjectLast; decUse t; get(idxs, dims, t))
131 :     | _ => NONE
132 :     (* end case *)
133 :     end
134 : jhr 3714 | _ => NONE
135 :     (* end case *))
136 :     | doAssign _ = NONE
137 :    
138 : jhr 3747 fun doMAssign _ = NONE
139 : jhr 3714
140 :     structure Rewrite = RewriteFn (
141 :     struct
142 :     structure IR = IR
143 :     val doAssign = doAssign
144 :     val doMAssign = doMAssign
145 :     val elimUnusedVars = UnusedElim.reduce
146 :     end)
147 :    
148 :     val transform = Rewrite.transform
149 :    
150 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0