Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/mid-to-low/step2.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/mid-to-low/step2.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2612, Wed May 7 02:58:55 2014 UTC revision 2838, Tue Nov 25 03:40:24 2014 UTC
# Line 1  Line 1 
1  (*hashs Ein Function after substitution*)  (*general function for scalars*)
2  structure step2 = struct  structure step2 = struct
3      local      local
4      structure DstIL = LowIL      structure DstIL = LowIL
# Line 8  Line 8 
8      structure E = Ein      structure E = Ein
9      structure S3=step3      structure S3=step3
10      structure genKrn=genKrn      structure genKrn=genKrn
11        structure tS= toStringEin
12    
13      in      in
14    
# Line 23  Line 24 
24  fun errS str=raise Fail(str)  fun errS str=raise Fail(str)
25    
26  (*Helpers for scalars*)  (*Helpers for scalars*)
27  fun mkCons(shape,rest)=S3.aaV(DstOp.cons(DstTy.TensorTy shape,0),rest,"Cons",DstTy.TensorTy(shape))  
28    fun mkCons(shape, rest)=let
29        val ty=DstTy.TensorTy shape
30        val a=DstIL.Var.new("Cons" ,ty)
31        val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
32        (*val _=print("###"^tS.toStringAll(ty,code))*)
33        in (a, [code])
34        end
35    
36    
37  val Sca=DstTy.TensorTy([])  val Sca=DstTy.TensorTy([])
38  fun mkProdSca rest=S3.aaV(DstOp.prodSca,rest,"prodSca",Sca)  val addR=DstOp.addSca
39  fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)  
40  fun mkDivSca rest= S3.aaV(DstOp.divSca,rest,"divSca",Sca)  
41  fun mkMultipleSca(ids,rator)=S3.mkMultiple(ids,rator,Sca)  fun mkProdSca(lhs,rest)=S3.aaV(DstOp.prodSca,rest,lhs^"prodSca",Sca)
42  fun mkInt n= S3.aaV(DstOp.C(n),[],"Int",Sca)  fun mkSubSca(lhs,rest)= S3.aaV(DstOp.subSca,rest,lhs^"subSca",Sca)
43    fun mkDivSca(lhs,rest)= S3.aaV(DstOp.divSca,rest,lhs^"divSca",Sca)
44    fun mkMultipleSca(info,ids,rator)=S3.mkMultiple(info,ids,rator,Sca)
45    fun mkReal n=S3.mkReal n
46    
47    
48    
49    
50    
51  fun prodIter(origIndex,index,nextfn,args)=(let  fun prodIter(origIndex,index,nextfn,args)=(let
52      val index'=List.map (fn (e)=>(e-1)) index      val index'=List.map (fn (e)=>(e-1)) index
# Line 76  Line 93 
93          in          in
94              (vA,code'@ccode)              (vA,code'@ccode)
95          end          end
96        | Iter _=raise Fail"index' is larger than origIndex"
97      in      in
98          Iter(empty,index',[],[],origIndex,0)          Iter(empty,index',[],[],origIndex,0)
99      end)      end)
100    
101  (*Get constant *)  (*Get constant *)
102  fun skeleton A=(case A  fun skeleton A=(case A
103      of [DstIL.ASSGN(_,DstIL.OP(DstOp.C c,_))]=>c      of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0
104        | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1
105        | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1
106      | _ => 9      | _ => 9
107      (*end case*))      (*end case*))
108    
# Line 93  Line 113 
113      (*end case*))      (*end case*))
114    
115    
116  fun NegCheckO(vA,A)=(case skeleton A  fun NegCheckO(lhs,(vA,A))=(case skeleton A
117      of 0 => mkInt 0      of 0 => mkReal 0
118      | ~1 => mkInt 1      | ~1 => mkReal 1
119      | 1  => mkInt ~1      | 1  => mkReal ~1
120      |  _=> let      |  _=> let
121          val (vB,B)=mkInt ~1          val (vB,B)=mkReal ~1
122          val (vD,D)=mkProdSca [vB,vA]          val (vD,D)=mkProdSca (lhs,[vB,vA])
123          in (vD,A@B@D) end          in (vD,A@B@D) end
124      (*end case*))      (*end case*))
125    
126    
127  fun SubcheckO((vA,A),(vB,B))=(case((skeleton A),(skeleton B))  fun SubcheckO(lhs,(vA,A),(vB,B))=(case((skeleton A),(skeleton B))
128      of (0,0)=> mkInt 0      of (0,0)=> mkReal 0
129      |(0,_)=> let      |(0,_)=> let
130          val (vD,D)= mkInt ~1          val (vD,D)= mkReal ~1
131          val (vE,E)= mkProdSca [vD,vB]          val (vE,E)= mkProdSca(lhs, [vD,vB])
132          in (vE,B@D@E) end          in (vE,B@D@E) end
133      | (_,0)=> (vA,A)      | (_,0)=> (vA,A)
134      | _ => let      | _ => let
135          val (vD,D)= mkSubSca [vA,vB]          val (vD,D)= mkSubSca(lhs,[vA,vB])
136          in (vD, A@B@D) end          in (vD, A@B@D) end
137      (*end case*))      (*end case*))
 (*  
 fun printMapp mapp=(case (lookup 0 mapp)  
 of NONE=>print(String.concat["\n No zero"])  
     |SOME s => print(String.concat["\n Found 0 =>",Int.toString(s)])  
 (*end case*))  
138    
139  *)  
 (*  val info=(params,args)*)  
140    
141  (* general expressions-removes zeros*)  (* general expressions-removes zeros*)
142  fun generalfn(dict,(body,origargs,info))=let  fun generalfn(dict,(body,origargs,info as (lhs,_,_)))=let
143      val mapp=ref dict      val mapp=ref dict
144      (*val _=printMapp(!mapp)*)  
     val _=print "in General FUNCTIOn"  
145    
146    
147      fun gen body=let      fun gen body=let
148          fun AddcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end          fun AddcheckO ([],[],[])=let val (vA,A)=mkReal 0 in ([vA],A) end
149            | AddcheckO([],ids,code)=(ids,code)            | AddcheckO([],ids,code)=(ids,code)
150            | AddcheckO(e1::es,ids,code)=let            | AddcheckO(e1::es,ids,code)=let
151              val (a,b)=gen e1              val (a,b)=gen e1
# Line 141  Line 154 
154                  |  _ => AddcheckO(es,ids@[a],code@b)                  |  _ => AddcheckO(es,ids@[a],code@b)
155                  (*end case*))                  (*end case*))
156              end              end
157          fun ProdcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end          fun ProdcheckO ([],[],[])=let val (vA,A)=mkReal 1 in ([vA],A) end
158            | ProdcheckO([],ids,code)=(ids,code)            | ProdcheckO([],ids,code)=(ids,code)
159            | ProdcheckO(e1::es,ids,code)=let            | ProdcheckO(e1::es,ids,code)=let
160               val (a,b)=gen e1               val (a,b)=gen e1
# Line 153  Line 166 
166              end              end
167    
168          fun Sumcheck(sumx,e)=let          fun Sumcheck(sumx,e)=let
             val _=print "\n Found Sum"  
169              fun sumloop mapsum=let              fun sumloop mapsum=let
170                  val _ = mapp:=mapsum                  val _ = mapp:=mapsum
171                  val(vA,A)=gen e                  val(vA,A)=gen e
172                  in ([vA],A) end                  in (case (skeleton A)
173                        of 0 => ([],A)
174                        |  _ => ([vA],A)
175                        (*end case*))
176                    end
177    
178                    (*in ([vA],A) end*)
179              fun sumI1(left,(v,0,lb1),[],rest,code)=let              fun sumI1(left,(v,0,lb1),[],rest,code)=let
180                  val dict=insert(v, lb1) left                  val dict=insert(v, lb1) left
181                  val (vD,pre)= sumloop dict                  val (vD,pre)= sumloop dict
# Line 173  Line 191 
191                  val dict=insert(v, (s+lb1)) left                  val dict=insert(v, (s+lb1)) left
192                  val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)                  val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)
193                  in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end                  in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end
194                 | sumI1 _ =raise Fail"None Variable-index in summation"
195              val (E.V v,lb,ub)=hd(sumx)              val (E.V v,lb,ub)=hd(sumx)
196              in              in
197                  sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])                  sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])
198              end              end
199    
200          fun iterList(e,rator)= (case e          fun iterList(e, DstOp.addSca)=(case e
201                of ([],code)=>let val (vA,A)=mkReal 0 in (vA,A) end
202                | ([id1],code) => (id1,code)
203                | (ids,code)    => let
204                    val (vB,B)= mkMultipleSca(info,ids,addR)
205                    in (vB,code@B) end
206                (*end case*))
207    
208            | iterList(e,rator)= (case e
209              of ([id1],code) => (id1,code)              of ([id1],code) => (id1,code)
210              | (ids,code)    => let              | (ids,code)    => let
211                  val (vB,B)= mkMultipleSca(ids, rator)                  val (vB,B)= mkMultipleSca(info,ids, rator)
212                  in (vB,code@B) end                  in (vB,code@B) end
213              (*end case*))              (*end case*))
214    
215      in (case body      in (case body
216          of  E.Field _           => err 1          of  E.Field _           => err 1
217          | E.Partial _           => err 1          | E.Partial _           => err 1
# Line 193  Line 221 
221          | E.Krn _               => err 1          | E.Krn _               => err 1
222          | E.Img _               => err 1          | E.Img _               => err 1
223          | E.Lift _              => err 1          | E.Lift _              => err 1
224          | E.Value v             => mkInt(findIX(v,!mapp))          | E.Value v             => mkReal(findIX(v,!mapp))
225          | E.Const c             => mkInt c          | E.Const c             => mkReal c
226          | E.Epsilon(i,j,k)      => S3.evalEps(!mapp,i,j,k)          | E.Epsilon(i,j,k)      => S3.evalEps(!mapp,i,j,k)
227          | E.Delta(i,j)          => S3.evalDelta2(!mapp,i,j)          | E.Delta(i,j)          => S3.evalDelta2(!mapp,i,j)
228          | E.Tensor(id,ix)       => (print("\n Tensor: "^Int.toString(id));          | E.Tensor(id,ix)       => S3.mkSca(!mapp,(id,ix,info))
229              S3.mkSca(!mapp,(id,ix,info)))          | E.Neg e               => NegCheckO(lhs,gen e)
230          | E.Neg e               => NegCheckO(gen e)          | E.Sub (e1,e2)         => SubcheckO(lhs,gen e1,gen e2)
         | E.Sub (e1,e2)         => SubcheckO(gen e1,gen e2)  
231          | E.Div(e1,e2)          => let          | E.Div(e1,e2)          => let
232              val (vA,A)=gen e1              val (vA,A)=gen e1
233              in (case (skeleton A)              in (case (skeleton A)
234                  of 0=> mkInt 0                  of 0=> mkReal 0
235                  | _=> let                  | _=> let
236                      val (vB,B)=gen e2                      val (vB,B)=gen e2
237                      val (vD,D)= mkDivSca [vA,vB]                      val (vD,D)= mkDivSca(lhs, [vA,vB])
238                      in (vD, A@B@D) end                      in (vD, A@B@D) end
239                  (*end case*))                  (*end case*))
240              end              end
241          | E.Add e               => (print "\n in Add";iterList(AddcheckO(e,[],[]),DstOp.addSca))          | E.Add e               => (iterList(AddcheckO(e,[],[]),addR))
242          | E.Prod e              => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)          | E.Prod e              => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)
243          | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let          | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let
244              val harg=List.nth(origargs,Hid)              val harg=List.nth(origargs,Hid)
245              val imgarg=List.nth(origargs,Vid)              val imgarg=List.nth(origargs,Vid)
246              val h=S3.getKernel(harg)              val h=S3.getKernel harg
247              val v=S3.getImage(imgarg)              val v=S3.getImageSrc imgarg
248                val(_,_,args)=info
249                val imgargnew=List.nth(args,Vid)
250              in              in
251                  genKrn.evalField(!mapp,(body,v,h,info))                  genKrn.evalField(!mapp,(body,(v,imgargnew),h,info))
252              end              end
253    
254          | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),DstOp.addSca)          | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),addR)
255          (*end case*))          (*end case*))
256          end          end
257    

Legend:
Removed from v.2612  
changed lines
  Added in v.2838

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0