Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis12/src/compiler/typechecker/util.sml
ViewVC logotype

Annotation of /branches/vis12/src/compiler/typechecker/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 81 - (view) (download)
Original Path: trunk/src/typechecker/util.sml

1 : jhr 80 (* util.sml
2 :     *
3 :     * COPYRIGHT (c) 2010 The Diderot Project (http://diderot.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *
6 :     * Utilities for typechecking
7 :     *)
8 :    
9 :     structure Util =
10 :     struct
11 :    
12 :     structure Ty = Types
13 : jhr 81 structure MV = MetaVar
14 : jhr 80
15 :     (* prune out instantiated meta variables from a type *)
16 :     fun prune ty = let
17 :     fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind
18 :     of NONE => ty
19 :     | SOME ty => prune' ty
20 :     (* end case *))
21 :     | prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff)
22 :     | prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape)
23 :     | prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{
24 :     dim = pruneDim dim,
25 :     shape = pruneShape shape
26 :     }
27 :     | prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{
28 :     diff = pruneDiff diff,
29 :     dim = pruneDim dim,
30 :     shape = pruneShape shape
31 :     }
32 : jhr 81 | prune' (Ty.T_Fun(tys1, ty2)) = Ty.T_Fun(List.map prune' tys1, prune' ty2)
33 : jhr 80 | prune' ty = ty
34 :     in
35 :     prune' ty
36 :     end
37 :    
38 : jhr 81 and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = (
39 :     case pruneDiff diff
40 :     of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i')
41 :     | Ty.DiffConst i' => Ty.DiffConst(i+i')
42 :     (* end case *))
43 :     | pruneDiff diff = diff
44 :    
45 :     and pruneDim dim = (case dim
46 :     of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim
47 :     | dim => dim
48 :     (* end case *))
49 :    
50 :     and pruneShape shape = (case shape
51 :     of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape
52 :     | Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim)
53 :     | _ => shape
54 :     (* end case *))
55 :    
56 :     (* FIXME: what about the bounds? *)
57 :     fun matchDiff (diff1, diff2) = (case (pruneDiff diff1, pruneDiff diff2)
58 :     of (Ty.DiffConst k1, Ty.DiffConst k2) => (k1 = k2)
59 :     | (Ty.DiffConst k, Ty.DiffVar(Ty.DfV{bind, bound, ...}, i)) => let
60 :     val k' = k+i
61 :     in
62 :     if k' < 0 then false
63 :     else (bind := SOME(Ty.DiffConst k'); true)
64 :     end
65 :     | (Ty.DiffVar(Ty.DfV{bind, bound, ...}, i), Ty.DiffConst k) => let
66 :     val k' = k+i
67 :     in
68 :     if k' < 0 then false
69 :     else (bind := SOME(Ty.DiffConst k'); true)
70 :     end
71 :     | (Ty.DiffVar(dv1, i1), Ty.DiffVar(dv2, i2)) => raise Fail "unimplemented" (* FIXME *)
72 :     (* end case *))
73 :    
74 :     fun matchDim (dim1, dim2) = (case (pruneDim dim1, pruneDim dim2)
75 :     of (Ty.DimConst d1, Ty.DimConst d2) => (d1 = d2)
76 :     | (Ty.DimVar(Ty.DV{bind, ...}), dim2) => (bind := SOME dim2; true)
77 :     | (dim1, Ty.DimVar(Ty.DV{bind, ...})) => (bind := SOME dim1; true)
78 :     (* end case *))
79 :    
80 :     fun matchShape (shape1, shape2) = (case (pruneShape shape1, pruneShape shape2)
81 :     of (Ty.Shape dd1, Ty.Shape dd2) => let
82 :     fun chk ([], []) = true
83 :     | chk (d1::dd1, d2::dd2) = matchDim(d1, d2) andalso chk (dd1, dd2)
84 :     | chk _ = false
85 :     in
86 :     chk (dd1, dd2)
87 :     end
88 :     | (Ty.Shape dd, Ty.ShapeExt(shape, d2)) => let
89 :     fun chk ([], _) = false
90 :     | chk ([d], revDD) =
91 :     matchDim(d, d2) andalso matchShape(Ty.Shape(List.rev revDD), shape)
92 :     | chk (d::dd, revDD) = chk(dd, d::revDD)
93 :     in
94 :     chk (dd, [])
95 :     end
96 :     | (Ty.ShapeVar(Ty.SV{bind, ...}), shape) => (bind := SOME shape; true)
97 :     | (Ty.ShapeExt(shape1, d1), Ty.ShapeExt(shape2, d2)) =>
98 :     matchDim(d1, d2) andalso matchShape(shape1, shape2)
99 :     | (shape1, shape2) => matchShape(shape2, shape1)
100 :     (* end case *))
101 :    
102 :     (* QUESTION: do we need an occurs check? *)
103 :     fun matchType (ty1, ty2) = let
104 :     fun setBind (Ty.TV{bind=ref(SOME _), ...}, _) = raise Fail "prune fail"
105 :     | setBind (Ty.TV{bind, ...}, ty) = bind := SOME ty
106 :     fun matchVar (tv1 as Ty.TV{id=id1, ...}, tv2 as Ty.TV{id=id2, ...}) =
107 :     if Stamp.same(id1, id2)
108 :     then ()
109 :     else setBind (tv1, Ty.T_Var tv2)
110 :     fun match (Ty.T_Var tv1, Ty.T_Var tv2) = (matchVar(tv1, tv2); true)
111 :     | match (Ty.T_Var tv1, ty2) = (setBind(tv1, ty2); true)
112 :     | match (ty1, Ty.T_Var tv2) = (setBind(tv2, ty2); true)
113 : jhr 80 | match (Ty.T_Bool, Ty.T_Bool) = true
114 :     | match (Ty.T_Int, Ty.T_Int) = true
115 :     | match (Ty.T_String, Ty.T_String) = true
116 : jhr 81 | match (Ty.T_Kernel k1, Ty.T_Kernel k2) = matchDiff (k1, k2)
117 :     | match (Ty.T_Tensor s1, Ty.T_Tensor s2) = matchShape (s1, s2)
118 : jhr 80 | match (Ty.T_Image{dim=d1, shape=s1}, Ty.T_Image{dim=d2, shape=s2}) =
119 : jhr 81 matchDim (d1, d2) andalso matchShape(s1, s2)
120 : jhr 80 | match (Ty.T_Field{diff=k1, dim=d1, shape=s1}, Ty.T_Field{diff=k2, dim=d2, shape=s2}) =
121 : jhr 81 matchDiff (k1, k2) andalso matchDim (d1, d2) andalso matchShape(s1, s2)
122 :     | match (Ty.T_Fun(tys11, ty12), Ty.T_Fun(tys21, ty22)) =
123 :     matchTypes (tys11, tys21) andalso match (ty12, ty22)
124 : jhr 80 | match _ = false
125 :     in
126 :     match (prune ty1, prune ty2)
127 :     end
128 :    
129 : jhr 81 and matchTypes (tys1, tys2) = ListPair.allEq matchType (tys1, tys2)
130 :    
131 :     (* instantiate a type scheme, returning the argument meta variables and the resulting type.
132 :     * Note that we assume that the scheme is closed.
133 :     *)
134 :     fun instantiate ([], ty) = ([], ty)
135 :     | instantiate (mvs, ty) = let
136 :     fun instantiateVar (mv, (mvs, env)) = let
137 :     val mv' = MV.copy mv
138 :     in
139 :     (mv'::mvs, MV.Map.insert(env, mv, mv'))
140 :     end
141 :     val (mvs, env) = List.foldl instantiateVar ([], MV.Map.empty) mvs
142 :     fun ity (Ty.T_Var tv) = (case MV.Map.find(env, Ty.TYPE tv)
143 :     of SOME(Ty.TYPE tv) => Ty.T_Var tv
144 :     | _ => raise Fail "impossible"
145 :     (* end case *))
146 :     in
147 :     raise Fail "unimplemented"
148 :     end
149 :    
150 : jhr 80 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0