Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/ein/check-ein.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/ein/check-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2495, Wed Oct 23 21:28:25 2013 UTC revision 2496, Tue Oct 29 05:29:46 2013 UTC
# Line 1  Line 1 
1  structure TypeCheck = struct  structure TypeEin = struct
2    
3  local  local
4    
5  structure E = Ein  structure E = Ein
6  structure P = Printer  structure P = Printer
 in  
   
 val errTy=0  
 val realTy=1  
 val intTy=2  
 val tensorTy=3  
 val fieldTy=4  
 val kernelTy=5  
 val partialTy=6  
 val epsTy=7  
 val deltaTy=8  
 val imageTy=10  
7    
8  datatype const_type=realTy|intTy  in
9    
10    
11    
12  datatype greek_type=epsTy|deltaTy  datatype greek_type=epsTy|deltaTy
13    
14  datatype ein_type  datatype ein_type   = realTy
     = C of const_type  
     | fld of int  
     | ten of int  
     | kernelTy  
15      | G of greek_type      | G of greek_type
16      | partialTy of int      | ten of E.mu list
17      | imageTy      | fld of E.mu list
18        | fldmid of E.mu list
19        | imageTy of E.mu list
20        | kernelTy of E.mu list
21        | partialTy of E.mu list
22      | errTy      | errTy
23    
   
   
 fun printType(Eintype)=(case Eintype  
     of C realTy=> "real type"  
     | C intTy=> "int type"  
     |  (ten _) =>"tensor type"  
     |  (fld _)=>"field type"  
     | imageTy =>"Image type"  
     |  kernelTy=>"kernel type"  
     |  (partialTy _) =>"partial type"  
     | G epsTy=> "epsilon type"  
     | G deltaTy =>"delta type"  
     | errTy=> "error"  
 (*end case*))  
   
24  (*can only be C type*)  (*can only be C type*)
25      fun position([C c])=C c  fun position([realTy])=realTy
26      | position((C c1)::(C c2)::es)= position([C c1]@es)      | position((realTy)::es)= position(es)
27      | position _= errTy      | position _= errTy
28    
29    
30  fun err(msg)= (print msg;errTy)  fun err(msg)= (print msg;errTy)
31    
 fun checker (Ein.EIN{params, index, body},phase) = let  
32    
33      fun checkParam(id, a,b)= let  fun sortIndex(ilist)=let
34              val p=List.nth(params,id)      fun sort([],ix)=ix
35          in (case p      | sort (E.C _::es,ix)=sort(es,ix)
36              of E.TEN => a      | sort(e::es,ix)=let
37              | E.FLD _ =>b)          val r=List.find (fn(x) => x=e) ix
38            in (case r
39                of NONE=>sort(es, ix@[e])
40                |_=> sort(es, ix)
41                (*end case*))
42          end          end
43        in sort(ilist,[]) end
44    
     (*make sure the finished expression is the shape it is suppose to be *)  
     fun checkbody term = (case term  
         of E.Const r => C realTy  
         | E.Tensor(id, []) => checkParam(id,C realTy, errTy)  
         | E.Tensor(id, [E.C _ ])=> checkParam(id,C realTy,errTy)  
         | E.Tensor(id, a) =>checkParam(id,ten(length a),errTy)  
         | E.Delta(ix)=>G deltaTy  
         | E.Value(ix) =>C intTy  
         | E.Epsilon _ => G epsTy  
         | E.Sum (_,e1)=>checkbody e1  
         | E.Neg e1=> checkbody e1  
         | E.Partial a =>   (partialTy(length a))  
         | E.Add([e1])=>checkbody e1  
         | E.Add es => let  
             val ep=List.map checkbody es  
             fun distribute [fld f]= fld f  
                 | distribute [ten t]= ten t  
                 | distribute [C c]= C c  
                 | distribute [imageTy]= imageTy  
                 | distribute(G _::el)= errTy  
                 | distribute(partialTy _::el)=errTy  
                 | distribute(kernelTy::el)=errTy  
                 | distribute(errTy::el)= errTy  
   
                 | distribute(C realTy::C _::el)=distribute(C realTy::el)  
                 | distribute(C _::C realTy::el)=distribute(C realTy::el)  
                 | distribute(C c::C _::el)=distribute(C c::el)  
45    
                 | distribute(ten T1::ten T2::el)=distribute(ten T1::el)  
                    (* if(T1=T2) then distribute(ten T1::el)  
                     else errTy*)  
                 | distribute(ten T::el)=errTy  
                 | distribute(fld F1:: fld F2::el)=distribute(fld F1::el)  
                     (*if (F1=F2) then distribute(fld F1::el)  
                     else errTy*)  
                 | distribute(imageTy::imageTy::el)=distribute(imageTy::el)  
                 | distribute _= errTy  
             in  
                 distribute ep  
             end  
         | E.Sub(e1, e2) =>(case (checkbody e1 ,checkbody e2)  
             of (C c, C _)=> C c  
             | (ten T1, ten T2)=>ten T1  
                 (*if (T1=T2) then ten T1  
                 else errTy*)  
             |(fld F1,  fld F2)=>fld F1  
                (* if (F1=F2) then fld F1  
                 else errTy*)  
             | (imageTy, imageTy)=>  imageTy  
             | (fld f, C _) => fld f  
             | (C _, fld f) => fld f  
             | (imageTy, C _) => imageTy  
             | (C _, imageTy) => imageTy  
             |_=>errTy)  
46    
47          | E.Prod el =>let  fun sortFldIndex(part,ilist)=let
48              val ep=List.map checkbody el      fun sort([],ix)=ix
49              (*Product is very complicated          | sort (E.C _::es,ix)=sort(es,ix)
50              Need to examine size of tensors and fields to see if a scalar is produced          | sort(e::es,ix)=let
51              *)              val r=List.find (fn(x) => x=e) ix
52                in (case r
53              (*Currently allow tensor * Field. *)                  of NONE=>sort(es, ix@[e])
54              fun distribute ([C c])= C c                  |_=> sort(es, ix)
55                  | distribute([fld f])= fld f                  (*end case*))
56                  | distribute([ten t])=ten t              end
57                  | distribute([partialTy p])= partialTy p          in sort(part,ilist) end
                 | distribute([imageTy ])=imageTy  
   
                 | distribute([e1])= errTy  
                 | distribute(errTy::el)= errTy  
   
                 (*Constants and greeks*)  
                 | distribute(C _::e1::el)=distribute(e1::el)  
                 | distribute(G _::e1::el)=distribute(e1::el)  
   
                 (*R Partials*)  
   
                 | distribute( partialTy p::C _::el)= distribute(partialTy p::el)  
                 | distribute( partialTy p::G _::el)= distribute(partialTy p::el)  
                 | distribute( partialTy p::ten T::el)= distribute(ten T::el)  
                 | distribute( partialTy p::fld F::el)= err "Do not take derivative of fieldty"  
                 | distribute( partialTy p1::partialTy p2::el)= distribute(partialTy(p1+p2)::el)  
                 | distribute( partialTy p::kernelty::el)= err "Do not take derivative of kernelty"  
                 (*partial of imagety ?*)  
   
   
                 (* B Tensors *)  
                 | distribute(ten T ::C _::el)= distribute(ten T::el)  
                 | distribute(ten T::G _::el)= distribute(ten T::el)  
                 | distribute(ten T::ten T2::el)=distribute(ten T::el)  
   
                 |distribute(ten T::fld F::el)= distribute(fld F::el)  
                (* | distribute(ten 0::fld F::el)= err "can not multiply tensor and field"  
                 | distribute(ten T::fld F::el)= err "can not multiply tensor and field"*)  
58    
                 | distribute(ten T::partialTy p ::el)=distribute(ten T::el)  
                 | distribute(ten T::kernelty::el)= err"can not multiply tensor and kernel"  
59    
60    fun evalAdd [fld f]= fld f
61        | evalAdd [ten t]= ten t
62        | evalAdd [realTy]= realTy
63        | evalAdd [imageTy i]= imageTy i
64        | evalAdd(realTyy::realTy::el)=evalAdd(realTy::el)
65        | evalAdd(fld F1:: fld F2::el)=evalAdd(fld F1::el)
66            (*if (F1=F2) then evalAdd(fld F1::el)
67            else errTy*)
68        | evalAdd(ten T1::ten T2::el)=evalAdd(ten T1::el)
69            (* if(T1=T2) then evalAdd(ten T1::el)
70            else errTy*)
71        | evalAdd(imageTy i::imageTy _::el)=evalAdd(imageTy i::el)
72        | evalAdd _= errTy
73    
74    fun evalProd([])=errTy
75        | evalProd(errTy::el)= errTy
76        | evalProd([e1])=e1
77        | evalProd(realTy::el)=evalProd el
78        | evalProd(G g::el)= evalProd el
79                  (*Fields..*)                  (*Fields..*)
80                  | distribute(fld F::C _::el)=distribute(fld F::el)      | evalProd [fld f,realTy] =fld f
81                  | distribute(fld F::G _::el)=distribute(fld F::el)      | evalProd [fld f,G _] =fld f
82        | evalProd([fld f,_])= err "can not multiply field and other type "
83        (* Tensors *)
84        | evalProd [ten t ,realTy]= ten t
85        | evalProd [ten t,G _]= ten t
86        | evalProd [ten t,ten T2]=ten(sortIndex(t@T2))
87        | evalProd [ten t, partialTy p]= ten(sortIndex(t@p))
88        (*kernels*)
89        | evalProd [kernelTy k,realTy]= kernelTy k
90        | evalProd [kernelTy k,kernelTy _]= kernelTy k
91        | evalProd [kernelTy k,imageTy i]= fldmid(i@k)
92        (*Partials*)
93        | evalProd [partialTy p,realTy]= partialTy p
94        | evalProd [partialTy p,G _]= partialTy p
95        | evalProd [partialTy p,ten T]= ten(sortIndex(T@p))
96        | evalProd [partialTy p,partialTy p2]= partialTy(p@p2)
97        (*Image *)
98        | evalProd [imageTy i,realTy]= imageTy i
99        | evalProd [imageTy i ,G _]= imageTy i
100        | evalProd [imageTy i ,imageTy i2]=  imageTy(i@i2)
101        | evalProd [imageTy i, kernelTy k]= fldmid(i@k)
102        | evalProd [a,b]= errTy
103        | evalProd (e1::es)= evalProd [e1,evalProd(es)]
104    
105                  | distribute(fld F::ten t::el)=fld F  fun evalSub(a,b)=(case (a,b)
106              (*    | distribute(fld F::ten t ::el)= err "can not multiply field and tendot type"*)      of (realTy, realTy)=>realTy
107        | (ten T1, ten T2)=>ten T1
108                  | distribute(fld F:: _::el)= err "can not multiply field and other type "      |(fld F1,  fld F2)=>fld F1
109        | (imageTy i, imageTy _)=>  imageTy i
110                  | distribute(imageTy::el)= imageTy      | (fldmid f, fldmid _)=> fldmid f
111                  | distribute _ = errTy      | (fldmid f, realTy)=> fldmid f
112              in distribute ep      | (realTy, fldmid  f)=> fldmid f
113              end      | (fld f, realTy) => fld f
114          | E.Div(e1,e2)=>(case (checkbody e1,checkbody e2)      | (realTy, fld f) => fld f
115              of(C _ ,C a)=>C a      | (imageTy i, realTy) => imageTy i
116              | (fld f, C_)=> fld f      | (realTy, imageTy i) => imageTy i
             | (ten t, C _)=>ten t  
117              | _=>errTy)              | _=>errTy)
         (*Phase dependent operators*)  
118    
119          (* Phase 1, After normalize Before Probe-ein *)  fun evalDiv(a,b)=(case (a,b)
120          | E.Apply(e1, e2)=>      of(realTy ,realTy)=>realTy
121              if (phase =1) then      | (fld f, realTy)=> fld f
122                  (case (checkbody e1,checkbody e2)      | (ten t, realTy)=>ten t
                 of (partialTy a, fld b)=>fld( a+ b)  
123                  |_ =>errTy)                  |_ =>errTy)
124              else err "wrong phase for apply"  
125          (*Phase 1 - 2*)  fun evalProbe(a,b,phase)=if (phase>1) then err "wrong phase for Probe op"
126          | E.Probe (e1,e2)=>      else (case (a,b)
             if (3>phase ) then  
                 (case (checkbody e1, checkbody e2)  
127                  of (fld f,ten _)=>fld f                  of (fld f,ten _)=>fld f
128                  | (fld f, C c)=>fld f          | (fld f, realTy)=>fld f
129                  | (fld f, _)=> err "wrong pos for field probe"                  | (fld f, _)=> err "wrong pos for field probe"
130                  |  _=>err "Not a fieldTy in probe"                  |  _=>err "Not a fieldTy in probe"
131                  (*end case*))                  (*end case*))
             else err "wrong phase for Probe op"  
132    
133    fun evalKrn(dels,phase)=  if (3>phase) then err "wrong phrase for kernel"
134        else let
135            fun size([])=[]
136            | size((i ,j)::dels)= [j]@ size(dels)
137            in  kernelTy(size(dels)) end
138    
139    fun evalApply(e1,e2,phase)=
140        if (phase>1) then err "wrong phase for apply"
141        else (case (e1,e2)
142            of (partialTy a, fld b)=>fld(sortFldIndex(a,b))
143            |_ =>errTy
144            (*end case*))
145    
146    fun checkTenParam(id,params, ix)= let
147        val p=List.nth(params,id)
148        in (case p
149        of E.TEN => let
150            val m = (sortIndex(ix))
151            in(case m
152                of []=> realTy
153                | _=> ten m)
154            end
155        |  _ =>errTy)
156        end
157    
158    fun checkFldParam(id, params,ix)= let
159        val p=List.nth(params,id)
160        in (case p
161            of  E.FLD _ =>fld ix
162            |_=> errTy)
163        end
164    
165    
166    fun checker (Ein.EIN{params, index, body},phase) = let
167    
168    
169    
170    
171        (*make sure the finished expression is the shape it is suppose to be *)
172        fun checkbody term = (case term
173            of E.Const r => realTy
174            | E.Tensor(id, ix) =>checkTenParam(id,params,ix)
175            | E.Delta(ix)=>G deltaTy
176            | E.Value(ix) =>realTy
177            | E.Epsilon _ => G epsTy
178            | E.Sum (_,e1)=>checkbody e1
179            | E.Neg e1=> checkbody e1
180            | E.Partial a =>   partialTy(sortIndex(a))
181            | E.Add es => evalAdd(List.map checkbody es)
182            | E.Sub(e1, e2) =>evalSub(checkbody e1 ,checkbody e2)
183            | E.Prod el => evalProd(List.map checkbody el)
184            | E.Div(e1,e2)=> evalDiv (checkbody e1,checkbody e2)
185            (*Phase dependent operators*)
186          | E.Field(id, alpha)=>          | E.Field(id, alpha)=>
187              if (3>phase) then checkParam(id,errTy, fld(length alpha))              if (phase>1) then err "wrong phase for Field"
188               else err "wrong phase for Field"              else checkFldParam(id,params,alpha)
189          (*Phase 2, After Probe-ein*)          | E.Apply(e1, e2)=> evalApply(checkbody e1,checkbody e2,phase)
190          | E.Conv(e1, a)=>          | E.Probe (e1,e2)=>evalProbe(checkbody e1, checkbody e2,phase)
191              if (phase=2) then          | E.Conv (fid,alpha, tid, beta)=>
192                  (case checkbody e1              if (phase>1) then err "wrong phase for convolution"
193                  of fld f=>fld(f+(length a))              else (case checkFldParam(fid, params,alpha)
194                  | _ =>err "convolution does not have a fieldtype")                  of fld f=> fld(sortFldIndex(beta,f))
195              else err "wrong phase for convolution"                  |_=> errTy
         (*Phase 3 Mid-IL*)  
         | E.Krn _ =>  
             if (3>phase) then err "wrong phrase for kernel"  
             else kernelTy  
         | E.Img((id,ix,pos),hs)=>  
             if(3 >phase) then errTy  
             else (let  
                 fun  findK([])=kernelTy  
                 | findK(kernelTy::es)=findK(es)  
                 | findK _=errTy  
   
                 val p'= position(List.map checkbody pos)  
                 val k'=findK (List.map checkbody hs)  
   
                 in  (case  (p',k')  
                 of (errTy,_)=> err "Not an image position"  
                 | (_,errTy)=>err "Not all kernel types"  
                 | _ => imageTy  
196                  (*end case*))                  (*end case*))
             end)  
197    
198            (*Phase 2 Mid-IL*)
199            | E.Krn (_,dels,_) =>evalKrn(dels,phase)
200            | E.Img(id,ix,pos)=>
201                if(3 >phase) then errTy
202                else (case position(List.map checkbody pos)
203                    of errTy=> err "Not an image position"
204                    | _ =>(case checkFldParam(id, params,[])
205                        of errTy=> errTy
206                        |_=>imageTy(sortIndex(ix))
207                        (*end case*))
208                    (*end case*))
209          (* end case *))          (* end case *))
210      in      in
211          checkbody body          checkbody body

Legend:
Removed from v.2495  
changed lines
  Added in v.2496

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0