Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/ein/einsurface.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/ein/einsurface.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3686 - (download) (annotate)
Fri Feb 26 16:01:37 2016 UTC (4 years, 7 months ago) by cchiw
File size: 8083 byte(s)
added einsurface
(* creates EIN operators 
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure einSurface = struct

    local

    structure E = Ein
    structure P=Printer
    in
 
    fun specialize(alpha,inc)=  List.tabulate(length(alpha), (fn(x)=>E.V (x+inc)))
    fun sumIds(n,inc,alpha)=let
        val vs=List.tabulate(n, (fn v=>E.V (v+inc)))
        in ListPair.map  (fn(v,i)=>(v, 0, i-1))   (vs, alpha) end
     fun sumIds2(n,i)=List.tabulate(n, (fn v=>(E.V v, 0, i)))

    val subst_flag = 1(*here*)

    fun mkCx es =List.map (fn c => E.C (c,true)) es
    fun mkCxSingle c = E.C (c,true)
    fun mkSxSingle c = E.C (c,false)


    (* Ty.TensorTy [alpha] A;
    * 	λT <.. T[ixs] .. >_zeta (A)
    * Index ix is the nth index in ixs
    * range = List.nth(alpha,n)
    * bound = List.nth(zeta, ix)
    * range=bound
    *
    * ex.  Ty.TensorTy[d0,d1]A;  λT <T[j,i] .. >_zeta (A)
    * zeta[j]=d0
    * zeta[i]=d1
    * length(zeta) > i,j
    *)
    fun checkbound([],_,_,_,_) =true
     | checkbound(ix::ixs,n,alpha, zeta,argn) = let
        val range = List.nth(alpha,n)
        val bound = List.nth(zeta, ix)
        in if (range=bound)
            then checkbound(ixs,n+1,alpha, zeta,argn)
            else raise Fail(String.concat["Index(",Int.toString(ix),") to argument ", Int.toString(argn)," is bound by the tensor dimension ", Int.toString(range), " but bound by the output type as",Int.toString(bound)])
        end

    fun shapeToTy alpha = String.concat["Ty.TensorTy[",String.concatWith","(List.map Int.toString alpha),"]"]
    
    (*check that each index in alpha is accounted for in zeta*)
    fun checkB(alpha,beta, zeta, ns1,ns2) = let
        val n =length(zeta)
        fun iter [] =true
        | iter (e1::es)=
            if(n>e1)
            then iter(es)
            else raise Fail(String.concat["Index(",Int.toString(e1),") is outside range. Range=[0-",Int.toString(n),"]"])
    
        in (iter(ns1) andalso iter(ns2)) end
      
    fun matchLength(alpha,ns1,argn)=  ((length(alpha)=length(ns1)))
    
    
    
    
    (*check shape of variable is number of indices*)
    fun checkA(alpha,beta, zeta, ns1,ns2) =
        if (matchLength(alpha,ns1,1))
            then if (matchLength(beta,ns2,2))
                    then true
                    else raise Fail(String.concat["shape of second variable", shapeToTy beta,"does not match number of indices"])
            else raise Fail(String.concat["shape of first variable", shapeToTy alpha," does not match number of indices"])
        
    fun checkZeta(body1,zeta)= let
        val (tshape, sizes, body) = cleanIndex.cleanIndex(body1, zeta, [])
        fun m es =String.concatWith "," (List.map (fn e=>Int.toString(e)) es)
        fun m2 es =String.concatWith ", " (List.map (fn E.V e=>"E.V"^Int.toString(e)) es)

        val str=String.concat["\n Observed output type:", m sizes, "Set Type:",m zeta]
        val _ =print(String.concat["\n Body:",P.printbody(body),str])
        val _ =print(String.concat["\n tshape:",m2 tshape,"\n"])
        in
            if (sizes=zeta) then true
            else raise Fail str
        end
     
     
     fun m3 es= List.map (fn e=> E.V e) es
     fun m es =String.concatWith "," (List.map (fn e=>Int.toString(e)) es)
     fun m2 es =String.concatWith ", " (List.map (fn E.V e=>"E.V"^Int.toString(e)) es)
     
    
     fun cleanToString(tshape,sizes,body)= let
        (* printing output to clean index*)
        val _ =print(String.concat["\n Body:",P.printbody(body)])
        val _ = print(String.concat["\n Observed output type:", m sizes])
        val _ =print(String.concat["\n tshape1:",m2 tshape,"\n"])
        in 1 end



     
     (* checkTshape: is the tensor shape what we expect?
     * expected: expected shape of tensor
     * observed: observed shape of tensor
     *  found by computing tshape T_ns1
     *      then indexing alpha for the it's size
     *
     *
     *)
     fun checkTshape(alpha,expect,ns1,argN)= let
        val (tshape, sizes, body) = cleanIndex.cleanIndex(E.Tensor(0,m3 ns1), expect, [])
        val observed = List.map (fn (E.V e) => List.nth(alpha, e)) tshape
        val str =String.concat["\n\nArgument ",Int.toString(argN)," Expected shape:", m expect," Observed shape:",m observed, "\n  tshape:",m2 tshape]
        val _ = str
        in
            if (observed = expect) then true else raise Fail str
        end
     
     fun iter(alphas,zeta,id) = let
        val n=List.length(alphas)
        fun getSol (cnt,ui)= (String.concat[" U_",Int.toString(cnt)," =  E.V ",Int.toString(ui)])
        fun f([],_,soln)  = soln
        | f(a1::a1n,cnt,soln)= let
            val cur="U_"^ Int.toString(cnt)
            val _ = ("\nAttempting"^" to find solution for index "^ cur^". We know A["^Int.toString(cnt)^"]="^Int.toString(a1))
                 
            fun g(0,sol) = sol
            | g(ui,sol) = let
                val ui'=ui-1
                val x=List.nth(zeta,ui')
                val _ = (String.concat["\n If ",cur,"=",Int.toString(ui'), " then zeta[",cur,"]=",Int.toString(x)," "])
                (*List.nth zeta  ui=ai*)
                in
                    if (x=a1)
                    then ((getSol(cnt,ui'));g(ui',ui'::sol))
                    else g(ui',sol)
                end
           val sol = g(n,[])
           in f(a1n,cnt+1,sol::soln)
           end

        fun k([],_,soln) =soln
         | k(e1::es ,cnt,soln)=let
            val _ =("\n length e1: "^Int.toString(length(e1)))
            val _ =("\nAttempt: U_"^Int.toString((cnt)))
            fun l([],sol)=sol
            | l(f1::fs,sol)= let

                val x = List.nth(alphas,f1)
                val z1=List.nth(zeta, cnt)
                    val _ =(String.concat["\n zeta[",Int.toString(cnt),"]=",Int.toString(z1),"   alpha[",Int.toString(f1),"]=",Int.toString(x)])
                in if(x=z1)
                    then ((getSol(cnt,f1));l(fs,f1::sol))
                    else  l(fs,sol)
              end
           val sol = l(e1,[])
           in k(es,cnt+1,sol::soln) end
        val _ = print("\n Argument "^Int.toString(id))
        val _ =("\n round 1: zeta[ui]=ai:")
        val solns= f(alphas,0,[])
        val _ ="\n round 2: alphas[ui]=zi:"
        val soln =k(List.rev solns,0,[])
           val _ =print"\n Possible solutions:"
        fun m([],_)= ""
        | m(e1::es,cnt) = String.concat["\n U_",Int.toString(cnt)," = E.V ", String.concatWith", E.V " (List.map (fn e=> Int.toString(e)) e1),",",m(es,cnt+1)]
        val _ = print(m(List.rev soln,0))
        in 1 end
  
    fun einchecks(alpha,beta,zeta,ns1,ns2)= let
        val _ = checkA(alpha,beta, zeta, ns1,ns2)
        val _ = checkB(alpha,beta, zeta, ns1,ns2)
        val _ =  checkbound(ns1,0,alpha, zeta,1)
        val _ =  checkbound(ns2,0,beta, zeta,2)
        in true end
        
    fun einMultTT(alpha,beta, zeta, ns1,ns2) = let
        val talpha= List.map (fn e => E.V e) ns1
        val tbeta= List.map (fn e => E.V e) ns2
        val e= E.EIN{
            params = [E.TEN(1,alpha),E.TEN(1,beta)],
            index = zeta,
            body = E.Opn(E.Prod,[E.Tensor(0,talpha),E.Tensor(1,tbeta)])
            }
        val _ =  einchecks(alpha,beta,zeta,ns1,ns2)
        val _ =print("\n"^P.printerE(e)^"\n")
        in e end


    fun einAddTT(alpha,beta, zeta, ns1,ns2) = let
        val talpha= List.map (fn e => E.V e) ns1
        val tbeta= List.map (fn e => E.V e) ns2
        val body1=E.Opn(E.Add,[E.Tensor(0,talpha),E.Tensor(1,tbeta)])
        val e= E.EIN{
            params = [E.TEN(1,alpha),E.TEN(1,beta)],
            index = zeta,
            body =  body1
        }
        val _ =print("\n"^P.printerE(e)^"\n")
        val _ =iter(alpha,zeta,1)
        val _ =iter(beta,zeta,2)
        val _ =checkTshape(alpha,zeta,ns1,1)
        val _ =checkTshape(beta,zeta,ns2,2)
        val _ =  einchecks(alpha,beta,zeta,ns1,ns2)


        in e end



  end; (* local *)

    end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0