Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/low-opt/low-contract.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/low-opt/low-contract.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3805 - (view) (download)

1 : jhr 3714 (* low-contract.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 :     * Contraction phase for LowIR.
9 :     *)
10 :    
11 :     structure LowContract : sig
12 :    
13 :     val transform : LowIR.program -> LowIR.program
14 :    
15 :     end = struct
16 :    
17 :     structure IR = LowIR
18 :     structure Op = LowOps
19 :     structure Ty = LowTypes
20 :     structure V = IR.Var
21 :     structure ST = Stats
22 :    
23 :     (********** Counters for statistics **********)
24 : jhr 3799 val cntAddNeg = ST.newCounter "low-contract:add-neg"
25 :     val cntSubNeg = ST.newCounter "low-contract:sub-neg"
26 :     val cntSubSame = ST.newCounter "low-contract:sub-same"
27 :     val cntNegNeg = ST.newCounter "low-contract:neg-neg"
28 :     val cntIntToReal = ST.newCounter "low-contract:int-to-real"
29 :     val cntTensorIndex = ST.newCounter "low-contract:tensor-index"
30 :     val cntProjectLast = ST.newCounter "low-contract:project-last"
31 :     val cntUnused = ST.newCounter "low-contract:unused"
32 : jhr 3714 val firstCounter = cntAddNeg
33 :     val lastCounter = cntUnused
34 :    
35 :     structure UnusedElim = UnusedElimFn (
36 :     structure IR = IR
37 :     val cntUnused = cntUnused)
38 :    
39 :     fun useCount (IR.V{useCnt, ...}) = !useCnt
40 :    
41 :     (* adjust a variable's use count *)
42 :     fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
43 :     fun decUse (IR.V{useCnt, ...}) = (useCnt := !useCnt - 1)
44 :     fun use x = (incUse x; x)
45 :    
46 :     fun getRHSOpt x = (case V.getDef x
47 :     of IR.OP arg => SOME arg
48 :     | _ => NONE
49 :     (* end case *))
50 :    
51 :     (* TODO: tensor selection operations *)
52 :     fun doAssign (lhs, IR.OP rhs) = (case rhs
53 :     of (Op.IAdd, [a, b]) => (case getRHSOpt b
54 :     of SOME(Op.INeg, [c]) => (
55 :     (* rewrite to "a-c" *)
56 :     ST.tick cntAddNeg;
57 :     decUse b;
58 :     SOME[(lhs, IR.OP(Op.ISub, [a, use c]))])
59 :     | _ => NONE
60 :     (* end case *))
61 :     | (Op.ISub, [a, b]) => if IR.Var.same(a, b)
62 :     then ( (* rewrite to 0 *)
63 :     ST.tick cntSubSame;
64 :     decUse a; decUse b;
65 :     SOME[(lhs, IR.LIT(Literal.Int 0))])
66 :     else (case getRHSOpt b
67 :     of SOME(Op.INeg, [c]) => (
68 :     (* rewrite to "a+c" *)
69 :     ST.tick cntSubNeg;
70 :     decUse b;
71 :     SOME[(lhs, IR.OP(Op.IAdd, [a, use c]))])
72 :     | _ => NONE
73 :     (* end case *))
74 :     | (Op.INeg, [a]) => (case getRHSOpt a
75 :     of SOME(Op.INeg, [b]) => (
76 :     (* rewrite to "b" *)
77 :     ST.tick cntNegNeg;
78 :     decUse a;
79 :     SOME[(lhs, IR.VAR(use b))])
80 :     | _ => NONE
81 :     (* end case *))
82 :     | (Op.IntToReal, [a]) => (case V.getDef a
83 :     of IR.LIT(Literal.Int n) => (
84 :     (* rerite to a real literal *)
85 :     ST.tick cntIntToReal;
86 :     decUse a;
87 :     SOME[(lhs, IR.LIT(Literal.Real(RealLit.fromInt n)))])
88 :     | _ => NONE
89 :     (* end case *))
90 : jhr 3799 | (Op.TensorIndex(Ty.TensorTy dims, idxs), [t]) => let
91 :     fun get ([], [], x) = (
92 :     SOME[(lhs, IR.VAR(use x))])
93 :     | get (ix::ixs, d::ds, x) = (case V.getDef x
94 :     of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
95 : jhr 3805 | _ => SOME[(lhs, IR.OP(Op.TensorIndex(Ty.tensorTy(d::ds), ix::ixs), [use x]))]
96 : jhr 3799 (* end case *))
97 :     | get _ = raise Fail "malformed TensorIndex"
98 :     in
99 :     case V.getDef t
100 :     of IR.CONS _ => (ST.tick cntTensorIndex; decUse t; get(idxs, dims, t))
101 :     | _ => NONE
102 :     (* end case *)
103 :     end
104 :     | (Op.ProjectLast(Ty.TensorTy dims, idxs), [t]) => let
105 :     fun get ([], [_], x) = (
106 :     SOME[(lhs, IR.VAR(use x))])
107 :     | get (ix::ixs, d::ds, x) = (case V.getDef x
108 :     of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
109 :     | _ => SOME[(lhs, IR.OP(Op.ProjectLast(Ty.tensorTy ds, ix::ixs), [use x]))]
110 :     (* end case *))
111 :     | get _ = raise Fail "malformed ProjectLast"
112 :     in
113 :     case V.getDef t
114 :     of IR.CONS _ => (ST.tick cntProjectLast; decUse t; get(idxs, dims, t))
115 :     | _ => NONE
116 :     (* end case *)
117 :     end
118 : jhr 3714 | _ => NONE
119 :     (* end case *))
120 :     | doAssign _ = NONE
121 :    
122 : jhr 3747 fun doMAssign _ = NONE
123 : jhr 3714
124 :     structure Rewrite = RewriteFn (
125 :     struct
126 :     structure IR = IR
127 :     val doAssign = doAssign
128 :     val doMAssign = doMAssign
129 :     val elimUnusedVars = UnusedElim.reduce
130 :     end)
131 :    
132 :     val transform = Rewrite.transform
133 :    
134 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0