[llvm] 3fe6a06 - [LV] Check if compare is truncated directly in getInstructionCost.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 4 12:52:46 PDT 2024
Author: Florian Hahn
Date: 2024-09-04T20:50:06+01:00
New Revision: 3fe6a064f15cd854fd497594cc20e8b680cd2133
URL: https://github.com/llvm/llvm-project/commit/3fe6a064f15cd854fd497594cc20e8b680cd2133
DIFF: https://github.com/llvm/llvm-project/commit/3fe6a064f15cd854fd497594cc20e8b680cd2133.diff
LOG: [LV] Check if compare is truncated directly in getInstructionCost.
The current check for truncated compares in getInstructionCost misses
cases where either the first or both operands are constants.
Check directly if the compare is marked for truncation. In that case,
the minimum bitwidth is that of the operands.
The patch also adds asserts to ensure that.
This fixes a divergence between legacy and VPlan-based cost model, where
the legacy cost model incorrectly estimated the cost of compares with
truncated operands.
Fixes https://github.com/llvm/llvm-project/issues/107171.
Added:
Modified:
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-cost.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0200525a718d5f..0ccf442dac9993 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6606,9 +6606,20 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I,
case Instruction::ICmp:
case Instruction::FCmp: {
Type *ValTy = I->getOperand(0)->getType();
+
Instruction *Op0AsInstruction = dyn_cast<Instruction>(I->getOperand(0));
- if (canTruncateToMinimalBitwidth(Op0AsInstruction, VF))
- ValTy = IntegerType::get(ValTy->getContext(), MinBWs[Op0AsInstruction]);
+ (void)Op0AsInstruction;
+ assert((!canTruncateToMinimalBitwidth(Op0AsInstruction, VF) ||
+ canTruncateToMinimalBitwidth(I, VF)) &&
+ "truncating Op0 must imply truncating the compare");
+ if (canTruncateToMinimalBitwidth(I, VF)) {
+ assert(!canTruncateToMinimalBitwidth(Op0AsInstruction, VF) ||
+ MinBWs[I] == MinBWs[Op0AsInstruction] &&
+ "if both the operand and the compare are marked for "
+ "truncation, they must have the same bitwidth");
+ ValTy = IntegerType::get(ValTy->getContext(), MinBWs[I]);
+ }
+
VectorTy = ToVectorTy(ValTy, VF);
return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, nullptr,
cast<CmpInst>(I)->getPredicate(), CostKind,
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-cost.ll b/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-cost.ll
index 3e2f290a497db1..9fe5a2a6a3ecc2 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-cost.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/truncate-to-minimal-bitwidth-cost.ll
@@ -221,7 +221,56 @@ exit:
ret void
}
+; Test case for https://github.com/llvm/llvm-project/issues/107171.
+define i8 @icmp_ops_narrowed_to_i1() #1 {
+; CHECK-LABEL: define i8 @icmp_ops_narrowed_to_i1(
+; CHECK-SAME: ) #[[ATTR2:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK: [[VECTOR_PH]]:
+; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
+; CHECK: [[VECTOR_BODY]]:
+; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 32
+; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[INDEX_NEXT]], 96
+; CHECK-NEXT: br i1 [[TMP0]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
+; CHECK: [[MIDDLE_BLOCK]]:
+; CHECK-NEXT: br i1 false, label %[[EXIT:.*]], label %[[SCALAR_PH]]
+; CHECK: [[SCALAR_PH]]:
+; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i16 [ 96, %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
+; CHECK-NEXT: br label %[[LOOP:.*]]
+; CHECK: [[LOOP]]:
+; CHECK-NEXT: [[IV:%.*]] = phi i16 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
+; CHECK-NEXT: [[C:%.*]] = icmp eq i8 0, 0
+; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[C]] to i64
+; CHECK-NEXT: [[SHR:%.*]] = lshr i64 [[EXT]], 1
+; CHECK-NEXT: [[TRUNC:%.*]] = trunc i64 [[SHR]] to i8
+; CHECK-NEXT: [[IV_NEXT]] = add i16 [[IV]], 1
+; CHECK-NEXT: [[EC:%.*]] = icmp eq i16 [[IV_NEXT]], 100
+; CHECK-NEXT: br i1 [[EC]], label %[[EXIT]], label %[[LOOP]], !llvm.loop [[LOOP9:![0-9]+]]
+; CHECK: [[EXIT]]:
+; CHECK-NEXT: [[TRUNC_LCSSA:%.*]] = phi i8 [ [[TRUNC]], %[[LOOP]] ], [ 0, %[[MIDDLE_BLOCK]] ]
+; CHECK-NEXT: ret i8 [[TRUNC_LCSSA]]
+;
+entry:
+ br label %loop
+
+loop:
+ %iv = phi i16 [ 0, %entry ], [ %iv.next, %loop ]
+ %c = icmp eq i8 0, 0
+ %ext = zext i1 %c to i64
+ %shr = lshr i64 %ext, 1
+ %trunc = trunc i64 %shr to i8
+ %iv.next = add i16 %iv, 1
+ %ec = icmp eq i16 %iv.next, 100
+ br i1 %ec, label %exit, label %loop
+
+exit:
+ ret i8 %trunc
+}
+
attributes #0 = { "target-features"="+64bit,+v,+zvl256b" }
+attributes #1 = { "target-features"="+64bit,+v" }
;.
; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
@@ -232,4 +281,6 @@ attributes #0 = { "target-features"="+64bit,+v,+zvl256b" }
; CHECK: [[LOOP5]] = distinct !{[[LOOP5]], [[META2]], [[META1]]}
; CHECK: [[LOOP6]] = distinct !{[[LOOP6]], [[META1]], [[META2]]}
; CHECK: [[LOOP7]] = distinct !{[[LOOP7]], [[META2]], [[META1]]}
+; CHECK: [[LOOP8]] = distinct !{[[LOOP8]], [[META1]], [[META2]]}
+; CHECK: [[LOOP9]] = distinct !{[[LOOP9]], [[META2]], [[META1]]}
;.
More information about the llvm-commits
mailing list