diff --git a/src/asahi/compiler/agx_compile.h b/src/asahi/compiler/agx_compile.h index 5702006525b..0ee8f43b7fe 100644 --- a/src/asahi/compiler/agx_compile.h +++ b/src/asahi/compiler/agx_compile.h @@ -206,7 +206,6 @@ static const nir_shader_compiler_options agx_nir_options = { .lower_pack_half_2x16 = true, .lower_unpack_half_2x16 = true, .lower_extract_byte = true, - .lower_extract_word = true, .lower_insert_byte = true, .lower_insert_word = true, .lower_cs_local_index_to_id = true, diff --git a/src/asahi/compiler/agx_nir_algebraic.py b/src/asahi/compiler/agx_nir_algebraic.py index cf60f136b9c..deef09d94fb 100644 --- a/src/asahi/compiler/agx_nir_algebraic.py +++ b/src/asahi/compiler/agx_nir_algebraic.py @@ -21,12 +21,21 @@ for s in [8, 16, 32, 64]: lower_sm5_shift += [((shift, f'a@{s}', b), (shift, a, ('iand', b, s - 1)))] -lower_half_pack = [ +lower_pack = [ (('pack_half_2x16_split', a, b), ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))), (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))), (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))), + + (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))), + (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))), + (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))), + (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))), + + # For optimizing extract->convert sequences for unpack/pack norm + (('u2f32', ('u2u32', a)), ('u2f32', a)), + (('i2f32', ('i2i32', a)), ('i2f32', a)), ] def main(): @@ -42,7 +51,7 @@ def run(): print('#include "agx_nir.h"') print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", - lower_sm5_shift + lower_half_pack).render()) + lower_sm5_shift + lower_pack).render()) if __name__ == '__main__':