[llvm] 07a5667 - [SLP]Fix PR87477: fix alternate node cast cost/codegen.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 3 10:14:39 PDT 2024


Author: Alexey Bataev
Date: 2024-04-03T10:00:03-07:00
New Revision: 07a566793b2f94d0de6b95b7e6d1146b0d7ffe49

URL: https://github.com/llvm/llvm-project/commit/07a566793b2f94d0de6b95b7e6d1146b0d7ffe49
DIFF: https://github.com/llvm/llvm-project/commit/07a566793b2f94d0de6b95b7e6d1146b0d7ffe49.diff

LOG: [SLP]Fix PR87477: fix alternate node cast cost/codegen.

Have to compare actual type size to pick up proper cast operation
opcode.

Added: 
    llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index cb55992051ebf0..7928d29d6dfa7d 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9063,25 +9063,35 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
             cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind,
             E->getAltOp());
       } else {
-        Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType();
-        Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType();
-        auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size());
-        auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size());
-        if (It != MinBWs.end()) {
-          if (!MinBWs.contains(getOperandEntry(E, 0)))
-            VecCost =
-                TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, Src0Ty,
-                                        TTI::CastContextHint::None, CostKind);
-          LLVM_DEBUG({
-            dbgs() << "SLP: alternate extension, which should be truncated.\n";
-            E->dump();
-          });
-          return VecCost;
+        Type *SrcSclTy = E->getMainOp()->getOperand(0)->getType();
+        auto *SrcTy = FixedVectorType::get(SrcSclTy, VL.size());
+        if (SrcSclTy->isIntegerTy() && ScalarTy->isIntegerTy()) {
+          auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
+          unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+          unsigned SrcBWSz =
+              DL->getTypeSizeInBits(E->getMainOp()->getOperand(0)->getType());
+          if (SrcIt != MinBWs.end()) {
+            SrcBWSz = SrcIt->second.first;
+            SrcSclTy = IntegerType::get(SrcSclTy->getContext(), SrcBWSz);
+            SrcTy = FixedVectorType::get(SrcSclTy, VL.size());
+          }
+          if (BWSz <= SrcBWSz) {
+            if (BWSz < SrcBWSz)
+              VecCost =
+                  TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, SrcTy,
+                                          TTI::CastContextHint::None, CostKind);
+            LLVM_DEBUG({
+              dbgs()
+                  << "SLP: alternate extension, which should be truncated.\n";
+              E->dump();
+            });
+            return VecCost;
+          }
         }
-        VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, Src0Ty,
+        VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, SrcTy,
                                           TTI::CastContextHint::None, CostKind);
         VecCost +=
-            TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty,
+            TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, SrcTy,
                                     TTI::CastContextHint::None, CostKind);
       }
       SmallVector<int> Mask;
@@ -12591,15 +12601,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         CmpInst::Predicate AltPred = AltCI->getPredicate();
         V1 = Builder.CreateCmp(AltPred, LHS, RHS);
       } else {
-        if (It != MinBWs.end()) {
-          if (!MinBWs.contains(getOperandEntry(E, 0)))
-            LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
-          assert(LHS->getType() == VecTy && "Expected same type as operand.");
-          if (auto *I = dyn_cast<Instruction>(LHS))
-            LHS = propagateMetadata(I, E->Scalars);
-          E->VectorizedValue = LHS;
-          ++NumVectorInstructions;
-          return LHS;
+        if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) {
+          unsigned SrcBWSz = DL->getTypeSizeInBits(
+              cast<VectorType>(LHS->getType())->getElementType());
+          unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+          if (BWSz <= SrcBWSz) {
+            if (BWSz < SrcBWSz)
+              LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
+            assert(LHS->getType() == VecTy && "Expected same type as operand.");
+            if (auto *I = dyn_cast<Instruction>(LHS))
+              LHS = propagateMetadata(I, E->Scalars);
+            E->VectorizedValue = LHS;
+            ++NumVectorInstructions;
+            return LHS;
+          }
         }
         V0 = Builder.CreateCast(
             static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);

diff  --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll
new file mode 100644
index 00000000000000..979d0ea66bac9a
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/ext-alt-node-must-ext.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S --passes=slp-vectorizer -mtriple=systemz-unknown -mcpu=z15 < %s -slp-threshold=-10 | FileCheck %s
+
+define i32 @test(ptr %0, ptr %1) {
+; CHECK-LABEL: define i32 @test(
+; CHECK-SAME: ptr [[TMP0:%.*]], ptr [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[TMP3:%.*]] = load i64, ptr inttoptr (i64 32 to ptr), align 32
+; CHECK-NEXT:    [[TMP4:%.*]] = load ptr, ptr [[TMP1]], align 8
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i64 32
+; CHECK-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <2 x i64> poison, i64 [[TMP6]], i32 0
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <2 x i64> [[TMP7]], i64 [[TMP3]], i32 1
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp ne <2 x i64> [[TMP14]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = sext <2 x i1> [[TMP9]] to <2 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <2 x i1> [[TMP9]] to <2 x i8>
+; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <2 x i8> [[TMP16]], <2 x i8> [[TMP11]], <2 x i32> <i32 0, i32 3>
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <2 x i8> [[TMP12]], i32 0
+; CHECK-NEXT:    [[DOTNEG:%.*]] = sext i8 [[TMP13]] to i32
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <2 x i8> [[TMP12]], i32 1
+; CHECK-NEXT:    [[TMP8:%.*]] = sext i8 [[TMP15]] to i32
+; CHECK-NEXT:    [[TMP10:%.*]] = add nsw i32 [[DOTNEG]], [[TMP8]]
+; CHECK-NEXT:    ret i32 [[TMP10]]
+;
+  %3 = load i64, ptr inttoptr (i64 32 to ptr), align 32
+  %4 = load ptr, ptr %1, align 8
+  %5 = getelementptr inbounds i8, ptr %4, i64 32
+  %6 = load i64, ptr %5, align 8
+  %7 = icmp ne i64 %3, 0
+  %8 = zext i1 %7 to i32
+  %9 = icmp ne i64 %6, 0
+  %.neg = sext i1 %9 to i32
+  %10 = add nsw i32 %.neg, %8
+  ret i32 %10
+}


        


More information about the llvm-commits mailing list