Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/ein/type-ein.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/ein/type-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2520, Mon Dec 30 05:08:47 2013 UTC revision 2521, Thu Jan 9 02:17:07 2014 UTC
# Line 12  Line 12 
12  datatype greek_type=epsTy|deltaTy  datatype greek_type=epsTy|deltaTy
13    
14  datatype ein_type   = realTy  datatype ein_type   = realTy
     | G of E.mu list  
15      | ten of E.mu list      | ten of E.mu list
16      | fld of E.mu list      | fld of E.mu list
     | fldmid of E.mu list  
     | imageTy of E.mu list  
     | kernelTy of E.mu list  
     | partialTy of E.mu list  
     | errTy  
17    
 (*can only be C type*)  
 fun position([realTy])=realTy  
     | position((realTy)::es)= position(es)  
     | position _= errTy  
   
   
 fun err(msg)= (errTy)  
18    
19    
20  fun printIndex([])=""  fun printIndex([])=""
21      | printIndex(E.C x::ix)=  String.concat[Int.toString(x),printIndex(ix)]      | printIndex(E.C x::ix)=  String.concat[Int.toString(x),printIndex(ix)]
22      | printIndex(E.V v::ix)= String.concat[Int.toString(v),printIndex(ix)]      | printIndex(E.V v::ix)= String.concat[Int.toString(v),printIndex(ix)]
23    
24  fun printTy(ty)= (case ty of realTy=> print "realTy"  fun printTy ty= (case ty
25      | G _=> print "Greek Type"      of realTy=>  "realTy"
26      | ten ix=> print (String.concat["TEN-" ,printIndex(ix)])      | ten ix=> String.concat["TEN-" ,printIndex(ix)]
27      | fld  ix=> print (String.concat["FLD-HIGH-" ,printIndex(ix)])      | fld  []=>  "Scalar FLD"
28      | fldmid ix=> print (String.concat["FLD-MID-" ,printIndex(ix)])      | fld  ix=>String.concat["FLD" ,printIndex(ix)]
     | imageTy  ix=> print (String.concat["Image-" ,printIndex(ix)])  
     | kernelTy  ix=> print (String.concat["KRN-" ,printIndex(ix)])  
     | partialTy  ix=> print (String.concat["Partial-" ,printIndex(ix)])  
     | errTy=> print "err")  
   
   
   
 (*single index no duplicates*)  
 fun sortIndex(ilist)=let  
     fun sort([],ix)=ix  
     | sort (E.C _::es,ix)=sort(es,ix)  
     | sort(e::es,ix)=let  
         val r=List.find (fn(x) => x=e) ix  
         in (case r  
             of NONE=>sort(es, ix@[e])  
             |_=> sort(es, ix)  
             (*end case*))  
             end  
     in sort(ilist,[]) end  
   
   
 (*compares one list of indices to another*)  
 fun sortFldIndex(part,ilist)=let  
     fun sort([],ix)=ix  
         | sort (E.C _::es,ix)=sort(es,ix)  
         | sort(e::es,ix)=let  
             val r=List.find (fn(x) => x=e) ix  
             in (case r  
                 of NONE=>sort(es, ix@[e])  
                 |_=> sort(es, ix)  
                 (*end case*))  
             end  
         in sort(part,ilist) end  
   
   
 fun removeSumIndex(ilist,sum)=let  
     fun sort([],ix,rest)=rest  
     | sort (E.C _::es,ix,rest)=sort(es,ix,rest)  
     | sort(e::es,ix,rest)=let  
         val r=List.find (fn((x,_,_)) => x=e) ix  
         in (case r  
         of NONE=>sort(es, ix,rest@[e])  
         |_=> sort(es, ix,rest)  
29          (*end case*))          (*end case*))
         end  
     in sort(ilist,sum,[]) end  
30    
31    
32    
 fun evalAdd [fld f]= fld f  
     | evalAdd [ten t]= ten t  
     | evalAdd [realTy]= realTy  
     | evalAdd [imageTy i]= imageTy i  
     | evalAdd(realTyy::realTy::el)=evalAdd(realTy::el)  
     | evalAdd(fld F1:: fld F2::el)=evalAdd(fld F1::el)  
         (*if (F1=F2) then evalAdd(fld F1::el)  
         else errTy*)  
     | evalAdd(ten T1::ten T2::el)=evalAdd(ten T1::el)  
         (* if(T1=T2) then evalAdd(ten T1::el)  
         else errTy*)  
     | evalAdd(imageTy i::imageTy _::el)=evalAdd(imageTy i::el)  
     | evalAdd _= errTy  
   
 fun evalProd([])=errTy  
     | evalProd(errTy::el)= errTy  
     | evalProd([e1])=e1  
     | evalProd(realTy::el)=evalProd el  
     | evalProd(G g::ten t::es)= evalProd (ten(t@g)::es)  
     | evalProd(G g::fld t::es)= evalProd (fld(t@g)::es)  
     (*Fields..*)  
     | evalProd [fld f,realTy] =fld f  
     | evalProd [fld f,G _] =fld f  
     | evalProd [fld t, partialTy p]= ten(sortIndex(t@p))  
     | evalProd([fld f,_])= err "can not multiply field and other type "  
     (* Tensors *)  
     | evalProd [ten t ,realTy]= ten t  
     | evalProd [ten t,G _]= ten t  
     | evalProd [ten t,ten T2]=ten(sortIndex(t@T2))  
     | evalProd [ten t, partialTy p]= ten(sortIndex(t@p))  
     (*kernels*)  
     | evalProd [kernelTy k,realTy]= kernelTy k  
     | evalProd [kernelTy k,kernelTy _]= kernelTy k  
     | evalProd [kernelTy k,imageTy i]= fldmid(i@k)  
     (*Partials*)  
     | evalProd [partialTy p,realTy]= partialTy p  
     | evalProd [partialTy p,G _]= partialTy p  
     | evalProd [partialTy p,ten T]= ten(sortIndex(T@p))  
     | evalProd [partialTy p,fld T]= ten(sortIndex(T@p))  
     | evalProd [partialTy p,partialTy p2]= partialTy(p@p2)  
     (*Image *)  
     | evalProd [imageTy i,realTy]= imageTy i  
     | evalProd [imageTy i ,G _]= imageTy i  
     | evalProd [imageTy i ,imageTy i2]=  imageTy(i@i2)  
     | evalProd [imageTy i, kernelTy k]= fldmid(i@k)  
     | evalProd [a,b]= errTy  
     | evalProd (e1::es)= evalProd [e1,evalProd(es)]  
   
 fun evalSub(a,b)=(case (a,b)  
     of (realTy, realTy)=>realTy  
     | (ten T1, ten T2)=>ten T1  
     |(fld F1,  fld F2)=>fld F1  
     | (imageTy i, imageTy _)=>  imageTy i  
     | (fldmid f, fldmid _)=> fldmid f  
     | (fldmid f, realTy)=> fldmid f  
     | (realTy, fldmid  f)=> fldmid f  
     | (fld f, realTy) => fld f  
     | (realTy, fld f) => fld f  
     | (imageTy i, realTy) => imageTy i  
     | (realTy, imageTy i) => imageTy i  
     |_=>errTy)  
   
 fun evalDiv(a,b)=(case (a,b)  
     of(realTy ,realTy)=>realTy  
     | (fld f, realTy)=> fld f  
     | (ten t, realTy)=>ten t  
     | _=>errTy)  
   
 fun evalProbe(a,b,phase)=if (phase>1) then err "wrong phase for Probe op"  
     else (case (a,b)  
         of (fld f,ten _)=>fld f  
         | (fld f, realTy)=>fld f  
         | (fld f, _)=> err "wrong pos for field probe"  
         |  _=>err "Not a fieldTy in probe"  
         (*end case*))  
   
 fun evalKrn(dels,phase)=  if (3>phase) then err "wrong phrase for kernel"  
     else let  
         fun size([])=[]  
         | size((i ,j)::dels)= [j]@ size(dels)  
         in  kernelTy(size(dels)) end  
   
 fun evalApply(e1,e2,phase)=  
     if (phase>1) then err "wrong phase for apply"  
     else (case (e1,e2)  
         of (partialTy a, fld b)=>fld(a@b)  
         |_ =>errTy  
         (*end case*))  
   
   
 fun evalSum(sx,m)=(case m  
         of ten ix=>(  
             let val ix'=removeSumIndex(ix,sx)  
                 in (case ix' of [] => realTy  
             |_ =>ten ix')end  
             (*end case*))  
             | fld ix=>(  
                 let val ix'=removeSumIndex(ix,sx)  
                 in (case ix' of [] => realTy  
                 |_ =>fld ix')end  
                 (*end case*))  
             | fldmid ix =>m  
             | realTy=>realTy  
             |_=> errTy  
         (*end case*))  
   
   
 fun checkTenParam(id,params, ix)=  
     if(id>length(params))then (print "in here";errTy)  
     else(let  
     val p=List.nth(params,id)  
     in (case p  
     of E.TEN => let  
         val m = (sortIndex(ix))  
         in(case m  
             of []=> realTy  
             | _=> ten m)  
         end  
     |  _ =>errTy)  
     end)  
   
 fun checkFldParam(id, params,ix)=  
         if(id>length(params))then (print "in here";errTy)  
         else(  
             let  
     val p=List.nth(params,id)  
     in (case p  
         of  E.FLD _ =>fld ix  
         |_=> errTy)  
     end)  
   
   
 fun checker (Ein.EIN{params, index, body},phase) = let  
   
   
   
   
     (*make sure the finished expression is the shape it is suppose to be *)  
     fun checkbody term = (case term  
         of E.Const r => realTy  
         | E.Tensor(id, ix) =>checkTenParam(id,params,ix)  
         | E.Delta(i,j)=>G [i,j]  
         | E.Value(ix) =>realTy  
         | E.Epsilon(i,j,k) => G [E.V i,E.V j,E.V k]  
         | E.Sum (sx,e1)=> evalSum(sx,checkbody e1)  
   
   
         | E.Neg e1=> checkbody e1  
         | E.Partial a =>   partialTy(sortIndex(a))  
         | E.Add es => evalAdd(List.map checkbody es)  
         | E.Sub(e1, e2) =>evalSub(checkbody e1 ,checkbody e2)  
         | E.Prod el => evalProd(List.map checkbody el)  
         | E.Div(e1,e2)=> evalDiv (checkbody e1,checkbody e2)  
         (*Phase dependent operators*)  
         | E.Field(id, alpha)=>  
             if (phase>1) then err "wrong phase for Field"  
             else checkFldParam(id,params,alpha)  
         | E.Apply(e1, e2)=> evalApply(checkbody e1,checkbody e2,phase)  
         | E.Probe (e1,e2)=>evalProbe(checkbody e1, checkbody e2,phase)  
         | E.Conv (fid,alpha, tid, beta)=>  
             if (phase>1) then err "wrong phase for convolution"  
             else (case checkFldParam(fid, params,alpha)  
                 of fld f=> fld(sortFldIndex(beta,f))  
                 |_=> errTy  
                 (*end case*))  
   
         (*Phase 2 Mid-IL*)  
         | E.Krn (_,dels,_) =>evalKrn(dels,phase)  
         | E.Img(id,ix,pos)=>  
             if(3 >phase) then errTy  
             else (case position(List.map checkbody pos)  
                 of errTy=> err "Not an image position"  
                 | _ =>(case checkFldParam(id, params,[])  
                     of errTy=> errTy  
                     |_=>imageTy(sortIndex(ix))  
                     (*end case*))  
                 (*end case*))  
         (* end case *))  
     in  
         checkbody body  
     end  
   
33    
34  end; (* local *)  end; (* local *)
35    

Legend:
Removed from v.2520  
changed lines
  Added in v.2521

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0