Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/specialize.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/specialize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2485 - (download) (annotate)
Mon Oct 21 16:34:57 2013 UTC (6 years ago) by cchiw
File size: 4581 byte(s)
 
(* specialize.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure Specialize : sig

    val transform : GenericEin.ein * (int list) list * (int list) -> Ein.ein

  end = struct

    structure G = GenericEin
    structure E = Ein

  (* This function takes a shape-polymorphic Einstein expression and specializes it to a
   * particular shape.  For each multindex (MX) in the argument index list, the order of
   * the instantiated index is supplied.
   *)

    (*Creates an array of -1. lp function goes through each index in list.
        -Updates the *mapping , and returns an Index list.
        -the body is rewritten
        -The index in terms is rewritten using the array.
    *)
    fun transform (G.EIN{params, index, body}, inst, paraminst) = let
        (* allocate the remapping array *)
          val mapping = Array.array(length index, [~1])
          fun mapIx ix =  Array.sub(mapping, ix)

          fun yu (G.C ix)= [E.C ix]
            | yu (G.V ix)= let
                val u= mapIx ix
                fun create([])=[]
                  | create(i::es)=[E.V i]@create(es)
                in create(u) end
(*
        fun sumIndex (G.C ix,lb, ub)= [(E.C ix,lb,ub)]
        | sumIndex (G.V ix,lb,ub)= let
            val u= mapIx ix
            fun create([])=[]
            | create(i::es)=[(E.V i,lb,ub)]@create(es)
            in create(u) end
*)


          fun mapSingleIx(ix) = (case mapIx ix
                 of [ix] => ix
                  | _ => raise Fail "unexpected multiindex"
                (* end case *))

        (* initialize the remapping and compute the new index *)
        (* i-position in array, j:DeB number *)
          fun lp (i, j, [], []) = []
            | lp (i, j, G.MX::r, dn::r') =
                let val n = length(dn) in 
                (
                Array.update(mapping, i, List.tabulate(n, fn k => j+k));
                List.tabulate(n, fn k => E.IX(List.nth(dn,k))) @ lp(i+1, j+n, r, r')) end 
            | lp (i, j, G.IX::r, [dn]::inst) = (
                Array.update(mapping, i, [j]);
                E.IX(dn):: lp(i+1, j+1, r, inst))
            | lp (i, j, G.SX::r, [lb,ub]::inst) = (
                Array.update(mapping, i, [j]);
                E.SX(lb,ub):: lp(i+1, j+1, r, inst))
            | lp _ = raise Fail "multiindex/instantiation mismatch"
          val index' = lp (0, 0, index, inst)
        (* rewrite the body using the mapping *)

          fun handleMu(G.C(ix))= E.C ix
            | handleMu(G.V(ix))= E.V ix

          fun handleBeta(a,b)=(handleMu a, handleMu b)
        

          fun rewrite e = (case e
                 of G.Const r => E.Const r
                | G.Tensor(tid, ix) => (E.Tensor(tid, List.concat(List.map yu ix)))


                | G.Field(fid, alpha) => E.Field(fid, List.concat(List.map yu alpha))

                | G.Krn(betas,pos) =>E.Krn(List.map handleBeta betas,rewrite pos)


                | G.Delta(ix) => E.Delta(handleBeta ix)
                | G.Value(ix)=>E.Value(ix)
                | G.Epsilon(ix, jx, kx) => E.Epsilon(mapSingleIx ix, mapSingleIx jx, mapSingleIx kx)
                | G.Sum (ix,e1)=> E.Sum ( List.concat(List.map yu ix), (rewrite e1))
                | G.Neg e1 => E.Neg(rewrite e1)
                | G.Add el => E.Add(List.map rewrite el)
                | G.Sub(e1, e2) => E.Sub(rewrite e1, rewrite e2)
                | G.Prod el => E.Prod(List.map rewrite el)
                | G.Div(e1,e2) => E.Div(rewrite e1, rewrite e2)
                | G.Partial alpha =>E.Partial(List.concat(List.map yu alpha))
                | G.Apply(e1, e2)=> E.Apply(rewrite e1, rewrite e2)
                | G.Conv(e1, alpha)=> E.Conv(rewrite e1, List.concat(List.map yu alpha))
                | G.Probe (e1,e2)=> E.Probe(rewrite e1,rewrite e2)

                | G.Img((e1,pos),h)=> E.Img((rewrite e1, List.map rewrite pos),List.map rewrite h)

                (* end case *))

        fun writeParam([G.TEN],a::[]) = raise Fail "param index too many"
        | writeParam(G.TEN::param, list) =
                [E.TEN]@writeParam(param, list)
          | writeParam(G.FLD::param, d::list) =
            [E.FLD(d)]@writeParam(param, list)

            | writeParam([],_) = ([])
          | writeParam _= raise Fail "param index mismatch"
                
        val params'=  writeParam(params, paraminst)

in
            E.EIN{
                params = params',
                index = index',
                body = rewrite body
              }
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0