Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/type-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/type-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2498 - (download) (annotate)
Wed Oct 30 17:35:23 2013 UTC (7 years, 11 months ago) by cchiw
File size: 8302 byte(s)
add type checking
structure TypeEin = struct

local

structure E = Ein
structure P = Printer

in



datatype greek_type=epsTy|deltaTy

datatype ein_type   = realTy
    | G of E.mu list
    | ten of E.mu list
    | fld of E.mu list
    | fldmid of E.mu list
    | imageTy of E.mu list
    | kernelTy of E.mu list
    | partialTy of E.mu list
    | errTy

(*can only be C type*)
fun position([realTy])=realTy
    | position((realTy)::es)= position(es)
    | position _= errTy


fun err(msg)= (errTy)


fun printIndex([])=""
    | printIndex(E.C x::ix)=  String.concat[Int.toString(x),printIndex(ix)]
    | printIndex(E.V v::ix)= String.concat[Int.toString(v),printIndex(ix)]

fun printTy(ty)= (case ty of realTy=> print "realTy"
    | G _=> print "Greek Type"
    | ten ix=> print (String.concat["TEN-" ,printIndex(ix)])
    | fld  ix=> print (String.concat["FLD-HIGH-" ,printIndex(ix)])
    | fldmid ix=> print (String.concat["FLD-MID-" ,printIndex(ix)])
    | imageTy  ix=> print (String.concat["Image-" ,printIndex(ix)])
    | kernelTy  ix=> print (String.concat["KRN-" ,printIndex(ix)])
    | partialTy  ix=> print (String.concat["Partial-" ,printIndex(ix)])
    | errTy=> print "err")



(*single index no duplicates*)
fun sortIndex(ilist)=let
    fun sort([],ix)=ix
    | sort (E.C _::es,ix)=sort(es,ix)
    | sort(e::es,ix)=let
        val r=List.find (fn(x) => x=e) ix
        in (case r
            of NONE=>sort(es, ix@[e])
            |_=> sort(es, ix)
            (*end case*))
            end 
    in sort(ilist,[]) end 


(*compares one list of indices to another*)
fun sortFldIndex(part,ilist)=let
    fun sort([],ix)=ix
        | sort (E.C _::es,ix)=sort(es,ix)
        | sort(e::es,ix)=let
            val r=List.find (fn(x) => x=e) ix
            in (case r
                of NONE=>sort(es, ix@[e])
                |_=> sort(es, ix)
                (*end case*))
            end
        in sort(part,ilist) end


fun removeSumIndex(ilist,sum)=let
    fun sort([],ix,rest)=rest
    | sort (E.C _::es,ix,rest)=sort(es,ix,rest)
    | sort(e::es,ix,rest)=let
        val r=List.find (fn(x) => x=e) ix
        in (case r
        of NONE=>sort(es, ix,rest@[e])
        |_=> sort(es, ix,rest)
        (*end case*))
        end
    in sort(ilist,sum,[]) end



fun evalAdd [fld f]= fld f
    | evalAdd [ten t]= ten t
    | evalAdd [realTy]= realTy
    | evalAdd [imageTy i]= imageTy i
    | evalAdd(realTyy::realTy::el)=evalAdd(realTy::el)
    | evalAdd(fld F1:: fld F2::el)=evalAdd(fld F1::el)
        (*if (F1=F2) then evalAdd(fld F1::el)
        else errTy*)
    | evalAdd(ten T1::ten T2::el)=evalAdd(ten T1::el)
        (* if(T1=T2) then evalAdd(ten T1::el)
        else errTy*)
    | evalAdd(imageTy i::imageTy _::el)=evalAdd(imageTy i::el)
    | evalAdd _= errTy

fun evalProd([])=errTy
    | evalProd(errTy::el)= errTy
    | evalProd([e1])=e1
    | evalProd(realTy::el)=evalProd el
    | evalProd(G g::ten t::es)= evalProd (ten(t@g)::es)
    | evalProd(G g::fld t::es)= evalProd (fld(t@g)::es)
    (*Fields..*)
    | evalProd [fld f,realTy] =fld f
    | evalProd [fld f,G _] =fld f
    | evalProd [fld t, partialTy p]= ten(sortIndex(t@p))
    | evalProd([fld f,_])= err "can not multiply field and other type "
    (* Tensors *)
    | evalProd [ten t ,realTy]= ten t
    | evalProd [ten t,G _]= ten t
    | evalProd [ten t,ten T2]=ten(sortIndex(t@T2))
    | evalProd [ten t, partialTy p]= ten(sortIndex(t@p))
    (*kernels*)
    | evalProd [kernelTy k,realTy]= kernelTy k
    | evalProd [kernelTy k,kernelTy _]= kernelTy k
    | evalProd [kernelTy k,imageTy i]= fldmid(i@k)
    (*Partials*)
    | evalProd [partialTy p,realTy]= partialTy p
    | evalProd [partialTy p,G _]= partialTy p
    | evalProd [partialTy p,ten T]= ten(sortIndex(T@p))
    | evalProd [partialTy p,fld T]= ten(sortIndex(T@p))
    | evalProd [partialTy p,partialTy p2]= partialTy(p@p2)
    (*Image *)
    | evalProd [imageTy i,realTy]= imageTy i
    | evalProd [imageTy i ,G _]= imageTy i
    | evalProd [imageTy i ,imageTy i2]=  imageTy(i@i2)
    | evalProd [imageTy i, kernelTy k]= fldmid(i@k)
    | evalProd [a,b]= errTy
    | evalProd (e1::es)= evalProd [e1,evalProd(es)]

fun evalSub(a,b)=(case (a,b)
    of (realTy, realTy)=>realTy
    | (ten T1, ten T2)=>ten T1
    |(fld F1,  fld F2)=>fld F1
    | (imageTy i, imageTy _)=>  imageTy i
    | (fldmid f, fldmid _)=> fldmid f
    | (fldmid f, realTy)=> fldmid f
    | (realTy, fldmid  f)=> fldmid f
    | (fld f, realTy) => fld f
    | (realTy, fld f) => fld f
    | (imageTy i, realTy) => imageTy i
    | (realTy, imageTy i) => imageTy i
    |_=>errTy)

fun evalDiv(a,b)=(case (a,b)
    of(realTy ,realTy)=>realTy
    | (fld f, realTy)=> fld f
    | (ten t, realTy)=>ten t
    | _=>errTy)

fun evalProbe(a,b,phase)=if (phase>1) then err "wrong phase for Probe op"
    else (case (a,b)
        of (fld f,ten _)=>fld f
        | (fld f, realTy)=>fld f
        | (fld f, _)=> err "wrong pos for field probe"
        |  _=>err "Not a fieldTy in probe"
        (*end case*))

fun evalKrn(dels,phase)=  if (3>phase) then err "wrong phrase for kernel"
    else let
        fun size([])=[]
        | size((i ,j)::dels)= [j]@ size(dels)
        in  kernelTy(size(dels)) end

fun evalApply(e1,e2,phase)=
    if (phase>1) then err "wrong phase for apply"
    else (case (e1,e2)
        of (partialTy a, fld b)=>fld(a@b)
        |_ =>errTy
        (*end case*))


fun evalSum(sx,m)=(case m
        of ten ix=>(
            let val ix'=removeSumIndex(ix,sx)
                in (case ix' of [] => realTy
            |_ =>ten ix')end
            (*end case*))
            | fld ix=>(
                let val ix'=removeSumIndex(ix,sx)
                in (case ix' of [] => realTy
                |_ =>fld ix')end
                (*end case*))
            | fldmid ix =>m
            | realTy=>realTy
            |_=> errTy
        (*end case*))


fun checkTenParam(id,params, ix)=
    if(id>length(params))then (print "in here";errTy)
    else(let
    val p=List.nth(params,id)
    in (case p
    of E.TEN => let
        val m = (sortIndex(ix))
        in(case m
            of []=> realTy
            | _=> ten m)
        end
    |  _ =>errTy)
    end)

fun checkFldParam(id, params,ix)=
        if(id>length(params))then (print "in here";errTy)
        else(
            let
    val p=List.nth(params,id)
    in (case p
        of  E.FLD _ =>fld ix
        |_=> errTy)
    end)


fun checker (Ein.EIN{params, index, body},phase) = let

 


    (*make sure the finished expression is the shape it is suppose to be *)
    fun checkbody term = (case term
        of E.Const r => realTy
        | E.Tensor(id, ix) =>checkTenParam(id,params,ix)
        | E.Delta(i,j)=>G [i,j]
        | E.Value(ix) =>realTy
        | E.Epsilon(i,j,k) => G [E.V i,E.V j,E.V k]
        | E.Sum (sx,e1)=> evalSum(sx,checkbody e1)
           
           
        | E.Neg e1=> checkbody e1
        | E.Partial a =>   partialTy(sortIndex(a))
        | E.Add es => evalAdd(List.map checkbody es)
        | E.Sub(e1, e2) =>evalSub(checkbody e1 ,checkbody e2)
        | E.Prod el => evalProd(List.map checkbody el)
        | E.Div(e1,e2)=> evalDiv (checkbody e1,checkbody e2) 
        (*Phase dependent operators*)
        | E.Field(id, alpha)=>
            if (phase>1) then err "wrong phase for Field"
            else checkFldParam(id,params,alpha)
        | E.Apply(e1, e2)=> evalApply(checkbody e1,checkbody e2,phase)
        | E.Probe (e1,e2)=>evalProbe(checkbody e1, checkbody e2,phase)
        | E.Conv (fid,alpha, tid, beta)=>
            if (phase>1) then err "wrong phase for convolution"
            else (case checkFldParam(fid, params,alpha)
                of fld f=> fld(sortFldIndex(beta,f))
                |_=> errTy
                (*end case*))

        (*Phase 2 Mid-IL*)
        | E.Krn (_,dels,_) =>evalKrn(dels,phase)
        | E.Img(id,ix,pos)=>
            if(3 >phase) then errTy
            else (case position(List.map checkbody pos)
                of errTy=> err "Not an image position"
                | _ =>(case checkFldParam(id, params,[])
                    of errTy=> errTy
                    |_=>imageTy(sortIndex(ix))
                    (*end case*))
                (*end case*))
        (* end case *))
    in
        checkbody body
    end


end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0