[llvm] [SLP][REVEC] Fix the mismatch between the result of getAltInstrMask and the VecTy argument of TargetTransformInfo::isLegalAltInstr. (PR #134795)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 7 23:48:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Han-Kuan Chen (HanKuanChen)

<details>
<summary>Changes</summary>

We cannot determine ScalarTy from VL because some ScalarTy is determined
from VL[0]->getType(), while others are determined from
getValueType(VL[0]).

Fix "Mask and VecTy are incompatible".

---
Full diff: https://github.com/llvm/llvm-project/pull/134795.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+13-11) 
- (added) llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll (+47) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e6559f26be8c2..7e167f238b82e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1264,9 +1264,8 @@ static void fixupOrderingIndices(MutableArrayRef<unsigned> Order) {
 
 /// \returns a bitset for selecting opcodes. false for Opcode0 and true for
 /// Opcode1.
-static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, unsigned Opcode0,
-                                      unsigned Opcode1) {
-  Type *ScalarTy = VL[0]->getType();
+static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, Type *ScalarTy,
+                                      unsigned Opcode0, unsigned Opcode1) {
   unsigned ScalarTyNumElements = getNumElements(ScalarTy);
   SmallBitVector OpcodeMask(VL.size() * ScalarTyNumElements, false);
   for (unsigned Lane : seq<unsigned>(VL.size())) {
@@ -6667,11 +6666,12 @@ void BoUpSLP::reorderTopToBottom() {
     // to take into account their order when looking for the most used order.
     if (TE->hasState() && TE->isAltShuffle() &&
         TE->State != TreeEntry::SplitVectorize) {
-      VectorType *VecTy =
-          getWidenedType(TE->Scalars[0]->getType(), TE->Scalars.size());
+      Type *ScalarTy = TE->Scalars[0]->getType();
+      VectorType *VecTy = getWidenedType(ScalarTy, TE->Scalars.size());
       unsigned Opcode0 = TE->getOpcode();
       unsigned Opcode1 = TE->getAltOpcode();
-      SmallBitVector OpcodeMask(getAltInstrMask(TE->Scalars, Opcode0, Opcode1));
+      SmallBitVector OpcodeMask(
+          getAltInstrMask(TE->Scalars, ScalarTy, Opcode0, Opcode1));
       // If this pattern is supported by the target then we consider the order.
       if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
         VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get());
@@ -8352,12 +8352,13 @@ static bool isAlternateInstruction(const Instruction *I,
 
 bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
                                        ArrayRef<Value *> VL) const {
+  Type *ScalarTy = S.getMainOp()->getType();
   unsigned Opcode0 = S.getOpcode();
   unsigned Opcode1 = S.getAltOpcode();
-  SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
+  SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
   // If this pattern is supported by the target then consider it profitable.
-  if (TTI->isLegalAltInstr(getWidenedType(S.getMainOp()->getType(), VL.size()),
-                           Opcode0, Opcode1, OpcodeMask))
+  if (TTI->isLegalAltInstr(getWidenedType(ScalarTy, VL.size()), Opcode0,
+                           Opcode1, OpcodeMask))
     return true;
   SmallVector<ValueList> Operands;
   for (unsigned I : seq<unsigned>(S.getMainOp()->getNumOperands())) {
@@ -9270,7 +9271,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
     unsigned Opcode0 = LocalState.getOpcode();
     unsigned Opcode1 = LocalState.getAltOpcode();
-    SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
+    SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
     // Enable split node, only if all nodes do not form legal alternate
     // instruction (like X86 addsub).
     SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
@@ -13200,7 +13201,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       // order.
       unsigned Opcode0 = E->getOpcode();
       unsigned Opcode1 = E->getAltOpcode();
-      SmallBitVector OpcodeMask(getAltInstrMask(E->Scalars, Opcode0, Opcode1));
+      SmallBitVector OpcodeMask(
+          getAltInstrMask(E->Scalars, ScalarTy, Opcode0, Opcode1));
       // If this pattern is supported by the target then we consider the
       // order.
       if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll b/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
new file mode 100644
index 0000000000000..8380b1cb5f850
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec %s | FileCheck %s
+
+define i32 @test() {
+; CHECK-LABEL: @test(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD136:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD137:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD138:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = or <16 x i8> [[WIDE_LOAD]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i8> [[WIDE_LOAD136]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = or <16 x i8> [[WIDE_LOAD137]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = or <16 x i8> [[WIDE_LOAD138]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp ult <16 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = icmp ult <16 x i8> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = icmp ult <16 x i8> [[TMP4]], zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp ult <16 x i8> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = or <16 x i8> [[TMP0]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = or <16 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ult <16 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ult <16 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %wide.load = load <16 x i8>, ptr null, align 1
+  %wide.load136 = load <16 x i8>, ptr null, align 1
+  %wide.load137 = load <16 x i8>, ptr null, align 1
+  %wide.load138 = load <16 x i8>, ptr null, align 1
+  %0 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+  %1 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+  %2 = or <16 x i8> %wide.load, zeroinitializer
+  %3 = or <16 x i8> %wide.load136, zeroinitializer
+  %4 = or <16 x i8> %wide.load137, zeroinitializer
+  %5 = or <16 x i8> %wide.load138, zeroinitializer
+  %6 = icmp ult <16 x i8> %2, zeroinitializer
+  %7 = icmp ult <16 x i8> %3, zeroinitializer
+  %8 = icmp ult <16 x i8> %4, zeroinitializer
+  %9 = icmp ult <16 x i8> %5, zeroinitializer
+  %10 = or <16 x i8> %0, zeroinitializer
+  %11 = or <16 x i8> %1, zeroinitializer
+  %12 = icmp ult <16 x i8> %10, zeroinitializer
+  %13 = icmp ult <16 x i8> %11, zeroinitializer
+  ret i32 0
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/134795


More information about the llvm-commits mailing list