x86: Use VMM API in memcmp-evex-movbe.S and minor changes

The only change to the existing generated code is `tzcnt` -> `bsf` to
save a byte of code size here and there.

Rewriting with VMM API allows for memcmp-evex-movbe to be used with
evex512 by including "x86-evex512-vecs.h" at the top.

Complete check passes on x86-64.
This commit is contained in:
Noah Goldstein 2022-10-29 15:19:58 -05:00
parent ca7d181b62
commit 419c832aba

View File

@ -62,44 +62,38 @@ Latency:
# define MEMCMP __memcmp_evex_movbe
# endif
# define VMOVU vmovdqu64
# ifndef VEC_SIZE
# include "x86-evex256-vecs.h"
# endif
# ifdef USE_AS_WMEMCMP
# define VMOVU_MASK vmovdqu32
# define CHAR_SIZE 4
# define VPCMP vpcmpd
# define VPCMPEQ vpcmpeqd
# define VPTEST vptestmd
# define USE_WIDE_CHAR
# else
# define VMOVU_MASK vmovdqu8
# define CHAR_SIZE 1
# define VPCMP vpcmpub
# define VPCMPEQ vpcmpeqb
# define VPTEST vptestmb
# endif
# include "reg-macros.h"
# define VEC_SIZE 32
# define PAGE_SIZE 4096
# define CHAR_PER_VEC (VEC_SIZE / CHAR_SIZE)
# define XMM0 xmm16
# define XMM1 xmm17
# define XMM2 xmm18
# define YMM0 ymm16
# define XMM1 xmm17
# define XMM2 xmm18
# define YMM1 ymm17
# define YMM2 ymm18
# define YMM3 ymm19
# define YMM4 ymm20
# define YMM5 ymm21
# define YMM6 ymm22
/* Warning!
wmemcmp has to use SIGNED comparison for elements.
memcmp has to use UNSIGNED comparison for elemnts.
*/
.section .text.evex,"ax",@progbits
.section SECTION(.text), "ax", @progbits
/* Cache align memcmp entry. This allows for much more thorough
frontend optimization. */
ENTRY_P2ALIGN (MEMCMP, 6)
@ -111,23 +105,40 @@ ENTRY_P2ALIGN (MEMCMP, 6)
/* Fall through for [0, VEC_SIZE] as its the hottest. */
ja L(more_1x_vec)
/* Create mask for CHAR's we want to compare. This allows us to
avoid having to include page cross logic. */
movl $-1, %ecx
bzhil %edx, %ecx, %ecx
kmovd %ecx, %k2
/* Create mask of bytes that are guranteed to be valid because
of length (edx). Using masked movs allows us to skip checks
for page crosses/zero size. */
mov $-1, %VRAX
bzhi %VRDX, %VRAX, %VRAX
/* NB: A `jz` might be useful here. Page-faults that are
invalidated by predicate execution (the evex mask) can be
very slow. The expectation is this is not the norm so and
"most" code will not regularly call 'memcmp' with length = 0
and memory that is not wired up. */
KMOV %VRAX, %k2
/* Safe to load full ymm with mask. */
VMOVU_MASK (%rsi), %YMM2{%k2}
VPCMP $4,(%rdi), %YMM2, %k1{%k2}
kmovd %k1, %eax
testl %eax, %eax
VMOVU_MASK (%rsi), %VMM(2){%k2}{z}
/* Slightly different method for VEC_SIZE == 64 to save a bit of
code size. This allows us to fit L(return_vec_0) entirely in
the first cache line. */
# if VEC_SIZE == 64
VPCMPEQ (%rdi), %VMM(2), %k1{%k2}
KMOV %k1, %VRCX
sub %VRCX, %VRAX
# else
VPCMP $4, (%rdi), %VMM(2), %k1{%k2}
KMOV %k1, %VRAX
test %VRAX, %VRAX
# endif
jnz L(return_vec_0)
ret
.p2align 4
.p2align 4,, 11
L(return_vec_0):
tzcntl %eax, %eax
bsf %VRAX, %VRAX
# ifdef USE_AS_WMEMCMP
movl (%rdi, %rax, CHAR_SIZE), %ecx
xorl %edx, %edx
@ -138,33 +149,36 @@ L(return_vec_0):
leal -1(%rdx, %rdx), %eax
# else
movzbl (%rsi, %rax), %ecx
# if VEC_SIZE == 64
movb (%rdi, %rax), %al
# else
movzbl (%rdi, %rax), %eax
# endif
subl %ecx, %eax
# endif
ret
.p2align 4
.p2align 4,, 11
L(more_1x_vec):
/* From VEC to 2 * VEC. No branch when size == VEC_SIZE. */
VMOVU (%rsi), %YMM1
VMOVU (%rsi), %VMM(1)
/* Use compare not equals to directly check for mismatch. */
VPCMP $4,(%rdi), %YMM1, %k1
kmovd %k1, %eax
VPCMP $4, (%rdi), %VMM(1), %k1
KMOV %k1, %VRAX
/* NB: eax must be destination register if going to
L(return_vec_[0,2]). For L(return_vec_3) destination register
must be ecx. */
testl %eax, %eax
L(return_vec_[0,2]). For L(return_vec_3) destination
register must be ecx. */
test %VRAX, %VRAX
jnz L(return_vec_0)
cmpq $(CHAR_PER_VEC * 2), %rdx
jbe L(last_1x_vec)
/* Check second VEC no matter what. */
VMOVU VEC_SIZE(%rsi), %YMM2
VPCMP $4, VEC_SIZE(%rdi), %YMM2, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU VEC_SIZE(%rsi), %VMM(2)
VPCMP $4, VEC_SIZE(%rdi), %VMM(2), %k1
KMOV %k1, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_1)
/* Less than 4 * VEC. */
@ -172,16 +186,16 @@ L(more_1x_vec):
jbe L(last_2x_vec)
/* Check third and fourth VEC no matter what. */
VMOVU (VEC_SIZE * 2)(%rsi), %YMM3
VPCMP $4,(VEC_SIZE * 2)(%rdi), %YMM3, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU (VEC_SIZE * 2)(%rsi), %VMM(3)
VPCMP $4, (VEC_SIZE * 2)(%rdi), %VMM(3), %k1
KMOV %k1, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_2)
VMOVU (VEC_SIZE * 3)(%rsi), %YMM4
VPCMP $4,(VEC_SIZE * 3)(%rdi), %YMM4, %k1
kmovd %k1, %ecx
testl %ecx, %ecx
VMOVU (VEC_SIZE * 3)(%rsi), %VMM(4)
VPCMP $4, (VEC_SIZE * 3)(%rdi), %VMM(4), %k1
KMOV %k1, %VRCX
test %VRCX, %VRCX
jnz L(return_vec_3)
/* Go to 4x VEC loop. */
@ -192,8 +206,8 @@ L(more_1x_vec):
branches. */
/* Load first two VEC from s2 before adjusting addresses. */
VMOVU -(VEC_SIZE * 4)(%rsi, %rdx, CHAR_SIZE), %YMM1
VMOVU -(VEC_SIZE * 3)(%rsi, %rdx, CHAR_SIZE), %YMM2
VMOVU -(VEC_SIZE * 4)(%rsi, %rdx, CHAR_SIZE), %VMM(1)
VMOVU -(VEC_SIZE * 3)(%rsi, %rdx, CHAR_SIZE), %VMM(2)
leaq -(4 * VEC_SIZE)(%rdi, %rdx, CHAR_SIZE), %rdi
leaq -(4 * VEC_SIZE)(%rsi, %rdx, CHAR_SIZE), %rsi
@ -202,56 +216,61 @@ L(more_1x_vec):
/* vpxor will be all 0s if s1 and s2 are equal. Otherwise it
will have some 1s. */
vpxorq (%rdi), %YMM1, %YMM1
vpxorq (VEC_SIZE)(%rdi), %YMM2, %YMM2
vpxorq (%rdi), %VMM(1), %VMM(1)
vpxorq (VEC_SIZE)(%rdi), %VMM(2), %VMM(2)
VMOVU (VEC_SIZE * 2)(%rsi), %YMM3
vpxorq (VEC_SIZE * 2)(%rdi), %YMM3, %YMM3
VMOVU (VEC_SIZE * 2)(%rsi), %VMM(3)
vpxorq (VEC_SIZE * 2)(%rdi), %VMM(3), %VMM(3)
VMOVU (VEC_SIZE * 3)(%rsi), %YMM4
/* Ternary logic to xor (VEC_SIZE * 3)(%rdi) with YMM4 while
oring with YMM1. Result is stored in YMM4. */
vpternlogd $0xde,(VEC_SIZE * 3)(%rdi), %YMM1, %YMM4
VMOVU (VEC_SIZE * 3)(%rsi), %VMM(4)
/* Ternary logic to xor (VEC_SIZE * 3)(%rdi) with VEC(4) while
oring with VEC(1). Result is stored in VEC(4). */
vpternlogd $0xde, (VEC_SIZE * 3)(%rdi), %VMM(1), %VMM(4)
/* Or together YMM2, YMM3, and YMM4 into YMM4. */
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4
/* Or together VEC(2), VEC(3), and VEC(4) into VEC(4). */
vpternlogd $0xfe, %VMM(2), %VMM(3), %VMM(4)
/* Test YMM4 against itself. Store any CHAR mismatches in k1.
/* Test VEC(4) against itself. Store any CHAR mismatches in k1.
*/
VPTEST %YMM4, %YMM4, %k1
VPTEST %VMM(4), %VMM(4), %k1
/* k1 must go to ecx for L(return_vec_0_1_2_3). */
kmovd %k1, %ecx
testl %ecx, %ecx
KMOV %k1, %VRCX
test %VRCX, %VRCX
jnz L(return_vec_0_1_2_3)
/* NB: eax must be zero to reach here. */
ret
.p2align 4,, 8
.p2align 4,, 9
L(8x_end_return_vec_0_1_2_3):
movq %rdx, %rdi
L(8x_return_vec_0_1_2_3):
/* L(loop_4x_vec) leaves result in `k1` for VEC_SIZE == 64. */
# if VEC_SIZE == 64
KMOV %k1, %VRCX
# endif
addq %rdi, %rsi
L(return_vec_0_1_2_3):
VPTEST %YMM1, %YMM1, %k0
kmovd %k0, %eax
testl %eax, %eax
VPTEST %VMM(1), %VMM(1), %k0
KMOV %k0, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_0)
VPTEST %YMM2, %YMM2, %k0
kmovd %k0, %eax
testl %eax, %eax
VPTEST %VMM(2), %VMM(2), %k0
KMOV %k0, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_1)
VPTEST %YMM3, %YMM3, %k0
kmovd %k0, %eax
testl %eax, %eax
VPTEST %VMM(3), %VMM(3), %k0
KMOV %k0, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_2)
.p2align 4,, 2
L(return_vec_3):
/* bsf saves 1 byte from tzcnt. This keep L(return_vec_3) in one
fetch block and the entire L(*return_vec_0_1_2_3) in 1 cache
line. */
bsfl %ecx, %ecx
bsf %VRCX, %VRCX
# ifdef USE_AS_WMEMCMP
movl (VEC_SIZE * 3)(%rdi, %rcx, CHAR_SIZE), %eax
xorl %edx, %edx
@ -266,11 +285,11 @@ L(return_vec_3):
ret
.p2align 4
.p2align 4,, 8
L(return_vec_1):
/* bsf saves 1 byte over tzcnt and keeps L(return_vec_1) in one
fetch block. */
bsfl %eax, %eax
bsf %VRAX, %VRAX
# ifdef USE_AS_WMEMCMP
movl VEC_SIZE(%rdi, %rax, CHAR_SIZE), %ecx
xorl %edx, %edx
@ -284,11 +303,11 @@ L(return_vec_1):
# endif
ret
.p2align 4,, 10
.p2align 4,, 7
L(return_vec_2):
/* bsf saves 1 byte over tzcnt and keeps L(return_vec_2) in one
fetch block. */
bsfl %eax, %eax
bsf %VRAX, %VRAX
# ifdef USE_AS_WMEMCMP
movl (VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %ecx
xorl %edx, %edx
@ -302,7 +321,7 @@ L(return_vec_2):
# endif
ret
.p2align 4
.p2align 4,, 8
L(more_8x_vec):
/* Set end of s1 in rdx. */
leaq -(VEC_SIZE * 4)(%rdi, %rdx, CHAR_SIZE), %rdx
@ -316,62 +335,82 @@ L(more_8x_vec):
.p2align 4
L(loop_4x_vec):
VMOVU (%rsi, %rdi), %YMM1
vpxorq (%rdi), %YMM1, %YMM1
VMOVU VEC_SIZE(%rsi, %rdi), %YMM2
vpxorq VEC_SIZE(%rdi), %YMM2, %YMM2
VMOVU (VEC_SIZE * 2)(%rsi, %rdi), %YMM3
vpxorq (VEC_SIZE * 2)(%rdi), %YMM3, %YMM3
VMOVU (VEC_SIZE * 3)(%rsi, %rdi), %YMM4
vpternlogd $0xde,(VEC_SIZE * 3)(%rdi), %YMM1, %YMM4
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4
VPTEST %YMM4, %YMM4, %k1
kmovd %k1, %ecx
testl %ecx, %ecx
VMOVU (%rsi, %rdi), %VMM(1)
vpxorq (%rdi), %VMM(1), %VMM(1)
VMOVU VEC_SIZE(%rsi, %rdi), %VMM(2)
vpxorq VEC_SIZE(%rdi), %VMM(2), %VMM(2)
VMOVU (VEC_SIZE * 2)(%rsi, %rdi), %VMM(3)
vpxorq (VEC_SIZE * 2)(%rdi), %VMM(3), %VMM(3)
VMOVU (VEC_SIZE * 3)(%rsi, %rdi), %VMM(4)
vpternlogd $0xde, (VEC_SIZE * 3)(%rdi), %VMM(1), %VMM(4)
vpternlogd $0xfe, %VMM(2), %VMM(3), %VMM(4)
VPTEST %VMM(4), %VMM(4), %k1
/* If VEC_SIZE == 64 just branch with KTEST. We have free port0
space and it allows the loop to fit in 2x cache lines
instead of 3. */
# if VEC_SIZE == 64
KTEST %k1, %k1
# else
KMOV %k1, %VRCX
test %VRCX, %VRCX
# endif
jnz L(8x_return_vec_0_1_2_3)
subq $-(VEC_SIZE * 4), %rdi
cmpq %rdx, %rdi
jb L(loop_4x_vec)
subq %rdx, %rdi
/* rdi has 4 * VEC_SIZE - remaining length. */
cmpl $(VEC_SIZE * 3), %edi
jae L(8x_last_1x_vec)
jge L(8x_last_1x_vec)
/* Load regardless of branch. */
VMOVU (VEC_SIZE * 2)(%rsi, %rdx), %YMM3
VMOVU (VEC_SIZE * 2)(%rsi, %rdx), %VMM(3)
/* Seperate logic as we can only use testb for VEC_SIZE == 64.
*/
# if VEC_SIZE == 64
testb %dil, %dil
js L(8x_last_2x_vec)
# else
cmpl $(VEC_SIZE * 2), %edi
jae L(8x_last_2x_vec)
jge L(8x_last_2x_vec)
# endif
vpxorq (VEC_SIZE * 2)(%rdx), %YMM3, %YMM3
vpxorq (VEC_SIZE * 2)(%rdx), %VMM(3), %VMM(3)
VMOVU (%rsi, %rdx), %YMM1
vpxorq (%rdx), %YMM1, %YMM1
VMOVU (%rsi, %rdx), %VMM(1)
vpxorq (%rdx), %VMM(1), %VMM(1)
VMOVU VEC_SIZE(%rsi, %rdx), %YMM2
vpxorq VEC_SIZE(%rdx), %YMM2, %YMM2
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %YMM4
vpternlogd $0xde,(VEC_SIZE * 3)(%rdx), %YMM1, %YMM4
vpternlogd $0xfe, %YMM2, %YMM3, %YMM4
VPTEST %YMM4, %YMM4, %k1
kmovd %k1, %ecx
testl %ecx, %ecx
VMOVU VEC_SIZE(%rsi, %rdx), %VMM(2)
vpxorq VEC_SIZE(%rdx), %VMM(2), %VMM(2)
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %VMM(4)
vpternlogd $0xde, (VEC_SIZE * 3)(%rdx), %VMM(1), %VMM(4)
vpternlogd $0xfe, %VMM(2), %VMM(3), %VMM(4)
VPTEST %VMM(4), %VMM(4), %k1
/* L(8x_end_return_vec_0_1_2_3) expects bitmask to still be in
`k1` if VEC_SIZE == 64. */
# if VEC_SIZE == 64
KTEST %k1, %k1
# else
KMOV %k1, %VRCX
test %VRCX, %VRCX
# endif
jnz L(8x_end_return_vec_0_1_2_3)
/* NB: eax must be zero to reach here. */
ret
/* Only entry is from L(more_8x_vec). */
.p2align 4,, 10
.p2align 4,, 6
L(8x_last_2x_vec):
VPCMP $4,(VEC_SIZE * 2)(%rdx), %YMM3, %k1
kmovd %k1, %eax
testl %eax, %eax
VPCMP $4, (VEC_SIZE * 2)(%rdx), %VMM(3), %k1
KMOV %k1, %VRAX
test %VRAX, %VRAX
jnz L(8x_return_vec_2)
/* Naturally aligned to 16 bytes. */
.p2align 4,, 5
L(8x_last_1x_vec):
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %YMM1
VPCMP $4,(VEC_SIZE * 3)(%rdx), %YMM1, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU (VEC_SIZE * 3)(%rsi, %rdx), %VMM(1)
VPCMP $4, (VEC_SIZE * 3)(%rdx), %VMM(1), %k1
KMOV %k1, %VRAX
test %VRAX, %VRAX
jnz L(8x_return_vec_3)
ret
@ -383,7 +422,7 @@ L(8x_last_1x_vec):
L(8x_return_vec_2):
subq $VEC_SIZE, %rdx
L(8x_return_vec_3):
bsfl %eax, %eax
bsf %VRAX, %VRAX
# ifdef USE_AS_WMEMCMP
leaq (%rdx, %rax, CHAR_SIZE), %rax
movl (VEC_SIZE * 3)(%rax), %ecx
@ -399,32 +438,34 @@ L(8x_return_vec_3):
# endif
ret
.p2align 4,, 10
.p2align 4,, 8
L(last_2x_vec):
/* Check second to last VEC. */
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx, CHAR_SIZE), %YMM1
VPCMP $4, -(VEC_SIZE * 2)(%rdi, %rdx, CHAR_SIZE), %YMM1, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU -(VEC_SIZE * 2)(%rsi, %rdx, CHAR_SIZE), %VMM(1)
VPCMP $4, -(VEC_SIZE * 2)(%rdi, %rdx, CHAR_SIZE), %VMM(1), %k1
KMOV %k1, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_1_end)
/* Check last VEC. */
.p2align 4
.p2align 4,, 8
L(last_1x_vec):
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx, CHAR_SIZE), %YMM1
VPCMP $4, -(VEC_SIZE * 1)(%rdi, %rdx, CHAR_SIZE), %YMM1, %k1
kmovd %k1, %eax
testl %eax, %eax
VMOVU -(VEC_SIZE * 1)(%rsi, %rdx, CHAR_SIZE), %VMM(1)
VPCMP $4, -(VEC_SIZE * 1)(%rdi, %rdx, CHAR_SIZE), %VMM(1), %k1
KMOV %k1, %VRAX
test %VRAX, %VRAX
jnz L(return_vec_0_end)
ret
/* Don't align. Takes 2-fetch blocks either way and aligning
will cause code to spill into another cacheline. */
/* Don't fully align. Takes 2-fetch blocks either way and
aligning will cause code to spill into another cacheline.
*/
.p2align 4,, 3
L(return_vec_1_end):
/* Use bsf to save code size. This is necessary to have
L(one_or_less) fit in aligning bytes between. */
bsfl %eax, %eax
bsf %VRAX, %VRAX
addl %edx, %eax
# ifdef USE_AS_WMEMCMP
movl -(VEC_SIZE * 2)(%rdi, %rax, CHAR_SIZE), %ecx
@ -439,10 +480,11 @@ L(return_vec_1_end):
# endif
ret
.p2align 4,, 2
/* Don't align. Takes 2-fetch blocks either way and aligning
will cause code to spill into another cacheline. */
L(return_vec_0_end):
tzcntl %eax, %eax
bsf %VRAX, %VRAX
addl %edx, %eax
# ifdef USE_AS_WMEMCMP
movl -VEC_SIZE(%rdi, %rax, CHAR_SIZE), %ecx
@ -456,7 +498,7 @@ L(return_vec_0_end):
subl %ecx, %eax
# endif
ret
/* 1-byte until next cache line. */
/* evex256: 2-byte until next cache line. evex512: 46-bytes
until next cache line. */
END (MEMCMP)
#endif