Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2498 - (download) (annotate)
Wed Oct 30 17:35:23 2013 UTC (8 years, 1 month ago) by cchiw
File size: 6863 byte(s)
add type checking
(* examples.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)


(*Expand ProebConv to Probe of individual field *)
structure Expand = struct

    local
   
    structure E = Ein

    in
 

(*
Dictionary created for Tensor positions
Tensor X=> fractional and integer tensors in mid-il operators
Save in variable names.

*)


fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s

fun lookup k d = d k

(*
createDels=> creates the kronecker deltas for each Kernel
For each dimesnion a, and each index in derivative b create element (a,b)
*)
fun createDels([],_)= []
  | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
        



fun Position(idt,dict,params,args)=let
    val l=lookup idt dict
    in (case l
        of NONE =>let
            val pos1=length params
            val pos2=pos1+1
            val dict'=insert(idt,(pos1,pos2)) dict
            val params'=params@[E.TEN,E.TEN]
            in (pos1,pos2, dict',params',args)
            end
            (*Create new fractional, and n variables,and returns fresh ids*)
        | SOME (fid,nid)=>(fid,nid, dict,params,args)
        (*end case*))
    end



fun expandEinProbe(params,body,index,d,args)=(case body
    of E.Probe(E.Conv(id,alpha,kid,deltas),E.Tensor(idt,alphat)) =>
        if(id+1>length params) then (print "not enough params" ;(params,body,index,d,args))
        else (case List.nth(params,id)
            of E.FLD(dim)=>
            let
            val s=2 (*support*)
            val (fid,nid,d',params',args')=Position(idt,d,params,args)
            val shift=length index
        


            (*createIndex creates summation Index for Index*)
            fun createIndex(0)= []
            | createIndex(dim)=[E.SX(1-s,s)]@ createIndex(dim-1)

            (*sumIndex creating summaiton Index for body*)
            fun sumIndex(0)=[]
            |sumIndex(dim)= sumIndex(dim-1)@[E.V (dim+shift-1)]

            (*createKRN Image field and kernels *)
            fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(id,alpha,imgpos)] @rest)
            | createKRN(dim,imgpos,rest)=
                let
                val dim'=dim-1
                val sum=dim'+shift
                val dels=createDels(deltas,dim')
                val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
                val rest'= E.Krn(kid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end


            val i=createKRN(dim, [],[])
            val esum=sumIndex dim
            val index'=index@ createIndex dim

            in (params', E.Sum(esum, i),index', d',args') end
        | _=>(print "err: non field in param spot";(params,E.Const(0.0) ,index,d,args))
        (*end case*))
    |_=>(print "unexpected body" ;(params,body,index,d,args))
    (*end case*))



(**********     trasform to world space***********)

(* for gradients, etc. we have to transform back to world space *)
(*
val probeCode = if (*(k > 0)*)
length(deltas)>0
then let
(* for gradients, etc. we have to transform back to world space *)
val ty = DstV.ty result
val tensor = DstV.new("tensor", ty)



val alpha= List.take(dim, length(dim)-1)
val ilist=  [alpha,  tl(tensor), hd(tensor)]
val TensorToWorldSpace=S.tranform(EinOp.innerProduct, ilist,[])
val xform = assignEin(result, TensorToWorldSpace, [img, tensor])


*)



(*create set of positions
val dim = ImageInfo.dim v
val s = Kernel.support h
val vecsTy =createVec(2*s)
val vecDimTy = createVec(dim)
val translate=DstOp.Translate v
val transform=DstOp.Transform v

(* generate the transform code *)
val x = DstV.new ("x", vecDimTy)	(* image-space position *)
val f = DstV.new ("f", vecDimTy)
val nd = DstV.new ("nd", vecDimTy)
val n = DstV.new ("n", DstTy.iVecTy dim)
val M = DstV.new ("M", transform)
val T = DstV.new ("T", translate)

val sub= S.transform(EinOp.subTensor,[dim],[])


(* M_ij x_i*)
val MXop=S.transform(EinOp.innerProduct,[[dim],[],[dim]],[])
val MX = DstV.new ("MX", MXop)



val PosToImgSpace=S.transform(EinOp.addTensor,[[dim]],[])


val toImgSpaceCode = [
assignEin(MX, Mxop, [M,pos]), (*Just added*)
assignEin(x, PosToImgSpace,[MX,T])
assign(nd, DstOp.Floor dim, [x]),
assignEin(f, sub,[x,nd]),
assign(n, DstOp.RealToInt dim, [nd])
]

*)

(*copied from high-to-mid.sml*)
fun expandEinOp ( Ein.EIN{params, index, body}, args) = let

    val dummy=E.Const 0.0  (*tmp variables*)

    fun rewriteBody exp= let
        val (p,body,ix,d,args')= exp
        in (case body
        of E.Const _=>exp
        | E.Tensor _=>exp
        | E.Krn _=>exp
        | E.Delta _=>exp
        | E.Value _ =>exp
        | E.Epsilon _=>exp
        | E.Partial _=>exp
        | E.Img _=>  exp
        | E.Conv _=>(p,dummy,ix,d,args')
        | E.Field _ =>(p,dummy,ix,d,args')
        | E.Apply _ =>(p,dummy,ix,d,args')
        | E.Neg e=> let
            val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')
            in
                (p',E.Neg body',ix',d',args'')
            end
        | E.Sum (c,e)=> let
            val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')
            in
                (p',E.Sum(c,body'),ix',d',args'')
            end
        | E.Probe(E.Conv _, _) =>expandEinProbe exp
        | E.Sub(a,b)=>let
            val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')
            val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')
            in (pb,E.Sub( a', b'),bx,db,args''')
            end
        | E.Div(a,b)=>let
            val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')
            val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')
            in (pb,E.Div( a', b'),bx,db,args''')
            end
        | E.Add es=> let
            fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Add done, ix1,d1,args')
            | addFilter(p1,ix1,d1, e::es,done,args')=let
                val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')
                in  addFilter(p2, ix2, d2, es, done@[e2],args'')
                end
            in
                addFilter(p,ix,d,es,[],args')
            end
        | E.Prod es=> let
            fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Prod done, ix1,d1,args')
            | addFilter(p1,ix1,d1, e::es,done,args')=let
                val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')
                in  addFilter(p2, ix2, d2, es, done@[e2],args'')
                end
            in
            addFilter(p,ix,d,es,[],args')
            end
        | E.Probe _=> exp

        (* end case *))
        end 


    val empty =fn key =>NONE
    val (params',body',ix',_,args')=rewriteBody(params,body,index,empty,args)
    val newbie=Ein.EIN{params=params', index=ix', body=body'}
    in (newbie,args') end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0