mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2024-12-12 19:54:33 +08:00
nak: Use F2FP for nir_op_pack_half_2x16_split on SM86+
On Ampere and later, this instruction allows to handle packing of F32x2 to F16x2. Signed-off-by: Mary Guillemard <mary.guillemard@collabora.com> Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30765>
This commit is contained in:
parent
6a292c2699
commit
e19871bd6a
@ -1466,8 +1466,6 @@ impl<'a> ShaderFromNir<'a> {
|
||||
nir_op_ixor => b.lop2(LogicOp2::Xor, srcs[0], srcs[1]),
|
||||
nir_op_pack_half_2x16_split | nir_op_pack_half_2x16_rtz_split => {
|
||||
assert!(alu.get_src(0).bit_size() == 32);
|
||||
let low = b.alloc_ssa(RegFile::GPR, 1);
|
||||
let high = b.alloc_ssa(RegFile::GPR, 1);
|
||||
|
||||
let rnd_mode = match alu.op {
|
||||
nir_op_pack_half_2x16_split => FRndMode::NearestEven,
|
||||
@ -1475,32 +1473,46 @@ impl<'a> ShaderFromNir<'a> {
|
||||
_ => panic!("Unhandled fp16 pack op"),
|
||||
};
|
||||
|
||||
b.push_op(OpF2F {
|
||||
dst: low.into(),
|
||||
src: srcs[0],
|
||||
src_type: FloatType::F32,
|
||||
dst_type: FloatType::F16,
|
||||
rnd_mode: rnd_mode,
|
||||
ftz: false,
|
||||
high: false,
|
||||
integer_rnd: false,
|
||||
});
|
||||
if self.sm.sm() >= 86 {
|
||||
let result: SSARef = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpF2FP {
|
||||
dst: result.into(),
|
||||
srcs: [srcs[1], srcs[0]],
|
||||
rnd_mode: rnd_mode,
|
||||
});
|
||||
|
||||
let src_bits = usize::from(alu.get_src(1).bit_size());
|
||||
let src_type = FloatType::from_bits(src_bits);
|
||||
assert!(matches!(src_type, FloatType::F32));
|
||||
b.push_op(OpF2F {
|
||||
dst: high.into(),
|
||||
src: srcs[1],
|
||||
src_type: FloatType::F32,
|
||||
dst_type: FloatType::F16,
|
||||
rnd_mode: rnd_mode,
|
||||
ftz: false,
|
||||
high: false,
|
||||
integer_rnd: false,
|
||||
});
|
||||
result
|
||||
} else {
|
||||
let low = b.alloc_ssa(RegFile::GPR, 1);
|
||||
let high = b.alloc_ssa(RegFile::GPR, 1);
|
||||
|
||||
b.prmt(low.into(), high.into(), [0, 1, 4, 5])
|
||||
b.push_op(OpF2F {
|
||||
dst: low.into(),
|
||||
src: srcs[0],
|
||||
src_type: FloatType::F32,
|
||||
dst_type: FloatType::F16,
|
||||
rnd_mode: rnd_mode,
|
||||
ftz: false,
|
||||
high: false,
|
||||
integer_rnd: false,
|
||||
});
|
||||
|
||||
let src_bits = usize::from(alu.get_src(1).bit_size());
|
||||
let src_type = FloatType::from_bits(src_bits);
|
||||
assert!(matches!(src_type, FloatType::F32));
|
||||
b.push_op(OpF2F {
|
||||
dst: high.into(),
|
||||
src: srcs[1],
|
||||
src_type: FloatType::F32,
|
||||
dst_type: FloatType::F16,
|
||||
rnd_mode: rnd_mode,
|
||||
ftz: false,
|
||||
high: false,
|
||||
integer_rnd: false,
|
||||
});
|
||||
|
||||
b.prmt(low.into(), high.into(), [0, 1, 4, 5])
|
||||
}
|
||||
}
|
||||
nir_op_prmt_nv => {
|
||||
let dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
|
@ -1135,3 +1135,60 @@ fn test_shr64() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_f2fp_pack_ab() {
|
||||
let run = RunSingleton::get();
|
||||
let mut b = TestShaderBuilder::new(run.sm.as_ref());
|
||||
|
||||
let srcs = SSARef::from([
|
||||
b.ld_test_data(0, MemType::B32)[0],
|
||||
b.ld_test_data(4, MemType::B32)[0],
|
||||
]);
|
||||
|
||||
let dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpF2FP {
|
||||
dst: dst.into(),
|
||||
srcs: [srcs[0].into(), srcs[1].into()],
|
||||
rnd_mode: FRndMode::NearestEven,
|
||||
});
|
||||
b.st_test_data(8, MemType::B32, dst[0].into());
|
||||
|
||||
let dst = b.alloc_ssa(RegFile::GPR, 1);
|
||||
b.push_op(OpF2FP {
|
||||
dst: dst.into(),
|
||||
srcs: [srcs[0].into(), 2.0.into()],
|
||||
rnd_mode: FRndMode::Zero,
|
||||
});
|
||||
b.st_test_data(12, MemType::B32, dst[0].into());
|
||||
|
||||
let bin = b.compile();
|
||||
|
||||
fn f32_to_u32(val: f32) -> u32 {
|
||||
u32::from_le_bytes(val.to_le_bytes())
|
||||
}
|
||||
|
||||
let zero = f32_to_u32(0.0);
|
||||
let one = f32_to_u32(1.0);
|
||||
let two = f32_to_u32(2.0);
|
||||
let complex = f32_to_u32(1.4556);
|
||||
|
||||
let mut data = Vec::new();
|
||||
data.push([one, two, 0, 0]);
|
||||
data.push([one, zero, 0, 0]);
|
||||
data.push([complex, zero, 0, 0]);
|
||||
run.run.run(&bin, &mut data).unwrap();
|
||||
|
||||
// { 1.0fp16, 2.0fp16 }
|
||||
assert_eq!(data[0][2], 0x3c004000);
|
||||
// { 1.0fp16, 2.0fp16 }
|
||||
assert_eq!(data[0][3], 0x3c004000);
|
||||
// { 1.0fp16, 0.0fp16 }
|
||||
assert_eq!(data[1][2], 0x3c000000);
|
||||
// { 1.0fp16, 0.0fp16 }
|
||||
assert_eq!(data[1][3], 0x3c004000);
|
||||
// { 1.456fp16, 0.0fp16 }
|
||||
assert_eq!(data[2][2], 0x3dd30000);
|
||||
// { 1.455fp16, 0.0fp16 }
|
||||
assert_eq!(data[2][3], 0x3dd24000);
|
||||
}
|
||||
|
@ -4058,6 +4058,29 @@ impl DisplayOp for OpF2F {
|
||||
}
|
||||
impl_display_for_op!(OpF2F);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(DstsAsSlice, SrcsAsSlice)]
|
||||
pub struct OpF2FP {
|
||||
#[dst_type(GPR)]
|
||||
pub dst: Dst,
|
||||
|
||||
#[src_type(ALU)]
|
||||
pub srcs: [Src; 2],
|
||||
|
||||
pub rnd_mode: FRndMode,
|
||||
}
|
||||
|
||||
impl DisplayOp for OpF2FP {
|
||||
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "f2fp.pack_ab")?;
|
||||
if self.rnd_mode != FRndMode::NearestEven {
|
||||
write!(f, "{}", self.rnd_mode)?;
|
||||
}
|
||||
write!(f, " {}, {}", self.srcs[0], self.srcs[1],)
|
||||
}
|
||||
}
|
||||
impl_display_for_op!(OpF2FP);
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(DstsAsSlice)]
|
||||
pub struct OpF2I {
|
||||
@ -6159,6 +6182,7 @@ pub enum Op {
|
||||
Shl(OpShl),
|
||||
Shr(OpShr),
|
||||
F2F(OpF2F),
|
||||
F2FP(OpF2FP),
|
||||
F2I(OpF2I),
|
||||
I2F(OpI2F),
|
||||
I2I(OpI2I),
|
||||
@ -6606,7 +6630,8 @@ impl Instr {
|
||||
pub fn has_fixed_latency(&self, sm: u8) -> bool {
|
||||
match &self.op {
|
||||
// Float ALU
|
||||
Op::FAdd(_)
|
||||
Op::F2FP(_)
|
||||
| Op::FAdd(_)
|
||||
| Op::FFma(_)
|
||||
| Op::FMnMx(_)
|
||||
| Op::FMul(_)
|
||||
|
@ -1921,6 +1921,44 @@ impl SM70Op for OpF2F {
|
||||
}
|
||||
}
|
||||
|
||||
impl SM70Op for OpF2FP {
|
||||
fn legalize(&mut self, b: &mut LegalizeBuilder) {
|
||||
let gpr = op_gpr(self);
|
||||
let [src0, src1] = &mut self.srcs;
|
||||
swap_srcs_if_not_reg(src0, src1, gpr);
|
||||
|
||||
b.copy_alu_src_if_not_reg(src0, gpr, SrcType::ALU);
|
||||
}
|
||||
|
||||
fn encode(&self, e: &mut SM70Encoder<'_>) {
|
||||
if src_is_zero_or_gpr(&self.srcs[1]) {
|
||||
e.encode_alu(
|
||||
0x03e,
|
||||
Some(&self.dst),
|
||||
Some(&self.srcs[0]),
|
||||
Some(&self.srcs[1]),
|
||||
Some(&Src::new_zero()),
|
||||
)
|
||||
} else {
|
||||
e.encode_alu(
|
||||
0x03e,
|
||||
Some(&self.dst),
|
||||
None,
|
||||
Some(&self.srcs[1]),
|
||||
Some(&self.srcs[0]),
|
||||
)
|
||||
};
|
||||
|
||||
// .MERGE_C behavior
|
||||
// Use src1 and src2, src0 is unused
|
||||
// src1 get converted and packed in the lower 16 bits of dest.
|
||||
// src2 lower or high 16 bits (decided by .H1 flag) get packed in the upper of dest.
|
||||
e.set_bit(78, false); // TODO: .MERGE_C
|
||||
e.set_bit(72, false); // .H1 (MERGE_C only)
|
||||
e.set_rnd_mode(79..81, self.rnd_mode);
|
||||
}
|
||||
}
|
||||
|
||||
impl SM70Op for OpF2I {
|
||||
fn legalize(&mut self, _b: &mut LegalizeBuilder) {
|
||||
// Nothing to do
|
||||
@ -3397,6 +3435,7 @@ macro_rules! as_sm70_op_match {
|
||||
Op::PopC(op) => op,
|
||||
Op::Shf(op) => op,
|
||||
Op::F2F(op) => op,
|
||||
Op::F2FP(op) => op,
|
||||
Op::F2I(op) => op,
|
||||
Op::I2F(op) => op,
|
||||
Op::FRnd(op) => op,
|
||||
|
Loading…
Reference in New Issue
Block a user