Skip to content

Commit

Permalink
Rework predicate/vector slot spilling
Browse files Browse the repository at this point in the history
It should be a bit clearer now, and capable of spilling more than one
predicate/vector register when we get round to needing that
functionality.

Change-Id: Ibd3a372c0b08aa2d4cff94308367e1962da51c31
  • Loading branch information
jackgallagher-arm committed Oct 18, 2023
1 parent 456416a commit c3161d7
Showing 1 changed file with 93 additions and 55 deletions.
148 changes: 93 additions & 55 deletions ext/drx/scatter_gather_aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,21 @@
/* Control printing of verbose debugging messages. */
#define VERBOSE 0

#define SVE_MAX_VECTOR_LENGTH_BITS 2048
#define SVE_MAX_VECTOR_LENGTH_BYTES (SVE_MAX_VECTOR_LENGTH_BITS / 8)
#define SVE_VECTOR_ALIGNMENT_BYTES 16
#define SVE_VECTOR_SPILL_SLOT_SIZE \
SVE_MAX_VECTOR_LENGTH_BYTES + (SVE_VECTOR_ALIGNMENT_BYTES - 1)

#define SVE_MAX_PREDICATE_LENGTH_BITS (SVE_MAX_VECTOR_LENGTH_BITS / 8)
#define SVE_MAX_PREDICATE_LENGTH_BYTES (SVE_MAX_PREDICATE_LENGTH_BITS / 8)
#define SVE_PREDICATE_ALIGNMENT_BYTES 2
#define SVE_PREDICATE_SPILL_SLOT_SIZE SVE_MAX_PREDICATE_LENGTH_BYTES

typedef struct _per_thread_t {
/* TODO i#3844: drreg does not support spilling predicate/vector regs yet,
* so we do it ourselves.
*/
void *scratch_pred_spill_slot; /* Storage for spilled predicate register. */
void *scratch_vector_spill_slot; /* Storage for spilled vector register. */
void *scratch_vector_spill_slot_aligned; /* Aligned ptr inside
scratch_vector_spill_slot to save/restore
spilled Z vector register. */

void *scratch_pred_spill_slots; /* Storage for spilled predicate registers. */
size_t scratch_pred_spill_slots_size; /* Size of scratch_pred_spill_slots in bytes. */

void *scratch_vector_spill_slots; /* Storage for spilled vector registers. */
size_t scratch_vector_spill_slots_size; /* Size of scratch_vector_spill_slots in
bytes. */

void *scratch_vector_spill_slots_aligned; /* Aligned ptr inside
scratch_vector_spill_slots to save/restore
spilled Z vector registers. */
} per_thread_t;

/* Track the state of manual spill slots for SVE registers.
Expand Down Expand Up @@ -96,6 +91,9 @@ drx_scatter_gather_thread_init(void *drcontext)
{
per_thread_t *pt = (per_thread_t *)dr_thread_alloc(drcontext, sizeof(*pt));

const uint vl_bytes = proc_get_vector_length_bytes();
const uint pl_bytes = vl_bytes / 8; /* Predicate register size */

/*
* The instructions we use to load/store the spilled predicate register require
* the base address to be aligned to 2 bytes:
Expand All @@ -104,19 +102,26 @@ drx_scatter_gather_thread_init(void *drcontext)
* and dr_thread_alloc() guarantees allocated memory is aligned to the pointer size
* (8 bytes) so we shouldn't have to do any further alignment.
*/
pt->scratch_pred_spill_slot =
dr_thread_alloc(drcontext, SVE_PREDICATE_SPILL_SLOT_SIZE);
DR_ASSERT_MSG(ALIGNED(pt->scratch_pred_spill_slot, SVE_PREDICATE_ALIGNMENT_BYTES),
"scratch_pred_spill_slot is misaligned");
static const size_t predicate_alignment_bytes = 2;
pt->scratch_pred_spill_slots_size = pl_bytes * NUM_PRED_SLOTS;

pt->scratch_pred_spill_slots =
dr_thread_alloc(drcontext, pt->scratch_pred_spill_slots_size);
DR_ASSERT_MSG(ALIGNED(pt->scratch_pred_spill_slots, predicate_alignment_bytes),
"scratch_pred_spill_slots is misaligned");

/*
* The scalable vector versions of LDR/STR require 16 byte alignment so we have to
* over-allocate and get an aligned pointer inside the allocated memory.
*/
pt->scratch_vector_spill_slot =
dr_thread_alloc(drcontext, SVE_VECTOR_SPILL_SLOT_SIZE);
pt->scratch_vector_spill_slot_aligned =
(void *)ALIGN_FORWARD(pt->scratch_vector_spill_slot, SVE_VECTOR_ALIGNMENT_BYTES);
static const size_t vector_alignment_bytes = 16;
pt->scratch_vector_spill_slots_size =
(vl_bytes * NUM_VECTOR_SLOTS) + (vector_alignment_bytes - 1);

pt->scratch_vector_spill_slots =
dr_thread_alloc(drcontext, pt->scratch_vector_spill_slots_size);
pt->scratch_vector_spill_slots_aligned =
(void *)ALIGN_FORWARD(pt->scratch_vector_spill_slots, vector_alignment_bytes);

drmgr_set_tls_field(drcontext, drx_scatter_gather_tls_idx, (void *)pt);
}
Expand All @@ -126,8 +131,10 @@ drx_scatter_gather_thread_exit(void *drcontext)
{
per_thread_t *pt =
(per_thread_t *)drmgr_get_tls_field(drcontext, drx_scatter_gather_tls_idx);
dr_thread_free(drcontext, pt->scratch_pred_spill_slot, SVE_PREDICATE_SPILL_SLOT_SIZE);
dr_thread_free(drcontext, pt->scratch_vector_spill_slot, SVE_VECTOR_SPILL_SLOT_SIZE);
dr_thread_free(drcontext, pt->scratch_pred_spill_slots,
pt->scratch_pred_spill_slots_size);
dr_thread_free(drcontext, pt->scratch_vector_spill_slots,
pt->scratch_vector_spill_slots_size);
dr_thread_free(drcontext, pt, sizeof(*pt));
}

Expand Down Expand Up @@ -633,7 +640,7 @@ expand_contiguous(void *drcontext, instrlist_t *bb, instr_t *sg_instr,
reg_id_t
reserve_sve_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, reg_id_t min_register, reg_id_t max_register,
size_t slot_offset, opnd_size_t reg_size)
size_t slot_tls_offset, opnd_size_t reg_size, uint slot_num)
{
/* Search the instruction for an unused register we will use as a temp. */
reg_id_t reg;
Expand All @@ -646,17 +653,19 @@ reserve_sve_register(void *drcontext, instrlist_t *bb, instr_t *where,
drmgr_insert_read_tls_field(drcontext, drx_scatter_gather_tls_idx, bb, where,
scratch_gpr0);

/* ldr scratch_gpr0, [scratch_gpr0, #slot_offset] */
/* ldr scratch_gpr0, [scratch_gpr0, #slot_tls_offset] */
instrlist_meta_preinsert(
bb, where,
INSTR_CREATE_ldr(drcontext, opnd_create_reg(scratch_gpr0),
OPND_CREATE_MEMPTR(scratch_gpr0, slot_offset)));
OPND_CREATE_MEMPTR(scratch_gpr0, slot_tls_offset)));

/* str reg, [scratch_gpr0] */
/* str reg, [scratch_gpr0, #slot_num, mul vl] */
instrlist_meta_preinsert(
bb, where,
INSTR_CREATE_str(drcontext,
opnd_create_base_disp(scratch_gpr0, DR_REG_NULL, 0, 0, reg_size),
opnd_create_base_disp(
scratch_gpr0, DR_REG_NULL, /*scale=*/0,
/*disp=*/slot_num * opnd_size_in_bytes(reg_size), reg_size),
opnd_create_reg(reg)));

return reg;
Expand All @@ -666,32 +675,44 @@ reg_id_t
reserve_pred_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, spill_slot_state_t *slot_state)
{
DR_ASSERT(slot_state->pred_slots[0] == DR_REG_NULL);
uint slot;
for (slot = 0; slot < NUM_PRED_SLOTS; slot++) {
if (slot_state->pred_slots[slot] == DR_REG_NULL) {
break;
}
}
DR_ASSERT(slot_state->pred_slots[slot] == DR_REG_NULL);

/* Some instructions require the predicate to be in the range p0 - p7. This includes
* LASTB which we use to extract elements from the vector register.
*/
const reg_id_t reg =
reserve_sve_register(drcontext, bb, where, scratch_gpr0, DR_REG_P0, DR_REG_P7,
offsetof(per_thread_t, scratch_pred_spill_slot),
opnd_size_from_bytes(proc_get_vector_length_bytes() / 8));
const reg_id_t reg = reserve_sve_register(
drcontext, bb, where, scratch_gpr0, DR_REG_P0, DR_REG_P7,
offsetof(per_thread_t, scratch_pred_spill_slots),
opnd_size_from_bytes(proc_get_vector_length_bytes() / 8), slot);

slot_state->pred_slots[0] = reg;
slot_state->pred_slots[slot] = reg;
return reg;
}

reg_id_t
reserve_vector_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, spill_slot_state_t *slot_state)
{
DR_ASSERT(slot_state->vector_slots[0] == DR_REG_NULL);
uint slot;
for (slot = 0; slot < NUM_VECTOR_SLOTS; slot++) {
if (slot_state->vector_slots[slot] == DR_REG_NULL) {
break;
}
}
DR_ASSERT(slot_state->vector_slots[slot] == DR_REG_NULL);

const reg_id_t reg =
reserve_sve_register(drcontext, bb, where, scratch_gpr0, DR_REG_Z0, DR_REG_Z31,
offsetof(per_thread_t, scratch_vector_spill_slot_aligned),
opnd_size_from_bytes(proc_get_vector_length_bytes()));
offsetof(per_thread_t, scratch_vector_spill_slots_aligned),
opnd_size_from_bytes(proc_get_vector_length_bytes()), slot);

slot_state->vector_slots[0] = reg;
slot_state->vector_slots[slot] = reg;
return reg;
}

Expand All @@ -702,50 +723,67 @@ reserve_vector_register(void *drcontext, instrlist_t *bb, instr_t *where,
*/
void
unreserve_sve_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, reg_id_t reg, size_t slot_offset,
opnd_size_t reg_size)
reg_id_t scratch_gpr0, reg_id_t reg, size_t slot_tls_offset,
opnd_size_t reg_size, uint slot_num)
{
drmgr_insert_read_tls_field(drcontext, drx_scatter_gather_tls_idx, bb, where,
scratch_gpr0);

/* ldr scratch_gpr0, [scratch_gpr0, #slot_offset] */
/* ldr scratch_gpr0, [scratch_gpr0, #slot_tls_offset] */
instrlist_meta_preinsert(
bb, where,
INSTR_CREATE_ldr(drcontext, opnd_create_reg(scratch_gpr0),
OPND_CREATE_MEMPTR(scratch_gpr0, slot_offset)));
OPND_CREATE_MEMPTR(scratch_gpr0, slot_tls_offset)));

/* ldr reg, [scratch_gpr0] */
/* ldr reg, [scratch_gpr0, #slot_num, mul vl] */
instrlist_meta_preinsert(
bb, where,
INSTR_CREATE_ldr(
drcontext, opnd_create_reg(reg),
opnd_create_base_disp(scratch_gpr0, DR_REG_NULL, 0, 0, reg_size)));
opnd_create_base_disp(scratch_gpr0, DR_REG_NULL, /*scale=*/0,
/*disp=*/slot_num * opnd_size_in_bytes(reg_size),
reg_size)));
}

void
unreserve_pred_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, reg_id_t scratch_pred,
spill_slot_state_t *slot_state)
{
DR_ASSERT(slot_state->pred_slots[0] == scratch_pred);
slot_state->pred_slots[0] = DR_REG_NULL;
uint slot;
for (slot = 0; slot < NUM_PRED_SLOTS; slot++) {
if (slot_state->pred_slots[slot] == scratch_pred) {
break;
}
}
DR_ASSERT(slot_state->pred_slots[slot] == scratch_pred);

unreserve_sve_register(drcontext, bb, where, scratch_gpr0, scratch_pred,
offsetof(per_thread_t, scratch_pred_spill_slot),
opnd_size_from_bytes(proc_get_vector_length_bytes() / 8));
offsetof(per_thread_t, scratch_pred_spill_slots),
opnd_size_from_bytes(proc_get_vector_length_bytes() / 8),
slot);

slot_state->pred_slots[slot] = DR_REG_NULL;
}

void
unreserve_vector_register(void *drcontext, instrlist_t *bb, instr_t *where,
reg_id_t scratch_gpr0, reg_id_t scratch_vec,
spill_slot_state_t *slot_state)
{
DR_ASSERT(slot_state->vector_slots[0] == scratch_vec);
slot_state->vector_slots[0] = DR_REG_NULL;
uint slot;
for (slot = 0; slot < NUM_VECTOR_SLOTS; slot++) {
if (slot_state->vector_slots[slot] == scratch_vec) {
break;
}
}
DR_ASSERT(slot_state->vector_slots[slot] == scratch_vec);

unreserve_sve_register(drcontext, bb, where, scratch_gpr0, scratch_vec,
offsetof(per_thread_t, scratch_vector_spill_slot_aligned),
opnd_size_from_bytes(proc_get_vector_length_bytes()));
offsetof(per_thread_t, scratch_vector_spill_slots_aligned),
opnd_size_from_bytes(proc_get_vector_length_bytes()), slot);

slot_state->vector_slots[slot] = DR_REG_NULL;
}

/*****************************************************************************************
Expand Down

0 comments on commit c3161d7

Please sign in to comment.