6 |
|
|
7 |
structure TypeUtil : sig |
structure TypeUtil : sig |
8 |
|
|
9 |
|
(* prune out instantiated meta variables *) |
10 |
|
val prune : Types.ty -> Types.ty |
11 |
|
val pruneDiff : Types.diff -> Types.diff |
12 |
|
val pruneShape : Types.shape -> Types.shape |
13 |
|
val pruneDim : Types.dim -> Types.dim |
14 |
|
|
15 |
|
(* prune the head of a type *) |
16 |
|
val pruneHead : Types.ty -> Types.ty |
17 |
|
|
18 |
|
(* resolve meta variables to their instantiations (or else variable) *) |
19 |
|
val resolve : Types.ty_var -> Types.ty |
20 |
|
val resolveDiff : Types.diff_var -> Types.diff |
21 |
|
val resolveShape : Types.shape_var -> Types.shape |
22 |
|
val resolveDim : Types.dim_var -> Types.dim |
23 |
|
|
24 |
(* string representations of types, etc *) |
(* string representations of types, etc *) |
25 |
val toString : Types.ty -> string |
val toString : Types.ty -> string |
26 |
val diffToString : Types.diff -> string |
val diffToString : Types.diff -> string |
32 |
structure Ty = Types |
structure Ty = Types |
33 |
structure MV = MetaVar |
structure MV = MetaVar |
34 |
|
|
35 |
|
(* prune out instantiated meta variables from a type *) |
36 |
|
fun prune ty = (case ty |
37 |
|
of (ty as Ty.T_Var(Ty.TV{bind, ...})) => (case !bind |
38 |
|
of NONE => ty |
39 |
|
| SOME ty => prune ty |
40 |
|
(* end case *)) |
41 |
|
| (Ty.T_Kernel diff) => Ty.T_Kernel(pruneDiff diff) |
42 |
|
| (Ty.T_Tensor shape) => Ty.T_Tensor(pruneShape shape) |
43 |
|
| (Ty.T_Image{dim, shape}) => Ty.T_Image{ |
44 |
|
dim = pruneDim dim, |
45 |
|
shape = pruneShape shape |
46 |
|
} |
47 |
|
| (Ty.T_Field{diff, dim, shape}) => Ty.T_Field{ |
48 |
|
diff = pruneDiff diff, |
49 |
|
dim = pruneDim dim, |
50 |
|
shape = pruneShape shape |
51 |
|
} |
52 |
|
| (Ty.T_Fun(tys1, ty2)) => Ty.T_Fun(List.map prune tys1, prune ty2) |
53 |
|
| ty => ty |
54 |
|
(* end case *)) |
55 |
|
|
56 |
|
and pruneDiff (Ty.DiffVar(Ty.DfV{bind=ref(SOME diff), ...}, i)) = ( |
57 |
|
case pruneDiff diff |
58 |
|
of Ty.DiffVar(dv, i') => Ty.DiffVar(dv, i+i') |
59 |
|
| Ty.DiffConst i' => Ty.DiffConst(i+i') |
60 |
|
(* end case *)) |
61 |
|
| pruneDiff diff = diff |
62 |
|
|
63 |
|
and pruneDim dim = (case dim |
64 |
|
of Ty.DimVar(Ty.DV{bind=ref(SOME dim), ...}) => pruneDim dim |
65 |
|
| dim => dim |
66 |
|
(* end case *)) |
67 |
|
|
68 |
|
and pruneShape shape = (case shape |
69 |
|
of Ty.ShapeVar(Ty.SV{bind=ref(SOME shape), ...}) => pruneShape shape |
70 |
|
| Ty.ShapeExt(shape, dim) => Ty.shapeExt(pruneShape shape, pruneDim dim) |
71 |
|
| _ => shape |
72 |
|
(* end case *)) |
73 |
|
|
74 |
|
(* resolve meta variables to their instantiations (or else variable) *) |
75 |
|
fun resolve (tv as Ty.TV{bind, ...}) = (case !bind |
76 |
|
of NONE => Ty.T_Var tv |
77 |
|
| SOME ty => prune ty |
78 |
|
(* end case *)) |
79 |
|
|
80 |
|
fun resolveDiff (dv as Ty.DfV{bind, ...}) = (case !bind |
81 |
|
of NONE => Ty.DiffVar(dv, 0) |
82 |
|
| SOME diff => pruneDiff diff |
83 |
|
(* end case *)) |
84 |
|
|
85 |
|
fun resolveShape (sv as Ty.SV{bind, ...}) = (case !bind |
86 |
|
of NONE => Ty.ShapeVar sv |
87 |
|
| SOME shape => pruneShape shape |
88 |
|
(* end case *)) |
89 |
|
|
90 |
|
fun resolveDim (dv as Ty.DV{bind, ...}) = (case !bind |
91 |
|
of NONE => Ty.DimVar dv |
92 |
|
| SOME dim => pruneDim dim |
93 |
|
(* end case *)) |
94 |
|
|
95 |
|
(* prune the head of a type *) |
96 |
|
fun pruneHead ty = let |
97 |
|
fun prune' (ty as Ty.T_Var(Ty.TV{bind, ...})) = (case !bind |
98 |
|
of NONE => ty |
99 |
|
| SOME ty => prune' ty |
100 |
|
(* end case *)) |
101 |
|
| prune' (Ty.T_Kernel diff) = Ty.T_Kernel(pruneDiff diff) |
102 |
|
| prune' (Ty.T_Tensor shape) = Ty.T_Tensor(pruneShape shape) |
103 |
|
| prune' (Ty.T_Image{dim, shape}) = Ty.T_Image{ |
104 |
|
dim = pruneDim dim, |
105 |
|
shape = pruneShape shape |
106 |
|
} |
107 |
|
| prune' (Ty.T_Field{diff, dim, shape}) = Ty.T_Field{ |
108 |
|
diff = pruneDiff diff, |
109 |
|
dim = pruneDim dim, |
110 |
|
shape = pruneShape shape |
111 |
|
} |
112 |
|
| prune' ty = ty |
113 |
|
in |
114 |
|
prune' ty |
115 |
|
end |
116 |
|
|
117 |
fun listToString fmt sep items = String.concatWith sep (List.map fmt items) |
fun listToString fmt sep items = String.concatWith sep (List.map fmt items) |
118 |
|
|
119 |
fun diffToString (Ty.DiffConst n) = Int.toString n |
fun diffToString diff = (case pruneDiff diff |
120 |
| diffToString (Ty.DiffVar(dv, 0)) = MV.diffVarToString dv |
of Ty.DiffConst n => Int.toString n |
121 |
| diffToString (Ty.DiffVar(dv, i)) = if i < 0 |
| Ty.DiffVar(dv, 0) => MV.diffVarToString dv |
122 |
|
| Ty.DiffVar(dv, i) => if i < 0 |
123 |
then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"] |
then String.concat["(", MV.diffVarToString dv, "-", Int.toString(~i), ")"] |
124 |
else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"] |
else String.concat["(", MV.diffVarToString dv, "+", Int.toString i, ")"] |
125 |
|
(* end case *)) |
126 |
|
|
127 |
fun shapeToString (Ty.Shape shape) = |
fun shapeToString shape = (case pruneShape shape |
128 |
concat["[", listToString dimToString "," shape, "]"] |
of Ty.Shape shape => concat["[", listToString dimToString "," shape, "]"] |
129 |
| shapeToString (Ty.ShapeVar sv) = MV.shapeVarToString sv |
| Ty.ShapeVar sv => MV.shapeVarToString sv |
130 |
| shapeToString (Ty.ShapeExt(shape, d)) = let |
| Ty.ShapeExt(shape, d) => let |
131 |
fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ "," |
fun toS (Ty.Shape shape) = (listToString dimToString "," shape) ^ "," |
132 |
| toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";" |
| toS (Ty.ShapeVar sv) = MV.shapeVarToString sv ^ ";" |
133 |
| toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","] |
| toS (Ty.ShapeExt(shape, d)) = concat[toS shape, dimToString d, ","] |
134 |
in |
in |
135 |
toS shape ^ dimToString d |
toS shape ^ dimToString d |
136 |
end |
end |
137 |
|
(* end case *)) |
138 |
|
|
139 |
and dimToString (Ty.DimConst n) = Int.toString n |
and dimToString dim = (case pruneDim dim |
140 |
| dimToString (Ty.DimVar v) = MV.dimVarToString v |
of Ty.DimConst n => Int.toString n |
141 |
|
| Ty.DimVar v => MV.dimVarToString v |
142 |
|
(* end case *)) |
143 |
|
|
144 |
fun toString ty = (case ty |
fun toString ty = (case pruneHead ty |
145 |
of Ty.T_Var tv => MV.tyVarToString tv |
of Ty.T_Var tv => MV.tyVarToString tv |
146 |
| Ty.T_Bool => "bool" |
| Ty.T_Bool => "bool" |
147 |
| Ty.T_Int => "int" |
| Ty.T_Int => "int" |