microsoft/compiler: Implement wave reduce/exclusive scan ops that are supported

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21996>
This commit is contained in:
Jesse Natalie 2023-03-17 14:52:11 -07:00 committed by Marge Bot
parent 082368cd84
commit 981fe2bf42
3 changed files with 115 additions and 0 deletions

View File

@ -354,6 +354,19 @@ enum dxil_quad_op_kind {
QUAD_READ_ACROSS_DIAGONAL = 2,
};
enum dxil_wave_op_kind {
DXIL_WAVE_OP_SUM = 0,
DXIL_WAVE_OP_PRODUCT = 1,
DXIL_WAVE_OP_MIN = 2,
DXIL_WAVE_OP_MAX = 3,
};
enum dxil_wave_bit_op_kind {
DXIL_WAVE_BIT_OP_AND = 0,
DXIL_WAVE_BIT_OP_OR = 1,
DXIL_WAVE_BIT_OP_XOR = 2,
};
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -104,6 +104,9 @@ static struct predefined_func_descr predefined_funcs[] = {
{"dx.op.waveAllTrue", "b", "ib", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveActiveAllEqual", "b", "iO", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveActiveBallot", "F", "ib", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveActiveOp", "O", "iOcc", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.waveActiveBit", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.wavePrefixOp", "O", "iOcc", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
{"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
};

View File

@ -339,6 +339,9 @@ enum dxil_intr {
DXIL_INTR_WAVE_ACTIVE_BALLOT = 116,
DXIL_INTR_WAVE_READ_LANE_AT = 117,
DXIL_INTR_WAVE_READ_LANE_FIRST = 118,
DXIL_INTR_WAVE_ACTIVE_OP = 119,
DXIL_INTR_WAVE_ACTIVE_BIT = 120,
DXIL_INTR_WAVE_PREFIX_OP = 121,
DXIL_INTR_QUAD_READ_LANE_AT = 122,
DXIL_INTR_QUAD_OP = 123,
@ -4623,6 +4626,98 @@ emit_quad_op(struct ntd_context *ctx, nir_intrinsic_instr *intr, enum dxil_quad_
return true;
}
static enum dxil_wave_bit_op_kind
get_reduce_bit_op(nir_op op)
{
switch (op) {
case nir_op_ior: return DXIL_WAVE_BIT_OP_OR;
case nir_op_ixor: return DXIL_WAVE_BIT_OP_XOR;
case nir_op_iand: return DXIL_WAVE_BIT_OP_AND;
default:
unreachable("Invalid bit op");
}
}
static bool
emit_reduce_bitwise(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
enum dxil_wave_bit_op_kind wave_bit_op = get_reduce_bit_op(nir_intrinsic_reduction_op(intr));
const struct dxil_func *func = dxil_get_function(&ctx->mod, "dx.op.waveActiveBit",
get_overload(nir_type_uint, intr->dest.ssa.bit_size));
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, DXIL_INTR_WAVE_ACTIVE_BIT),
get_src(ctx, intr->src, 0, nir_type_uint),
dxil_module_get_int8_const(&ctx->mod, wave_bit_op),
};
if (!func || !args[0] || !args[1] || !args[2])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
store_dest(ctx, &intr->dest, 0, ret, nir_type_uint);
return true;
}
static enum dxil_wave_op_kind
get_reduce_op(nir_op op)
{
switch (op) {
case nir_op_iadd:
case nir_op_fadd:
return DXIL_WAVE_OP_SUM;
case nir_op_imul:
case nir_op_fmul:
return DXIL_WAVE_OP_PRODUCT;
case nir_op_imax:
case nir_op_umax:
case nir_op_fmax:
return DXIL_WAVE_OP_MAX;
case nir_op_imin:
case nir_op_umin:
case nir_op_fmin:
return DXIL_WAVE_OP_MIN;
default:
unreachable("Unexpected reduction op");
}
}
static bool
emit_reduce(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
ctx->mod.feats.wave_ops = 1;
bool is_prefix = intr->intrinsic == nir_intrinsic_exclusive_scan;
nir_op reduction_op = (nir_op)nir_intrinsic_reduction_op(intr);
switch (reduction_op) {
case nir_op_ior:
case nir_op_ixor:
case nir_op_iand:
assert(!is_prefix);
return emit_reduce_bitwise(ctx, intr);
default:
break;
}
nir_alu_type alu_type = nir_op_infos[reduction_op].input_types[0];
enum dxil_wave_op_kind wave_op = get_reduce_op(reduction_op);
const struct dxil_func *func = dxil_get_function(&ctx->mod, is_prefix ? "dx.op.wavePrefixOp" : "dx.op.waveActiveOp",
get_overload(alu_type, intr->dest.ssa.bit_size));
bool is_unsigned = alu_type == nir_type_uint;
const struct dxil_value *args[] = {
dxil_module_get_int32_const(&ctx->mod, is_prefix ? DXIL_INTR_WAVE_PREFIX_OP : DXIL_INTR_WAVE_ACTIVE_OP),
get_src(ctx, intr->src, 0, alu_type),
dxil_module_get_int8_const(&ctx->mod, wave_op),
dxil_module_get_int8_const(&ctx->mod, is_unsigned),
};
if (!func || !args[0] || !args[1] || !args[2] || !args[3])
return false;
const struct dxil_value *ret = dxil_emit_call(&ctx->mod, func, args, ARRAY_SIZE(args));
if (!ret)
return false;
store_dest(ctx, &intr->dest, 0, ret, alu_type);
return true;
}
static bool
emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
{
@ -4854,6 +4949,10 @@ emit_intrinsic(struct ntd_context *ctx, nir_intrinsic_instr *intr)
case nir_intrinsic_quad_swap_diagonal:
return emit_quad_op(ctx, intr, QUAD_READ_ACROSS_DIAGONAL);
case nir_intrinsic_reduce:
case nir_intrinsic_exclusive_scan:
return emit_reduce(ctx, intr);
case nir_intrinsic_load_num_workgroups:
case nir_intrinsic_load_workgroup_size:
default: