[llvm] [X86][AVX512] rematerialize smaller predicate masks (PR #166178)

Ahmed Nour via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 3 07:20:31 PST 2025


https://github.com/ahmednoursphinx updated https://github.com/llvm/llvm-project/pull/166178

>From 3cdf74e3eee0e206a3588acb1b4e53fd66eb9517 Mon Sep 17 00:00:00 2001
From: ahmed <ahmednour.mohamed2012 at gmail.com>
Date: Mon, 3 Nov 2025 17:16:29 +0200
Subject: [PATCH 1/2] fix: rematerialize smaller predicate masks

---
 llvm/lib/Target/X86/X86InstrAVX512.td        | 25 ++++++
 llvm/lib/Target/X86/X86InstrInfo.cpp         |  6 ++
 llvm/test/CodeGen/X86/avx512-mask-set-opt.ll | 93 ++++++++++++++++++++
 3 files changed, 124 insertions(+)
 create mode 100644 llvm/test/CodeGen/X86/avx512-mask-set-opt.ll

diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 1b748b7355716..9fae602974242 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -3161,6 +3161,12 @@ multiclass avx512_mask_setop_w<SDPatternOperator Val> {
 defm KSET0 : avx512_mask_setop_w<immAllZerosV>;
 defm KSET1 : avx512_mask_setop_w<immAllOnesV>;
 
+// 8-bit mask set operations for AVX512DQ
+let Predicates = [HasDQI] in {
+  defm KSET0B : avx512_mask_setop<VK8, v8i1, immAllZerosV>;
+  defm KSET1B : avx512_mask_setop<VK8, v8i1, immAllOnesV>;
+}
+
 // With AVX-512 only, 8-bit mask is promoted to 16-bit mask.
 let Predicates = [HasAVX512] in {
   def : Pat<(v8i1 immAllZerosV), (COPY_TO_REGCLASS (KSET0W), VK8)>;
@@ -3173,6 +3179,25 @@ let Predicates = [HasAVX512] in {
   def : Pat<(v1i1 immAllOnesV),  (COPY_TO_REGCLASS (KSET1W), VK1)>;
 }
 
+// With AVX512DQ, use 8-bit operations for 8-bit masks to avoid setting upper bits
+let Predicates = [HasDQI] in {
+  def : Pat<(v8i1 immAllZerosV), (KSET0B)>;
+  def : Pat<(v8i1 immAllOnesV),  (KSET1B)>;
+}
+
+// Optimize bitconvert of all-ones constants to use kxnor instructions
+let Predicates = [HasDQI] in {
+  def : Pat<(v8i1 (bitconvert (i8 255))), (KSET1B)>;
+  def : Pat<(v16i1 (bitconvert (i16 255))), (COPY_TO_REGCLASS (KSET1B), VK16)>;
+}
+let Predicates = [HasAVX512] in {
+  def : Pat<(v16i1 (bitconvert (i16 65535))), (KSET1W)>;
+}
+let Predicates = [HasBWI] in {
+  def : Pat<(v32i1 (bitconvert (i32 -1))), (KSET1D)>;
+  def : Pat<(v64i1 (bitconvert (i64 -1))), (KSET1Q)>;
+}
+
 // Patterns for kmask insert_subvector/extract_subvector to/from index=0
 multiclass operation_subvector_mask_lowering<RegisterClass subRC, ValueType subVT,
                                              RegisterClass RC, ValueType VT> {
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp
index 6b2a7a4ec3583..3eadac4f827bc 100644
--- a/llvm/lib/Target/X86/X86InstrInfo.cpp
+++ b/llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -789,9 +789,11 @@ bool X86InstrInfo::isReMaterializableImpl(
   case X86::FsFLD0SS:
   case X86::FsFLD0SH:
   case X86::FsFLD0F128:
+  case X86::KSET0B:
   case X86::KSET0D:
   case X86::KSET0Q:
   case X86::KSET0W:
+  case X86::KSET1B:
   case X86::KSET1D:
   case X86::KSET1Q:
   case X86::KSET1W:
@@ -6352,12 +6354,16 @@ bool X86InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   // registers, since it is not usable as a write mask.
   // FIXME: A more advanced approach would be to choose the best input mask
   // register based on context.
+  case X86::KSET0B:
+    return Expand2AddrKreg(MIB, get(X86::KXORBkk), X86::K0);
   case X86::KSET0W:
     return Expand2AddrKreg(MIB, get(X86::KXORWkk), X86::K0);
   case X86::KSET0D:
     return Expand2AddrKreg(MIB, get(X86::KXORDkk), X86::K0);
   case X86::KSET0Q:
     return Expand2AddrKreg(MIB, get(X86::KXORQkk), X86::K0);
+  case X86::KSET1B:
+    return Expand2AddrKreg(MIB, get(X86::KXNORBkk), X86::K0);
   case X86::KSET1W:
     return Expand2AddrKreg(MIB, get(X86::KXNORWkk), X86::K0);
   case X86::KSET1D:
diff --git a/llvm/test/CodeGen/X86/avx512-mask-set-opt.ll b/llvm/test/CodeGen/X86/avx512-mask-set-opt.ll
new file mode 100644
index 0000000000000..6a1a0af05d05c
--- /dev/null
+++ b/llvm/test/CodeGen/X86/avx512-mask-set-opt.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f | FileCheck %s --check-prefixes=AVX512F
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512dq | FileCheck %s --check-prefixes=AVX512DQ
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512bw | FileCheck %s --check-prefixes=AVX512BW
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f,+avx512dq,+avx512bw | FileCheck %s --check-prefixes=AVX512DQBW
+
+declare <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr>, i32, <16 x i1>, <16 x float>)
+
+; Test case 1: v16i1 with all bits set (should use kxnorw on all targets)
+define <16 x float> @gather_all(ptr %base, <16 x i32> %ind, i16 %mask) {
+; AVX512F-LABEL: gather_all:
+; AVX512F:       # %bb.0:
+; AVX512F-NEXT:    kxnorw %k0, %k0, %k1
+; AVX512F-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512F-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512F-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512F-NEXT:    retq
+;
+; AVX512DQ-LABEL: gather_all:
+; AVX512DQ:       # %bb.0:
+; AVX512DQ-NEXT:    kxnorw %k0, %k0, %k1
+; AVX512DQ-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512DQ-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512DQ-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512DQ-NEXT:    retq
+;
+; AVX512BW-LABEL: gather_all:
+; AVX512BW:       # %bb.0:
+; AVX512BW-NEXT:    kxnorw %k0, %k0, %k1
+; AVX512BW-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512BW-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512BW-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512BW-NEXT:    retq
+;
+; AVX512DQBW-LABEL: gather_all:
+; AVX512DQBW:       # %bb.0:
+; AVX512DQBW-NEXT:    kxnorw %k0, %k0, %k1
+; AVX512DQBW-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512DQBW-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512DQBW-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512DQBW-NEXT:    retq
+  %broadcast.splatinsert = insertelement <16 x ptr> undef, ptr %base, i32 0
+  %broadcast.splat = shufflevector <16 x ptr> %broadcast.splatinsert, <16 x ptr> undef, <16 x i32> zeroinitializer
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr float, <16 x ptr> %broadcast.splat, <16 x i64> %sext_ind
+  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float>undef)
+  ret <16 x float> %res
+}
+
+; Test case 2: v8i1 with lower 8 bits set (should use kxnorb on AVX512DQ targets)
+define <16 x float> @gather_lower(ptr %base, <16 x i32> %ind, i16 %mask) {
+; AVX512F-LABEL: gather_lower:
+; AVX512F:       # %bb.0:
+; AVX512F-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512F-NEXT:    movw $255, %ax
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512F-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512F-NEXT:    retq
+;
+; AVX512DQ-LABEL: gather_lower:
+; AVX512DQ:       # %bb.0:
+; AVX512DQ-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512DQ-NEXT:    kxnorb %k0, %k0, %k1
+; AVX512DQ-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512DQ-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512DQ-NEXT:    retq
+;
+; AVX512BW-LABEL: gather_lower:
+; AVX512BW:       # %bb.0:
+; AVX512BW-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512BW-NEXT:    movw $255, %ax
+; AVX512BW-NEXT:    kmovd %eax, %k1
+; AVX512BW-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512BW-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512BW-NEXT:    retq
+;
+; AVX512DQBW-LABEL: gather_lower:
+; AVX512DQBW:       # %bb.0:
+; AVX512DQBW-NEXT:    vxorps %xmm1, %xmm1, %xmm1
+; AVX512DQBW-NEXT:    kxnorb %k0, %k0, %k1
+; AVX512DQBW-NEXT:    vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; AVX512DQBW-NEXT:    vmovaps %zmm1, %zmm0
+; AVX512DQBW-NEXT:    retq
+  %broadcast.splatinsert = insertelement <16 x ptr> undef, ptr %base, i32 0
+  %broadcast.splat = shufflevector <16 x ptr> %broadcast.splatinsert, <16 x ptr> undef, <16 x i32> zeroinitializer
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr float, <16 x ptr> %broadcast.splat, <16 x i64> %sext_ind
+  %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0(<16 x ptr> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>, <16 x float>undef)
+  ret <16 x float> %res
+}
+
+

>From 4d2cfe334a511eb4e6baa7c98a34ea3e51ecd62d Mon Sep 17 00:00:00 2001
From: ahmed <ahmednour.mohamed2012 at gmail.com>
Date: Mon, 3 Nov 2025 17:20:20 +0200
Subject: [PATCH 2/2] chore: update formatting

---
 llvm/lib/Target/X86/X86InstrAVX512.td | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 9fae602974242..8a06296751f0d 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -3179,23 +3179,24 @@ let Predicates = [HasAVX512] in {
   def : Pat<(v1i1 immAllOnesV),  (COPY_TO_REGCLASS (KSET1W), VK1)>;
 }
 
-// With AVX512DQ, use 8-bit operations for 8-bit masks to avoid setting upper bits
+// With AVX512DQ, use 8-bit operations for 8-bit masks to avoid setting upper
+// bits
 let Predicates = [HasDQI] in {
   def : Pat<(v8i1 immAllZerosV), (KSET0B)>;
-  def : Pat<(v8i1 immAllOnesV),  (KSET1B)>;
+  def : Pat<(v8i1 immAllOnesV), (KSET1B)>;
 }
 
 // Optimize bitconvert of all-ones constants to use kxnor instructions
 let Predicates = [HasDQI] in {
-  def : Pat<(v8i1 (bitconvert (i8 255))), (KSET1B)>;
-  def : Pat<(v16i1 (bitconvert (i16 255))), (COPY_TO_REGCLASS (KSET1B), VK16)>;
+  def : Pat<(v8i1(bitconvert(i8 255))), (KSET1B)>;
+  def : Pat<(v16i1(bitconvert(i16 255))), (COPY_TO_REGCLASS(KSET1B), VK16)>;
 }
 let Predicates = [HasAVX512] in {
-  def : Pat<(v16i1 (bitconvert (i16 65535))), (KSET1W)>;
+  def : Pat<(v16i1(bitconvert(i16 65535))), (KSET1W)>;
 }
 let Predicates = [HasBWI] in {
-  def : Pat<(v32i1 (bitconvert (i32 -1))), (KSET1D)>;
-  def : Pat<(v64i1 (bitconvert (i64 -1))), (KSET1Q)>;
+  def : Pat<(v32i1(bitconvert(i32 -1))), (KSET1D)>;
+  def : Pat<(v64i1(bitconvert(i64 -1))), (KSET1Q)>;
 }
 
 // Patterns for kmask insert_subvector/extract_subvector to/from index=0



More information about the llvm-commits mailing list