nak: Use F2FP for nir_op_pack_half_2x16_split on SM86+

On Ampere and later, this instruction allows to handle packing of F32x2
to F16x2.

Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30765>
This commit is contained in:
Mary Guillemard 2024-08-21 16:14:20 +02:00 committed by Marge Bot
parent 6a292c2699
commit e19871bd6a
4 changed files with 160 additions and 27 deletions

View File

@ -1466,8 +1466,6 @@ impl<'a> ShaderFromNir<'a> {
nir_op_ixor => b.lop2(LogicOp2::Xor, srcs[0], srcs[1]),
nir_op_pack_half_2x16_split | nir_op_pack_half_2x16_rtz_split => {
assert!(alu.get_src(0).bit_size() == 32);
let low = b.alloc_ssa(RegFile::GPR, 1);
let high = b.alloc_ssa(RegFile::GPR, 1);
let rnd_mode = match alu.op {
nir_op_pack_half_2x16_split => FRndMode::NearestEven,
@ -1475,32 +1473,46 @@ impl<'a> ShaderFromNir<'a> {
_ => panic!("Unhandled fp16 pack op"),
};
b.push_op(OpF2F {
dst: low.into(),
src: srcs[0],
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: rnd_mode,
ftz: false,
high: false,
integer_rnd: false,
});
if self.sm.sm() >= 86 {
let result: SSARef = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpF2FP {
dst: result.into(),
srcs: [srcs[1], srcs[0]],
rnd_mode: rnd_mode,
});
let src_bits = usize::from(alu.get_src(1).bit_size());
let src_type = FloatType::from_bits(src_bits);
assert!(matches!(src_type, FloatType::F32));
b.push_op(OpF2F {
dst: high.into(),
src: srcs[1],
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: rnd_mode,
ftz: false,
high: false,
integer_rnd: false,
});
result
} else {
let low = b.alloc_ssa(RegFile::GPR, 1);
let high = b.alloc_ssa(RegFile::GPR, 1);
b.prmt(low.into(), high.into(), [0, 1, 4, 5])
b.push_op(OpF2F {
dst: low.into(),
src: srcs[0],
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: rnd_mode,
ftz: false,
high: false,
integer_rnd: false,
});
let src_bits = usize::from(alu.get_src(1).bit_size());
let src_type = FloatType::from_bits(src_bits);
assert!(matches!(src_type, FloatType::F32));
b.push_op(OpF2F {
dst: high.into(),
src: srcs[1],
src_type: FloatType::F32,
dst_type: FloatType::F16,
rnd_mode: rnd_mode,
ftz: false,
high: false,
integer_rnd: false,
});
b.prmt(low.into(), high.into(), [0, 1, 4, 5])
}
}
nir_op_prmt_nv => {
let dst = b.alloc_ssa(RegFile::GPR, 1);

View File

@ -1135,3 +1135,60 @@ fn test_shr64() {
}
}
}
#[test]
fn test_f2fp_pack_ab() {
let run = RunSingleton::get();
let mut b = TestShaderBuilder::new(run.sm.as_ref());
let srcs = SSARef::from([
b.ld_test_data(0, MemType::B32)[0],
b.ld_test_data(4, MemType::B32)[0],
]);
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpF2FP {
dst: dst.into(),
srcs: [srcs[0].into(), srcs[1].into()],
rnd_mode: FRndMode::NearestEven,
});
b.st_test_data(8, MemType::B32, dst[0].into());
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpF2FP {
dst: dst.into(),
srcs: [srcs[0].into(), 2.0.into()],
rnd_mode: FRndMode::Zero,
});
b.st_test_data(12, MemType::B32, dst[0].into());
let bin = b.compile();
fn f32_to_u32(val: f32) -> u32 {
u32::from_le_bytes(val.to_le_bytes())
}
let zero = f32_to_u32(0.0);
let one = f32_to_u32(1.0);
let two = f32_to_u32(2.0);
let complex = f32_to_u32(1.4556);
let mut data = Vec::new();
data.push([one, two, 0, 0]);
data.push([one, zero, 0, 0]);
data.push([complex, zero, 0, 0]);
run.run.run(&bin, &mut data).unwrap();
// { 1.0fp16, 2.0fp16 }
assert_eq!(data[0][2], 0x3c004000);
// { 1.0fp16, 2.0fp16 }
assert_eq!(data[0][3], 0x3c004000);
// { 1.0fp16, 0.0fp16 }
assert_eq!(data[1][2], 0x3c000000);
// { 1.0fp16, 0.0fp16 }
assert_eq!(data[1][3], 0x3c004000);
// { 1.456fp16, 0.0fp16 }
assert_eq!(data[2][2], 0x3dd30000);
// { 1.455fp16, 0.0fp16 }
assert_eq!(data[2][3], 0x3dd24000);
}

View File

@ -4058,6 +4058,29 @@ impl DisplayOp for OpF2F {
}
impl_display_for_op!(OpF2F);
#[repr(C)]
#[derive(DstsAsSlice, SrcsAsSlice)]
pub struct OpF2FP {
#[dst_type(GPR)]
pub dst: Dst,
#[src_type(ALU)]
pub srcs: [Src; 2],
pub rnd_mode: FRndMode,
}
impl DisplayOp for OpF2FP {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "f2fp.pack_ab")?;
if self.rnd_mode != FRndMode::NearestEven {
write!(f, "{}", self.rnd_mode)?;
}
write!(f, " {}, {}", self.srcs[0], self.srcs[1],)
}
}
impl_display_for_op!(OpF2FP);
#[repr(C)]
#[derive(DstsAsSlice)]
pub struct OpF2I {
@ -6159,6 +6182,7 @@ pub enum Op {
Shl(OpShl),
Shr(OpShr),
F2F(OpF2F),
F2FP(OpF2FP),
F2I(OpF2I),
I2F(OpI2F),
I2I(OpI2I),
@ -6606,7 +6630,8 @@ impl Instr {
pub fn has_fixed_latency(&self, sm: u8) -> bool {
match &self.op {
// Float ALU
Op::FAdd(_)
Op::F2FP(_)
| Op::FAdd(_)
| Op::FFma(_)
| Op::FMnMx(_)
| Op::FMul(_)

View File

@ -1921,6 +1921,44 @@ impl SM70Op for OpF2F {
}
}
impl SM70Op for OpF2FP {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
let gpr = op_gpr(self);
let [src0, src1] = &mut self.srcs;
swap_srcs_if_not_reg(src0, src1, gpr);
b.copy_alu_src_if_not_reg(src0, gpr, SrcType::ALU);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
if src_is_zero_or_gpr(&self.srcs[1]) {
e.encode_alu(
0x03e,
Some(&self.dst),
Some(&self.srcs[0]),
Some(&self.srcs[1]),
Some(&Src::new_zero()),
)
} else {
e.encode_alu(
0x03e,
Some(&self.dst),
None,
Some(&self.srcs[1]),
Some(&self.srcs[0]),
)
};
// .MERGE_C behavior
// Use src1 and src2, src0 is unused
// src1 get converted and packed in the lower 16 bits of dest.
// src2 lower or high 16 bits (decided by .H1 flag) get packed in the upper of dest.
e.set_bit(78, false); // TODO: .MERGE_C
e.set_bit(72, false); // .H1 (MERGE_C only)
e.set_rnd_mode(79..81, self.rnd_mode);
}
}
impl SM70Op for OpF2I {
fn legalize(&mut self, _b: &mut LegalizeBuilder) {
// Nothing to do
@ -3397,6 +3435,7 @@ macro_rules! as_sm70_op_match {
Op::PopC(op) => op,
Op::Shf(op) => op,
Op::F2F(op) => op,
Op::F2FP(op) => op,
Op::F2I(op) => op,
Op::I2F(op) => op,
Op::FRnd(op) => op,