Home My Page Projects Code Snippets Project Openings SML/NJ
Summary Activity Forums Tracker Lists Tasks Docs Surveys News SCM Files

SCM Repository

[smlnj] View of /MLRISC/trunk/amd64/staged-allocation/amd64-svid-fn.sml
ViewVC logotype

View of /MLRISC/trunk/amd64/staged-allocation/amd64-svid-fn.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3037 - (download) (annotate)
Tue May 27 06:30:07 2008 UTC (12 years, 4 months ago) by mrainey
File size: 11943 byte(s)
  Cleaning up AMD64 C calls.
(* amd64-svid-fn.sml
 *
 * C calling conventions for the AMD64. We use the technique of Staged Allocation (see
 * MLRISC/staged-allocation).
 *
 * Mike Rainey (mrainey@cs.uchicago.edu)
 *)

functor AMD64SVIDFn (
    structure T : MLTREE
  ) = struct

    structure T = T
    structure C = AMD64Cells
    structure CB = CellsBasis
    structure CTy = CTypes

    val wordTy = 64
    val mem = T.Region.memory
    val stack = T.Region.stack

    fun lit i = T.LI (T.I.fromInt (wordTy, i))
    fun gpr r = T.GPR (T.REG (wordTy, r))
    fun fpr (ty, f) = T.FPR (T.FREG (ty, f))
    fun sum ls = List.foldl (op +) 0 ls
    fun szBOfCTy cTy = #sz (CSizes.sizeOfTy cTy)

    fun toGpr r = (wordTy, r)
    fun toGprs gprs = List.map toGpr gprs
    fun toFpr r = (64, r)
    fun toFprs fprs = List.map toFpr fprs 


    datatype location_kind
      = K_GPR                (* general-purpose registers *)
      | K_FPR                (* floating-point registers *)
      | K_MEM                (* memory locations *)

    structure SA = StagedAllocationFn (
                         structure T = T
			 structure TargetLang = struct
                           datatype location_kind = datatype location_kind
                         end
			 val memSize = 8 (* bytes *))

    structure CCs =
      struct
      
	val calleeSaveRegs = toGprs [C.rbx, C.r12, C.r13, C.r14, C.r15]
	val callerSaveRegs = toGprs [C.rax, C.rcx, C.rdx, C.rsi, C.rdi, C.r8, C.r9, C.r10, C.r11]
	val callerSaveFRegs = toFprs (C.Regs CB.FP {from=0, to=15, step=1})
	val calleeSaveFRegs = []
	val spReg = T.REG (wordTy, C.rsp)

	val maxAlign = 16

      (* conventions for returning arguments *)
	val gprRets = toGprs [C.rax, C.rdx]
	val fprRets = toFprs [C.xmm0, C.xmm1]
	val (_, ssFloat) = SA.useRegs fprRets
	val (_, ssGpr) = SA.useRegs gprRets
	val cCallStk = SA.freshCounter ()
	val returnStages = [
	    SA.CHOICE [
	        (* return in general-purpose register *)
	          (fn (w, k, str) => k = K_GPR,
	           SA.SEQ [SA.WIDEN (fn w => Int.max (wordTy, w)), ssGpr]),
		(* return in floating-point register *)
		  (fn (w, k, str) => k = K_FPR,
	           SA.SEQ [SA.WIDEN (fn w => Int.max (64, w)), ssFloat]),
		(* return in a memory location *)
		  (fn (w, k, str) => k = K_MEM,
		 (* FIXME! *)
		   SA.OVERFLOW {counter=cCallStk, blockDirection=SA.UP, maxAlign=maxAlign}) ]
	     ]

      (* conventions for passing arguments *)
	val gprParams = toGprs [C.rdi, C.rsi, C.rdx, C.rcx, C.r8, C.r9]
	val fprParams = toFprs [C.xmm0, C.xmm1, C.xmm2, C.xmm3, C.xmm4, C.xmm5, C.xmm6, C.xmm7]
	val cCallGpr = SA.freshCounter ()
	val cCallFpr = SA.freshCounter ()
      (* initial store *)
	val str0 = SA.init [cCallStk, cCallGpr, cCallFpr]

	val callStages = [ 
	      SA.CHOICE [
	      (* pass in general-purpose register *)
	      (fn (w, k, str) => k = K_GPR, SA.SEQ [
					    SA.WIDEN (fn w => Int.max (wordTy, w)),
					    SA.BITCOUNTER cCallGpr,
					    SA.REGS_BY_BITS (cCallGpr, gprParams) ]),
	      (* pass in floating point register *)
	      (fn (w, k, str) => k = K_FPR, SA.SEQ [
					    SA.WIDEN (fn w => Int.max (64, w)),
					    SA.BITCOUNTER cCallFpr,
					    SA.REGS_BY_BITS (cCallFpr, fprParams) ]),
	      (* pass on the stack *)
	      (fn (w, k, str) => k = K_MEM,
	       SA.OVERFLOW {counter=cCallStk, blockDirection=SA.UP, maxAlign=maxAlign}) 
	      ],
	      SA.OVERFLOW {counter=cCallStk, blockDirection=SA.UP, maxAlign=maxAlign}
	]

      end  (* CCs *)

    structure CCall = CCallFn (
		        structure T = T
			structure C = C
			val wordTy = wordTy
			fun offSp 0 = CCs.spReg
			  | offSp offset = T.ADD (wordTy, CCs.spReg, T.LI offset))

    datatype c_arg = datatype CCall.c_arg

  (* convert a list of C types to a list of eight bytes *)
    fun eightBytesOfCTys ([], eb, ebs) = List.rev (List.map List.rev (eb :: ebs))
      | eightBytesOfCTys (cTy :: cTys, eb, ebs) = let
	    val szTy = szBOfCTy cTy
	    val szEb = sum(List.map szBOfCTy eb)
	    in
	       if szTy + szEb = 8
		  then eightBytesOfCTys(cTys, [], (cTy :: eb) :: ebs)
	       else if szTy + szEb < 8
	          then eightBytesOfCTys(cTys, cTy :: eb, ebs)
	       else eightBytesOfCTys(cTys, [cTy], eb :: ebs)
	    end

  (* convert a C type into its eight bytes *)
    fun eightBytesOfCTy cTy = eightBytesOfCTys (CTypes.flattenCTy cTy, [], [])

  (* classify a C type into its location kind (assuming that aggregates cannot be passed in registers) *)
    fun kindOfCTy (CTy.C_float | CTy.C_double | CTy.C_long_double) = K_FPR
      | kindOfCTy (CTy.C_ARRAY _) = K_MEM
      | kindOfCTy (CTy.C_STRUCT _ | CTy.C_UNION _) = raise Fail "impossible"
      | kindOfCTy (CTy.C_unsigned _ | CTy.C_signed _ | CTy.C_PTR) = K_GPR

    fun combineKinds (k1, k2) = if (k1 = k2)
	then k1
	else (case (k1, k2)
	       of (K_MEM, _) => K_MEM
		| (_, K_MEM) => K_MEM
		| (K_GPR, _) => K_GPR
		| (_, K_GPR) => K_GPR
		| _ => K_FPR
 	      (* end case*))

  (* this part of the ABI is tricky. if the eightbyte contains all floats, we use fprs, but 
   * otherwise we use gprs. *)
    fun kindOfEightByte [] = raise Fail "impossible"
      | kindOfEightByte [cTy] = kindOfCTy cTy
      | kindOfEightByte (cTy1 :: cTy2 :: cTys) = let
	   val k1 = combineKinds (kindOfCTy cTy1, kindOfCTy cTy2)
	   val k2 = kindOfEightByte(cTy2 :: cTys)
           in
	       combineKinds(k1, k2)
	   end

    fun containsUnalignedFields cTy = (case cTy
        of (CTy.C_STRUCT cTys | CTy.C_UNION cTys) => 
	   List.all (fn cTy => #sz (CSizes.sizeOfTy cTy) mod 8 = 0) cTys
	 | _ => false)

  (* classify a C type into its location kinds *)
    fun kindsOfCTy (cTy as CTy.C_STRUCT cTys) = 
	   if (szBOfCTy cTy > 2*8 orelse containsUnalignedFields cTy)
	      then List.tabulate (szBOfCTy cTy div 8, fn _ => K_MEM)
	      else List.map kindOfEightByte (eightBytesOfCTy cTy)
      | kindsOfCTy (cTy as CTy.C_UNION cTys) = raise Fail "todo"
      | kindsOfCTy (cTy as CTy.C_ARRAY _) = raise Fail "todo"
      | kindsOfCTy cTy = [kindOfCTy cTy]

    fun slotsOfCTy cTy = List.map (fn k => (8*8, k, 8)) (kindsOfCTy cTy)

    fun slotOfCTy cTy = (case slotsOfCTy cTy
			  of [slot] => slot
			   | _ => raise Fail "malformed C type"
			(* end case *))

  (* C location of a staged allocation location *) 
    fun cLocOfStagedAlloc (w, SA.REG (_, r), K_GPR) = CCall.C_GPR (w, r)
      | cLocOfStagedAlloc (w, SA.REG (_, r), K_FPR) = CCall.C_FPR (w, r)
      | cLocOfStagedAlloc (w, SA.BLOCK_OFFSET offB, (K_GPR | K_FPR | K_MEM)) = 
	CCall.C_STK (w, T.I.fromInt (wordTy, offB))
      | cLocOfStagedAlloc (w, SA.NARROW (loc, w', k), _) = cLocOfStagedAlloc (w', loc, k)
      | cLocOfStagedAlloc _ = raise Fail "impossible"

  (* given a return type, return the locations for the return values *)
    fun layoutReturn retTy = let
	   val returnStepper = SA.mkStep CCs.returnStages	   
	   in
	      case retTy
  	        of CTy.C_void => ([], NONE, CCs.str0)
		 | retTy => let
		       val (str, locs) = SA.doStagedAllocation(CCs.str0, returnStepper, slotsOfCTy retTy)
		       in
		           (List.map cLocOfStagedAlloc locs, SOME (CSizes.sizeOfTy retTy), str)
		       end
            end

  (* given a store and some parameters, return the C locations for those parameters *)
    fun layoutCall (str, paramTys) = let
	   val callStepper = SA.mkStep CCs.callStages
	   fun doParam (paramTy, (str, paramLocss)) = let
	          val (str', paramLocs) = SA.doStagedAllocation(str, callStepper, slotsOfCTy paramTy)
	          in
	             (str', List.map cLocOfStagedAlloc paramLocs :: paramLocss)
	          end
	   val (str, paramLocss) = List.foldl doParam (str, []) paramTys
           in
	      (List.rev paramLocss, str)
           end

    fun layout {conv, retTy, paramTys} = let
	   val (resLocs, structRetLoc, str) = layoutReturn retTy
	   val (paramLocss, str) = layoutCall(str, paramTys)
	 (* number of bytes allocated for the call *)
	   val cStkSzB = SA.find(str, CCs.cCallStk)
           in
	      {argLocs=paramLocss, argMem={szB=cStkSzB, align=8}, structRetLoc=structRetLoc, resLocs=resLocs}
	   end

  (* copy the return value into the result location *)
    fun returnVals resLocs = (case resLocs
         of [] => ([], [])
	  | [CCall.C_GPR (ty, r)] => let
		val resReg = C.newReg ()
	    in
		([T.GPR (T.REG (ty, resReg))],	 
		 [T.COPY (ty, [resReg], [r])])
	    end
	  | [CCall.C_FPR (ty, r)] => let
		val resReg = C.newFreg ()
	    in
		([T.FPR (T.FREG (ty, resReg))],
		 [T.FCOPY (ty, [resReg], [r])])
	    end
         (* end case *))

    fun genCall {name, proto, paramAlloc, structRet, saveRestoreDedicated, callComment, args} = let
	val {argLocs, argMem, resLocs, structRetLoc} = layout(proto)
	val argAlloc = if ((#szB argMem = 0) orelse paramAlloc argMem)
			then []
			else [T.MV (wordTy, C.rsp, T.SUB (wordTy, CCs.spReg, 
			      T.LI (T.I.fromInt (wordTy, #szB argMem))))]
	val (copyArgs, gprUses, fprUses) = CCall.copyArgs(args, argLocs)
       (* the defined registers of the call depend on the calling convention *)
 	val defs = (case #conv proto
            of "ccall" => List.map (gpr o #2) CCs.callerSaveRegs @ List.map fpr CCs.callerSaveFRegs
	     | "ccall-bare" => []
	     | conv => raise Fail (concat [
			"unknown calling convention \"", String.toString conv, "\""
		      ])
            (* end case *))
	val uses = List.map gpr gprUses @ List.map fpr fprUses
	val callStm = T.CALL {funct=name, targets=[], defs=defs, uses=uses, region=mem, pops=0}
	val (resultRegs, copyResult) = returnVals(resLocs)
	val callSeq = argAlloc @ copyArgs @ [callStm] @ copyResult
        in
          {callseq=callSeq, result=resultRegs}
        end (* genCall *)


    (* unit testing code *)
    structure Test = struct
      val ty1 = CTy.C_STRUCT [CTy.C_STRUCT [CTy.C_unsigned CTy.I_char, CTy.C_unsigned CTy.I_int]]
      val ty2 = CTy.C_STRUCT [CTy.C_signed CTy.I_short]
      val ty3 = CTy.C_STRUCT [CTy.C_signed CTy.I_short, CTy.C_PTR]
      val ty4 = CTy.C_STRUCT [CTy.C_PTR, CTy.C_PTR]
      val ty4 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_unsigned CTy.I_int], CTy.C_PTR]
      val ty5 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_float]]
      val ty6 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_float,CTy.C_float,CTy.C_float,CTy.C_float]]
      val ty7 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_STRUCT[CTy.C_float,CTy.C_float],CTy.C_float,CTy.C_float]]
      val ty8 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_STRUCT[CTy.C_float,CTy.C_unsigned CTy.I_int],CTy.C_float,CTy.C_float]]
      val ty9 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_float,CTy.C_float,CTy.C_float,CTy.C_float,CTy.C_float]]
      val ty10 = CTy.C_STRUCT [CTy.C_STRUCT[CTy.C_float,CTy.C_float, CTy.C_STRUCT[CTy.C_float,CTy.C_unsigned CTy.I_int]]]
      val ty11 = CTy.C_STRUCT [CTy.C_PTR, CTy.C_float, CTy.C_float, CTy.C_float]

      fun kindOfEB () = let
	  fun test (eb, k) = (kindOfEightByte eb = k) orelse raise Fail "failed test"
	  fun eb1 ty = hd (eightBytesOfCTy ty)
	  fun eb2 ty = hd(tl (eightBytesOfCTy ty))
          in
	      List.all test [(eb1 ty1, K_GPR), (eb1 ty2, K_GPR), (eb2 ty3, K_GPR),
			     (eb1 ty5, K_FPR), (eb1 ty6, K_FPR), (eb2 ty6, K_FPR),
			     (eb1 ty7, K_FPR), (eb2 ty7, K_FPR),
			     (eb1 ty8, K_GPR), (eb2 ty8, K_FPR)]
	  end

      fun li2k (_, k, _) = k

      fun slots () = let
	  fun test (lis : SA.slot list, ks2 : location_kind list) = let
	      val ks1 = List.map li2k lis
              in
	         (List.length ks1 = List.length ks2) andalso (ListPair.all (op =) (ks1, ks2))
	      end
	  val tests = [
(*	               (ty2, [K_GPR]), (ty1, [K_GPR]), (ty3, [K_GPR, K_GPR]), (ty4, [K_GPR, K_GPR]), 
		       (ty5, [K_FPR]), (ty6, [K_FPR, K_FPR]),
		       (ty7, [K_FPR, K_FPR]), (ty8, [K_GPR, K_FPR]),
		       (ty9, [K_MEM]), (ty10, [K_FPR, K_GPR]),
*)
		       (ty11, [K_MEM, K_MEM, K_MEM])
				       ]
	  val (ts, anss) = ListPair.unzip tests
          in
	     ListPair.all test (List.map slotsOfCTy ts, anss) orelse raise Fail "failed test"
          end
    end

  end (* AMD64SVIDFn *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0