SCM Repository
Annotation of /branches/vis15/src/compiler/simplify/inliner.sml
Parent Directory
|
Revision Log
Revision 3485 - (view) (download)
1 : | jhr | 3437 | (* inliner.sml |
2 : | * | ||
3 : | * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) | ||
4 : | * | ||
5 : | * COPYRIGHT (c) 2015 The University of Chicago | ||
6 : | * All rights reserved. | ||
7 : | * | ||
8 : | * This pass eliminates the function definitions by inlining them. | ||
9 : | *) | ||
10 : | |||
11 : | structure Inliner : sig | ||
12 : | |||
13 : | val transform : Simple.program -> Simple.program | ||
14 : | |||
15 : | end = struct | ||
16 : | |||
17 : | structure S = Simple | ||
18 : | structure V = SimpleVar | ||
19 : | |||
20 : | (* beta reduce the application "lhs = f(args)" by creating a fresh copy of f's body | ||
21 : | * while mapping the parameters to arguments. | ||
22 : | *) | ||
23 : | fun beta (lhs, S.Func{f, params, body}, args) = let | ||
24 : | val needsLHSPreDecl = ref false (* set to true if the lhs needs to be declared before the body *) | ||
25 : | fun rename env x = (case V.Map.find(env, x) | ||
26 : | of SOME x' => x' | ||
27 : | | NONE => if SimpleVar.hasGlobalScope x | ||
28 : | then x | ||
29 : | else raise Fail("unknown variable " ^ V.uniqueNameOf x) | ||
30 : | (* end case *)) | ||
31 : | fun doBlock (env, isTop, S.Block stms) = let | ||
32 : | fun f (stm, (env, stms)) = let | ||
33 : | val (env, stm) = doStmt (env, isTop, stm) | ||
34 : | in | ||
35 : | (env, stm::stms) | ||
36 : | end | ||
37 : | val (_, stms) = List.foldl f (env, []) stms | ||
38 : | in | ||
39 : | S.Block(List.rev stms) | ||
40 : | end | ||
41 : | and doStmt (env, isTop, stm) = (case stm | ||
42 : | jhr | 3465 | of S.S_Var(x, optE) => let |
43 : | val x' = V.copy(x, V.kindOf x) | ||
44 : | val optE' = Option.map (doExp env) optE | ||
45 : | jhr | 3437 | in |
46 : | jhr | 3465 | (V.Map.insert(env, x, x'), S.S_Var(x', optE')) |
47 : | jhr | 3437 | end |
48 : | jhr | 3465 | | S.S_Assign(x, e) => (env, S.S_Assign(rename env x, doExp env e)) |
49 : | jhr | 3437 | | S.S_IfThenElse(x, b1, b2) => |
50 : | (env, S.S_IfThenElse(rename env x, doBlock(env, false, b1), doBlock(env, false, b2))) | ||
51 : | jhr | 3453 | | S.S_Foreach(x, xs, blk) => |
52 : | (env, S.S_Foreach(rename env x, rename env xs, doBlock(env, false, blk))) | ||
53 : | jhr | 3437 | | S.S_New(strnd, xs) => (env, S.S_New(strnd, List.map (rename env) xs)) |
54 : | jhr | 3453 | | S.S_Continue => (env, stm) |
55 : | jhr | 3437 | | S.S_Die => (env, stm) |
56 : | | S.S_Stabilize => (env, stm) | ||
57 : | | S.S_Return x => ( | ||
58 : | if not isTop then needsLHSPreDecl := true else (); | ||
59 : | (env, S.S_Assign(lhs, S.E_Var(rename env x)))) | ||
60 : | | S.S_Print xs => (env, S.S_Print(List.map (rename env) xs)) | ||
61 : | jhr | 3465 | | S.S_MapReduce _ => raise Fail "unexpected MapReduce in function" |
62 : | jhr | 3437 | (* end case *)) |
63 : | and doExp env exp = (case exp | ||
64 : | of S.E_Var x => S.E_Var(rename env x) | ||
65 : | jhr | 3456 | | S.E_Select(x, fld) => S.E_Select(rename env x, fld) |
66 : | jhr | 3437 | | S.E_Lit _ => exp |
67 : | | S.E_Apply(f, xs, ty) => S.E_Apply(f, List.map (rename env) xs, ty) | ||
68 : | | S.E_Prim(f, tys, xs, ty) => | ||
69 : | S.E_Prim(f, tys, List.map (rename env) xs, ty) | ||
70 : | jhr | 3452 | | S.E_Tensor(xs, ty) => S.E_Tensor(List.map (rename env) xs, ty) |
71 : | jhr | 3437 | | S.E_Seq(xs, ty) => S.E_Seq(List.map (rename env) xs, ty) |
72 : | | S.E_Slice(x, xs, ty) => | ||
73 : | S.E_Slice(rename env x, List.map (Option.map (rename env)) xs, ty) | ||
74 : | | S.E_Coerce{srcTy, dstTy, x} => | ||
75 : | S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=rename env x} | ||
76 : | | S.E_LoadSeq _ => exp | ||
77 : | | S.E_LoadImage _ => exp | ||
78 : | (* end case *)) | ||
79 : | (* build the initial environment by mapping parameters to arguments *) | ||
80 : | val env = ListPair.foldlEq | ||
81 : | (fn (x, x', env) => V.Map.insert(env, x, x')) | ||
82 : | V.Map.empty (params, args) | ||
83 : | val blk as S.Block stms = doBlock (env, true, body) | ||
84 : | in | ||
85 : | if !needsLHSPreDecl | ||
86 : | jhr | 3465 | then S.Block(S.S_Var(lhs, NONE) :: stms) |
87 : | jhr | 3437 | else blk |
88 : | end | ||
89 : | |||
90 : | (* inline expand user-function calls in a block *) | ||
91 : | fun expandBlock funcTbl = let | ||
92 : | val findFunc = V.Tbl.find funcTbl | ||
93 : | fun expandBlk (S.Block stms) = | ||
94 : | S.Block(List.foldr expandStm [] stms) | ||
95 : | and expandStm (stm, stms') = (case stm | ||
96 : | of S.S_Assign(x, S.E_Apply(f, xs, _)) => (case findFunc f | ||
97 : | of NONE => stm :: stms' | ||
98 : | | SOME func => let | ||
99 : | val S.Block stms = beta(x, func, xs) | ||
100 : | in | ||
101 : | stms @ stms' | ||
102 : | end | ||
103 : | (* end case *)) | ||
104 : | | S.S_IfThenElse(x, b1, b2) => | ||
105 : | S.S_IfThenElse(x, expandBlk b1, expandBlk b2) :: stms' | ||
106 : | | _ => stm :: stms' | ||
107 : | (* end case *)) | ||
108 : | in | ||
109 : | expandBlk | ||
110 : | end | ||
111 : | |||
112 : | fun expandFunc funcTbl (S.Func{f, params, body}) = let | ||
113 : | val body' = expandBlock funcTbl body | ||
114 : | val func' = S.Func{f=f, params=params, body=body'} | ||
115 : | in | ||
116 : | V.Tbl.insert funcTbl (f, func') | ||
117 : | end | ||
118 : | |||
119 : | fun expandStrand funcTbl = let | ||
120 : | val expandBlock = expandBlock funcTbl | ||
121 : | jhr | 3451 | fun expand (S.Strand{name, params, state, stateInit, initM, updateM, stabilizeM}) = |
122 : | S.Strand{ | ||
123 : | name = name, | ||
124 : | params = params, | ||
125 : | state = state, | ||
126 : | stateInit = expandBlock stateInit, | ||
127 : | initM = Option.map expandBlock initM, | ||
128 : | updateM = expandBlock updateM, | ||
129 : | stabilizeM = Option.map expandBlock stabilizeM | ||
130 : | } | ||
131 : | jhr | 3437 | in |
132 : | expand | ||
133 : | end | ||
134 : | |||
135 : | fun transform (prog as S.Program{funcs=[], ...}) = prog | ||
136 : | jhr | 3451 | | transform prog = let |
137 : | jhr | 3456 | val S.Program{props, consts, inputs, constInit, globals, funcs, init, strand, create, update} = prog |
138 : | jhr | 3437 | (* a table that maps function names to their definitions *) |
139 : | val funcTbl = V.Tbl.mkTable (List.length funcs, Fail "funcTbl") | ||
140 : | (* first we inline expand the function bodies in definition order *) | ||
141 : | val _ = List.app (expandFunc funcTbl) funcs | ||
142 : | val expandBlock = expandBlock funcTbl | ||
143 : | in | ||
144 : | S.Program{ | ||
145 : | props = props, | ||
146 : | jhr | 3456 | consts = consts, |
147 : | jhr | 3437 | inputs = inputs, |
148 : | jhr | 3456 | constInit = constInit, |
149 : | jhr | 3437 | globals = globals, |
150 : | jhr | 3451 | init = expandBlock init, |
151 : | jhr | 3437 | funcs = [], |
152 : | jhr | 3451 | strand = expandStrand funcTbl strand, |
153 : | jhr | 3453 | create = (case create |
154 : | jhr | 3485 | of S.Create{dim, code} => S.Create{dim = dim, code = expandBlock code} |
155 : | jhr | 3453 | (* end case *)), |
156 : | jhr | 3451 | update = Option.map expandBlock update |
157 : | jhr | 3437 | } |
158 : | end | ||
159 : | |||
160 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |