Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/mid-to-low/step2.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/mid-to-low/step2.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2612, Wed May 7 02:58:55 2014 UTC revision 2680, Wed Aug 6 00:51:53 2014 UTC
# Line 1  Line 1 
1  (*hashs Ein Function after substitution*)  (*general function for scalars*)
2  structure step2 = struct  structure step2 = struct
3      local      local
4      structure DstIL = LowIL      structure DstIL = LowIL
# Line 8  Line 8 
8      structure E = Ein      structure E = Ein
9      structure S3=step3      structure S3=step3
10      structure genKrn=genKrn      structure genKrn=genKrn
11        structure tS= toStringEin
12    
13      in      in
14    
# Line 23  Line 24 
24  fun errS str=raise Fail(str)  fun errS str=raise Fail(str)
25    
26  (*Helpers for scalars*)  (*Helpers for scalars*)
27  fun mkCons(shape,rest)=S3.aaV(DstOp.cons(DstTy.TensorTy shape,0),rest,"Cons",DstTy.TensorTy(shape))  
28    fun mkCons(shape, rest)=let
29        val ty=DstTy.TensorTy shape
30        val a=DstIL.Var.new("Cons" ,ty)
31        val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
32        val _=print("###"^tS.toStringAll(ty,code))
33        in (a, [code])
34        end
35    
36    
37  val Sca=DstTy.TensorTy([])  val Sca=DstTy.TensorTy([])
38    val addR=DstOp.addSca
39    
40    
41  fun mkProdSca rest=S3.aaV(DstOp.prodSca,rest,"prodSca",Sca)  fun mkProdSca rest=S3.aaV(DstOp.prodSca,rest,"prodSca",Sca)
42  fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)  fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)
43  fun mkDivSca rest= S3.aaV(DstOp.divSca,rest,"divSca",Sca)  fun mkDivSca rest= S3.aaV(DstOp.divSca,rest,"divSca",Sca)
44  fun mkMultipleSca(ids,rator)=S3.mkMultiple(ids,rator,Sca)  fun mkMultipleSca(ids,rator)=S3.mkMultiple(ids,rator,Sca)
45  fun mkInt n= S3.aaV(DstOp.C(n),[],"Int",Sca)  fun mkInt n=S3.mkInt n
46    
47    
48    
49    
50    
51  fun prodIter(origIndex,index,nextfn,args)=(let  fun prodIter(origIndex,index,nextfn,args)=(let
52      val index'=List.map (fn (e)=>(e-1)) index      val index'=List.map (fn (e)=>(e-1)) index
# Line 76  Line 93 
93          in          in
94              (vA,code'@ccode)              (vA,code'@ccode)
95          end          end
96        | Iter _=raise Fail"index' is larger than origIndex"
97      in      in
98          Iter(empty,index',[],[],origIndex,0)          Iter(empty,index',[],[],origIndex,0)
99      end)      end)
100    
101  (*Get constant *)  (*Get constant *)
102  fun skeleton A=(case A  fun skeleton A=(case A
103      of [DstIL.ASSGN(_,DstIL.OP(DstOp.C c,_))]=>c      of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0
104        | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1
105        | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1
106      | _ => 9      | _ => 9
107      (*end case*))      (*end case*))
108    
# Line 127  Line 147 
147  (* general expressions-removes zeros*)  (* general expressions-removes zeros*)
148  fun generalfn(dict,(body,origargs,info))=let  fun generalfn(dict,(body,origargs,info))=let
149      val mapp=ref dict      val mapp=ref dict
150      (*val _=printMapp(!mapp)*)  
     val _=print "in General FUNCTIOn"  
151    
152    
153      fun gen body=let      fun gen body=let
154          fun AddcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end          fun AddcheckO ([],[],[])=let val (vA,A)=mkInt 0 in ([vA],A) end
155            | AddcheckO([],ids,code)=(ids,code)            | AddcheckO([],ids,code)=(ids,code)
156            | AddcheckO(e1::es,ids,code)=let            | AddcheckO(e1::es,ids,code)=let
157              val (a,b)=gen e1              val (a,b)=gen e1
# Line 153  Line 172 
172              end              end
173    
174          fun Sumcheck(sumx,e)=let          fun Sumcheck(sumx,e)=let
             val _=print "\n Found Sum"  
175              fun sumloop mapsum=let              fun sumloop mapsum=let
176                  val _ = mapp:=mapsum                  val _ = mapp:=mapsum
177                  val(vA,A)=gen e                  val(vA,A)=gen e
178                  in ([vA],A) end                  in (case (skeleton A)
179                        of 0 => ([],A)
180                        |  _ => ([vA],A)
181                        (*end case*))
182                    end
183    
184                    (*in ([vA],A) end*)
185              fun sumI1(left,(v,0,lb1),[],rest,code)=let              fun sumI1(left,(v,0,lb1),[],rest,code)=let
186                  val dict=insert(v, lb1) left                  val dict=insert(v, lb1) left
187                  val (vD,pre)= sumloop dict                  val (vD,pre)= sumloop dict
# Line 173  Line 197 
197                  val dict=insert(v, (s+lb1)) left                  val dict=insert(v, (s+lb1)) left
198                  val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)                  val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)
199                  in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end                  in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end
200                 | sumI1 _ =raise Fail"None Variable-index in summation"
201              val (E.V v,lb,ub)=hd(sumx)              val (E.V v,lb,ub)=hd(sumx)
202              in              in
203                  sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])                  sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])
204              end              end
205    
206          fun iterList(e,rator)= (case e          fun iterList(e, DstOp.addSca)=(case e
207                of ([],code)=>let val (vA,A)=mkInt 0 in (vA,A) end
208                | ([id1],code) => (id1,code)
209                | (ids,code)    => let
210                    val (vB,B)= mkMultipleSca(ids,addR)
211                    in (vB,code@B) end
212                (*end case*))
213    
214            | iterList(e,rator)= (case e
215              of ([id1],code) => (id1,code)              of ([id1],code) => (id1,code)
216              | (ids,code)    => let              | (ids,code)    => let
217                  val (vB,B)= mkMultipleSca(ids, rator)                  val (vB,B)= mkMultipleSca(ids, rator)
218                  in (vB,code@B) end                  in (vB,code@B) end
219              (*end case*))              (*end case*))
220    
221      in (case body      in (case body
222          of  E.Field _           => err 1          of  E.Field _           => err 1
223          | E.Partial _           => err 1          | E.Partial _           => err 1
# Line 197  Line 231 
231          | E.Const c             => mkInt c          | E.Const c             => mkInt c
232          | E.Epsilon(i,j,k)      => S3.evalEps(!mapp,i,j,k)          | E.Epsilon(i,j,k)      => S3.evalEps(!mapp,i,j,k)
233          | E.Delta(i,j)          => S3.evalDelta2(!mapp,i,j)          | E.Delta(i,j)          => S3.evalDelta2(!mapp,i,j)
234          | E.Tensor(id,ix)       => (print("\n Tensor: "^Int.toString(id));          | E.Tensor(id,ix)       => S3.mkSca(!mapp,(id,ix,info))
             S3.mkSca(!mapp,(id,ix,info)))  
235          | E.Neg e               => NegCheckO(gen e)          | E.Neg e               => NegCheckO(gen e)
236          | E.Sub (e1,e2)         => SubcheckO(gen e1,gen e2)          | E.Sub (e1,e2)         => SubcheckO(gen e1,gen e2)
237          | E.Div(e1,e2)          => let          | E.Div(e1,e2)          => let
# Line 211  Line 244 
244                      in (vD, A@B@D) end                      in (vD, A@B@D) end
245                  (*end case*))                  (*end case*))
246              end              end
247          | E.Add e               => (print "\n in Add";iterList(AddcheckO(e,[],[]),DstOp.addSca))          | E.Add e               => (iterList(AddcheckO(e,[],[]),addR))
248          | E.Prod e              => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)          | E.Prod e              => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)
249          | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let         (* | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let
250              val harg=List.nth(origargs,Hid)              val harg=List.nth(origargs,Hid)
251              val imgarg=List.nth(origargs,Vid)              val imgarg=List.nth(origargs,Vid)
252              val h=S3.getKernel(harg)              val h=S3.getKernel(harg)
253              val v=S3.getImage(imgarg)              val v=S3.getImage(imgarg)
254    val imgargnew=List.nth(args,Vid)
255    val v=S3.getImage(imgarg,imgargnew)
256              in              in
257                  genKrn.evalField(!mapp,(body,v,h,info))                  genKrn.evalField(!mapp,(body,v,h,info))
258              end              end
259    *)
260          | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),DstOp.addSca)          | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),addR)
261          (*end case*))          (*end case*))
262          end          end
263    

Legend:
Removed from v.2612  
changed lines
  Added in v.2680

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0