Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/mid-to-low/ein-to-scalar.sml
ViewVC logotype

View of /branches/vis15/src/compiler/mid-to-low/ein-to-scalar.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3649 - (download) (annotate)
Tue Feb 2 15:37:23 2016 UTC (4 years ago) by jhr
File size: 6166 byte(s)
working on merge
(* ein-to-scalar.sml
 *
 * Generate LowIR scalar computations that implement Ein expressions.
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure EinToScalar : sig

    val expand :
	  AvailRHS.t * ?? IntRedBlackMap.map * (Ein.param_kind list * Ein.ein_exp * LowIR.var list)
	    -> ??

  end = struct

    structure IR = LowIR
    structure Ty = LowTypes
    structure Op = LowOps
    structure Var = LowIR.Var
    structure E = Ein
    structure Mk = MkLowIR
    structure P = EinPP
    structure IMap = IntRedBlackMap

    fun evalField e =  FieldToLow.evalField e
    fun indexTensor e = Mk.indexTensor e
    fun mkSubSca e =  Mk.mkSubSca e
    fun mkProdSca e = Mk.mkProdSca e
    fun mkDivSca e =  Mk.mkDivSca e
    fun mkMultiple e = Mk.mkMultiple e
    fun evalG e =  Mk.evalG e
    fun mkOp1 e =  Mk.mkOp1 e
    fun insert  (k, v) d =  IMap.insert (d, k, v)
    fun errField e = raise FaIR ("Invalid Field Here:"^ (P.expToString e))

    fun mapIndex (mapp, id) => (case IMap.find(mapp, id)
	   of SOME x => x
	    | NONE => raise Fail(concat["mapIndex(_, V ", Int.toString id, "): out of bounds"])
	  (* end case *))

    fun expand (avail, mapp, (params, body, args)) = let
        val mapp = ref dict
        val info =  (e, args)
        fun gen (avaIR, body) =  let
            (*********sumexpression ********)
            fun tb n =  List.tabulate (n, fn e =>e)
            fun Sumcheck (avaIR, sumx, e) = let
                fun sumloop (avaIR, mapsum) = (mapp:= mapsum; gen (avaIR, e))
                fun sumI1 (avaIR, left, (v, [i], lb1), [], rest ) = let
                    val dict = insert (v, lb1+i)  left
                    val  (avaIR, vD) =  sumloop (avaIR, dict)
                    in  (avaIR, rest@[vD])  end
                |  sumI1 (avaIR, left, (v, i::es, lb1), [], rest) = let
                    val dict = insert (v, (i+lb1))  left
                    val  (avaIR, vD) = sumloop  (avaIR, dict)
                    in sumI1 (avaIR, dict, (v, es, lb1), [], rest@[vD])  end
                | sumI1 (avaIR, left, (v, [i], lb1), (E.V a, lb2, ub2) ::sx, rest) =
                    sumI1 (avaIR,  insert (v, lb1+i)  left, (a, tb (ub2-lb2+1), lb2), sx, rest)
                | sumI1 (avaIR, left, (v, s::es, lb1), (E.V a, lb2, ub2) ::sx, rest) = let
                    val dict = insert (v, (s+lb1))  left
                    val xx = tb (ub2-lb2+1)
                    val  (avaIR, rest') = sumI1 (avaIR, dict, (a, xx, lb2), sx, rest)
                    in sumI1 (avaIR, dict, (v, es, lb1), (E.V a, lb2, ub2) ::sx, rest')  end
                | sumI1 _ = raise FaIR"None Variable-index in summation"
                val  (E.V v, lb, ub) = hd (sumx)
                in
                    sumI1 (avaIR, !mapp, (v, tb (ub-lb+1), lb), tl (sumx), [])
                end
            in  (case body
                of E.Field _           => errField (body)
                | E.Partial _          => errField (body)
                | E.Apply _            => errField (body)
                | E.Probe _            => errField (body)
                | E.Conv _             => errField (body)
                | E.Krn _              => errField (body)
                | E.Img _              => errField (body)
                | E.Lift _             => errField (body)
                | E.Value v            => Mk.intToRealLit (avaIR, mapIndex (!mapp, v))
                | E.Const c            => Mk.intToRealLit (avaIR, c)
                | E.Delta _            => evalG (avaIR, !mapp, body)
                | E.EpsIRon _          => evalG (avaIR, !mapp, body)
                | E.Eps2 _             => evalG (avaIR, !mapp, body)
                | E.Tensor (id, ix)    => indexTensor (avaIR, !mapp, (params, args, id, ix, Ty.TensorTy []))
                | E.Op1 (E.Neg, e1)    => let
		    val (avaIR, vA) = gen (avaIR, e1)
		    val (avaIR, vB ) =  Mk.intToRealLit (avaIR, ~1)
                    in
		      mkProdSca  (avaIR, [vB, vA])
		    end
                | E.Op1 (op1, e1) => mkOp1 (op1, gen (avaIR, e1))
                | E.Op2 (E.Sub, e1, e2)   => let
		    val  (avaIR, vA) = gen (avaIR, e1)
		    val  (avaIR, vB) = gen (avaIR, e2)
                    in
		      mkSubSca (avaIR, [vA, vB])
		    end
                | E.Opn (E.Add, e)       =>
                    let
                        fun iter (avaIR, [], ids) =  mkMultiple (avaIR, List.rev ids, Op.addSca, Ty.TensorTy [])
                          | iter (avaIR, e1::es, ids) = let
                            val  (avaIR, a) = gen (avaIR, e1)
                            in  iter (avaIR,es,a::ids) end
                    in iter (avaIR, e, []) end
                | E.Opn (E.Prod, e)      =>
                    let
                        fun iter (avaIR, [], ids) =  mkMultiple (avaIR, List.rev ids, Op.prodSca, Ty.TensorTy [])
                          | iter (avaIR, e1::es, ids) = let
                        val  (avaIR, a) = gen (avaIR, e1)
                        in  iter (avaIR,es,a::ids) end
                    in iter (avaIR, e, []) end
                | E.Op2 (E.Div, e1 as E.Tensor (_, [_]), e2 as E.Tensor (_, [])) =>
                        gen (avaIR, E.Opn (E.Prod, [E.Op2 (E.Div, E.Const 1, e2), e1]))
                | E.Op2 (E.Div, e1, e2)    =>
                    let
                        val  (avaIR, vA ) = gen (avaIR, e1)
                        val  (avaIR, vB) = gen (avaIR, e2)
                    in mkDivSca (avaIR, [vA, vB]) end
                | E.Sum (x, E.Opn (E.Prod, (E.Img (Vid, _, _) ::E.Krn (Hid, _, _) ::_)))
                                       => evalField (avaIR, !mapp, (body, info))
                | E.Sum (sumx, e)        =>
                    let
                        val (avaIR,ids)= Sumcheck (avaIR, sumx, e)
                    in mkMultiple (avaIR, ids, Op.addSca, Ty.TensorTy []) end
                | _                    => raise FaIR"unsupported ein-exp "
                 (*end case*))
                end
         in
           gen (setOrig, body)
         end

    end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0