Skip to content

Commit

Permalink
PTW Hypervisor bug fixes: check GPA bits higher than HGATP.Mode (#3591)
Browse files Browse the repository at this point in the history
* util/package: VecToAugmentedVec extract

* traverse check GPA bits higher than HGATP

* check upper bits of VSATP addr when starting PTW

* remove pte_cache_addr, just use pte_addr
  • Loading branch information
ingallsj authored Mar 20, 2024
1 parent 7b9d44f commit dbcb06a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 29 deletions.
48 changes: 22 additions & 26 deletions src/main/scala/rocket/PTW.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
val aux_count = Reg(UInt(log2Ceil(pgLevels).W))
/** pte for 2-stage translation */
val aux_pte = Reg(new PTE)
val aux_ppn_hi = (pgLevels > 4 && r_req.addr.getWidth > aux_pte.ppn.getWidth).option(Reg(UInt((r_req.addr.getWidth - aux_pte.ppn.getWidth).W)))
val gpa_pgoff = Reg(UInt(pgIdxBits.W)) // only valid in resp_gf case
val stage2 = Reg(Bool())
val stage2_final = Reg(Bool())
Expand All @@ -301,7 +300,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
}
}
// construct pte from mem.resp
val (pte, invalid_paddr) = {
val (pte, invalid_paddr, invalid_gpa) = {
val tmp = mem_resp_data.asTypeOf(new PTE())
val res = WireDefault(tmp)
res.ppn := Mux(do_both_stages && !stage2, tmp.ppn(vpnBits.min(tmp.ppn.getWidth)-1, 0), tmp.ppn(ppnBits-1, 0))
Expand All @@ -310,10 +309,12 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
for (i <- 0 until pgLevels-1)
when (count <= i.U && tmp.ppn((pgLevels-1-i)*pgLevelBits-1, (pgLevels-2-i)*pgLevelBits) =/= 0.U) { res.v := false.B }
}
(res, Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0.U, (tmp.ppn >> ppnBits) =/= 0.U))
(res,
Mux(do_both_stages && !stage2, (tmp.ppn >> vpnBits) =/= 0.U, (tmp.ppn >> ppnBits) =/= 0.U),
do_both_stages && !stage2 && checkInvalidHypervisorGPA(r_hgatp, tmp.ppn))
}
// find non-leaf PTE, need traverse
val traverse = pte.table() && !invalid_paddr && count < (pgLevels-1).U
val traverse = pte.table() && !invalid_paddr && !invalid_gpa && count < (pgLevels-1).U
/** address send to mem for enquerry */
val pte_addr = if (!usingVM) 0.U else {
val vpn_idxs = (0 until pgLevels).map { i =>
Expand All @@ -328,19 +329,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
//use vpn slice as offset
raw_pte_addr.apply(size.min(raw_pte_addr.getWidth) - 1, 0)
}
/** pte_cache input addr */
val pte_cache_addr = if (!usingHypervisor) pte_addr else {
val vpn_idxs = (0 until pgLevels-1).map { i =>
val ext_aux_pte_ppn = aux_ppn_hi match {
case None => aux_pte.ppn
case Some(hi) => Cat(hi, aux_pte.ppn)
}
(ext_aux_pte_ppn >> (pgLevels - i - 1) * pgLevelBits)(pgLevelBits - 1, 0)
}
val vpn_idx = vpn_idxs(count)
val raw_pte_cache_addr = Cat(r_pte.ppn, vpn_idx) << log2Ceil(xLen/8)
raw_pte_cache_addr(vaddrBits.min(raw_pte_cache_addr.getWidth)-1, 0)
}
/** stage2_pte_cache input addr */
val stage2_pte_cache_addr = if (!usingHypervisor) 0.U else {
val vpn_idxs = (0 until pgLevels - 1).map { i =>
Expand Down Expand Up @@ -373,7 +361,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
else can_hit
val tag =
if (s2) Cat(true.B, stage2_pte_cache_addr.padTo(vaddrBits))
else Cat(r_req.vstage1, pte_cache_addr.padTo(if (usingHypervisor) vaddrBits else paddrBits))
else Cat(r_req.vstage1, pte_addr.padTo(if (usingHypervisor) vaddrBits else paddrBits))

val hits = tags.map(_ === tag).asUInt & valid
val hit = hits.orR && can_hit
Expand Down Expand Up @@ -539,7 +527,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
io.mem.req.bits.data := DontCare
io.mem.req.bits.mask := DontCare

io.mem.s1_kill := l2_hit || state =/= s_wait1
io.mem.s1_kill := l2_hit || (state =/= s_wait1) || resp_gf
io.mem.s1_data := DontCare
io.mem.s2_kill := false.B

Expand Down Expand Up @@ -607,12 +595,11 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
count := Mux(arb.io.out.bits.bits.stage2, hgatp_initial_count, satp_initial_count)
aux_count := Mux(arb.io.out.bits.bits.vstage1, vsatp_initial_count, 0.U)
aux_pte.ppn := aux_ppn
aux_ppn_hi.foreach { _ := aux_ppn >> aux_pte.ppn.getWidth }
aux_pte.reserved_for_future := 0.U
resp_ae_ptw := false.B
resp_ae_final := false.B
resp_pf := false.B
resp_gf := false.B
resp_gf := checkInvalidHypervisorGPA(io.dpath.hgatp, aux_ppn) && arb.io.out.bits.bits.stage2
resp_hr := true.B
resp_hw := true.B
resp_hx := true.B
Expand All @@ -630,7 +617,6 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
when (stage2_pte_cache_hit) {
aux_count := aux_count + 1.U
aux_pte.ppn := stage2_pte_cache_data
aux_ppn_hi.foreach { _ := 0.U }
aux_pte.reserved_for_future := 0.U
pte_hit := true.B
}.elsewhen (pte_cache_hit) {
Expand All @@ -639,6 +625,10 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
}.otherwise {
next_state := Mux(io.mem.req.ready, s_wait1, s_req)
}
when(resp_gf) {
next_state := s_ready
resp_valid(r_req_dest) := true.B
}
}
is (s_wait1) {
// This Mux is for the l2_error case; the l2_hit && !l2_error case is overriden below
Expand Down Expand Up @@ -676,7 +666,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(

r_pte := OptimizationBarrier(
// l2tlb hit->find a leaf PTE(l2_pte), respond to L1TLB
Mux(l2_hit && !l2_error, l2_pte,
Mux(l2_hit && !l2_error && !resp_gf, l2_pte,
// S2 PTE cache hit -> proceed to the next level of walking, update the r_pte with hgatp
Mux(state === s_req && stage2_pte_cache_hit, makeHypervisorRootPTE(r_hgatp, stage2_pte_cache_data, l2_pte),
// pte cache hit->find a non-leaf PTE(pte_cache),continue to request mem
Expand All @@ -691,7 +681,7 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
Mux(arb.io.out.fire, Mux(arb.io.out.bits.bits.stage2, makeHypervisorRootPTE(io.dpath.hgatp, io.dpath.vsatp.ppn, r_pte), makePTE(satp.ppn, r_pte)),
r_pte))))))))

when (l2_hit && !l2_error) {
when (l2_hit && !l2_error && !resp_gf) {
assert(state === s_req || state === s_wait1)
next_state := s_ready
resp_valid(r_req_dest) := true.B
Expand Down Expand Up @@ -753,14 +743,14 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
val s1_ppns = (0 until pgLevels-1).map(i => Cat(pte.ppn(pte.ppn.getWidth-1, (pgLevels-i-1)*pgLevelBits), r_req.addr(((pgLevels-i-1)*pgLevelBits min vpnBits)-1,0).padTo((pgLevels-i-1)*pgLevelBits))) :+ pte.ppn
makePTE(s1_ppns(count), pte)
})
aux_ppn_hi.foreach { _ := 0.U }
stage2 := true.B
}

for (i <- 0 until pgLevels) {
val leaf = mem_resp_valid && !traverse && count === i.U
ccover(leaf && pte.v && !invalid_paddr && pte.reserved_for_future === 0.U, s"L$i", s"successful page-table access, level $i")
ccover(leaf && pte.v && !invalid_paddr && !invalid_gpa && pte.reserved_for_future === 0.U, s"L$i", s"successful page-table access, level $i")
ccover(leaf && pte.v && invalid_paddr, s"L${i}_BAD_PPN_MSB", s"PPN too large, level $i")
ccover(leaf && pte.v && invalid_gpa, s"L${i}_BAD_GPA_MSB", s"GPA too large, level $i")
ccover(leaf && pte.v && pte.reserved_for_future =/= 0.U, s"L${i}_BAD_RSV_MSB", s"reserved MSBs set, level $i")
ccover(leaf && !mem_resp_data(0), s"L${i}_INVALID_PTE", s"page not present, level $i")
if (i != pgLevels-1)
Expand Down Expand Up @@ -790,6 +780,12 @@ class PTW(n: Int)(implicit edge: TLEdgeOut, p: Parameters) extends CoreModule()(
pte.ppn := Cat(hgatp.ppn >> maxHypervisorExtraAddrBits, lsbs)
pte
}
/** use hgatp and vpn to check for gpa out of range */
private def checkInvalidHypervisorGPA(hgatp: PTBR, vpn: UInt) = {
val count = pgLevels.U - minPgLevels.U - hgatp.additionalPgLevels
val idxs = (0 to pgLevels-minPgLevels).map(i => (vpn >> ((pgLevels-i)*pgLevelBits)+maxHypervisorExtraAddrBits))
idxs.extract(count) =/= 0.U
}
}

/** Mix-ins for constructing tiles that might have a PTW */
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/rocket/TLB.scala
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ class TLB(instruction: Boolean, lgMaxSize: Int, cfg: TLBConfig)(implicit edge: T
val pf_ld_array = Mux(cmd_read, ((~Mux(cmd_readx, x_array, r_array) & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array, 0.U)
val pf_st_array = Mux(cmd_write_perms, ((~w_array & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array, 0.U)
val pf_inst_array = ((~x_array & ~ptw_ae_array) | ptw_pf_array) & ~ptw_gf_array
val gf_ld_array = Mux(priv_v && cmd_read, ~Mux(cmd_readx, hx_array, hr_array) & ~ptw_ae_array, 0.U)
val gf_st_array = Mux(priv_v && cmd_write_perms, ~hw_array & ~ptw_ae_array, 0.U)
val gf_inst_array = Mux(priv_v, ~hx_array & ~ptw_ae_array, 0.U)
val gf_ld_array = Mux(priv_v && cmd_read, (~Mux(cmd_readx, hx_array, hr_array) | ptw_gf_array) & ~ptw_ae_array, 0.U)
val gf_st_array = Mux(priv_v && cmd_write_perms, (~hw_array | ptw_gf_array) & ~ptw_ae_array, 0.U)
val gf_inst_array = Mux(priv_v, (~hx_array | ptw_gf_array) & ~ptw_ae_array, 0.U)

val gpa_hits = {
val need_gpa_mask = if (instruction) gf_inst_array else gf_ld_array | gf_st_array
Expand Down
8 changes: 8 additions & 0 deletions src/main/scala/util/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ package object util {
def isOneOf(u1: UInt, u2: UInt*): Bool = isOneOf(u1 +: u2.toSeq)
}

implicit class VecToAugmentedVec[T <: Data](private val x: Vec[T]) extends AnyVal {

/** Like Vec.apply(idx), but tolerates indices of mismatched width */
def extract(idx: UInt): T = x((idx | 0.U(log2Ceil(x.size).W)).extract(log2Ceil(x.size) - 1, 0))
}

implicit class SeqToAugmentedSeq[T <: Data](private val x: Seq[T]) extends AnyVal {
def apply(idx: UInt): T = {
if (x.size <= 1) {
Expand All @@ -34,6 +40,8 @@ package object util {
}
}

def extract(idx: UInt): T = VecInit(x).extract(idx)

def asUInt: UInt = Cat(x.map(_.asUInt).reverse)

def rotate(n: Int): Seq[T] = x.drop(n) ++ x.take(n)
Expand Down

0 comments on commit dbcb06a

Please sign in to comment.