[llvm] [RISCV][TTI] Add checks for invalid cast operations (PR #88854)
Shih-Po Hung via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 16 18:37:47 PDT 2024
https://github.com/arcbbb updated https://github.com/llvm/llvm-project/pull/88854
>From a5278c4275b0eaf2d8b2b331eca18a095d2eb628 Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Tue, 16 Apr 2024 00:59:45 -0700
Subject: [PATCH 1/2] [RISCV][TTI] Add checks for invalid cast operations
In issue #88802, the LV cost model would query the cost of
the TRUNC for source type 2xi1 and destination type 2xi32.
This patch adds an early exit check to prevent invalid operations.
---
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 38304ff90252f0..c4f1c275f63b65 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -956,6 +956,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return getRISCVInstructionCost(Op, DstLT.second, CostKind);
}
case ISD::TRUNCATE:
+ // Early return for invalid operation
+ if (Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+ break;
if (Dst->getScalarSizeInBits() == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
@@ -968,6 +971,13 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
[[fallthrough]];
case ISD::FP_EXTEND:
case ISD::FP_ROUND: {
+ // Early return for invalid operation
+ if ((ISD == ISD::FP_ROUND) &&
+ Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+ break;
+ if ((ISD == ISD::FP_EXTEND) &&
+ Src->getScalarSizeInBits() >= Dst->getScalarSizeInBits())
+ break;
// Counts of narrow/widen instructions.
unsigned SrcEltSize = Src->getScalarSizeInBits();
unsigned DstEltSize = Dst->getScalarSizeInBits();
>From 145b59aa35216fab59158a2dde74ed92f29bc08c Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Tue, 16 Apr 2024 18:23:07 -0700
Subject: [PATCH 2/2] Address comments: Add a test and create var. for
getScalarSizeInBits()
---
.../Target/RISCV/RISCVTargetTransformInfo.cpp | 24 ++++----
.../RISCV/cost-on-invalid-cast.ll | 55 +++++++++++++++++++
2 files changed, 65 insertions(+), 14 deletions(-)
create mode 100644 llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index c4f1c275f63b65..7362bdeab4e0ee 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -919,9 +919,11 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
if (!IsVectorType)
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+ unsigned SrcEltSize = Src->getScalarSizeInBits();
+ unsigned DstEltSize = Dst->getScalarSizeInBits();
bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
- (Src->getScalarSizeInBits() <= ST->getELen()) &&
- (Dst->getScalarSizeInBits() <= ST->getELen());
+ (SrcEltSize <= ST->getELen()) &&
+ (DstEltSize <= ST->getELen());
// FIXME: Need to compute legalizing cost for illegal types.
if (!IsTypeLegal)
@@ -933,12 +935,10 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
int ISD = TLI->InstructionOpcodeToISD(Opcode);
assert(ISD && "Invalid opcode");
- int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
- (int)Log2_32(Src->getScalarSizeInBits());
+ int PowDiff = (int)Log2_32(DstEltSize) - (int)Log2_32(SrcEltSize);
switch (ISD) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND: {
- const unsigned SrcEltSize = Src->getScalarSizeInBits();
if (SrcEltSize == 1) {
// We do not use vsext/vzext to extend from mask vector.
// Instead we use the following instructions to extend from mask vector:
@@ -957,9 +957,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
}
case ISD::TRUNCATE:
// Early return for invalid operation
- if (Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+ if (DstEltSize >= SrcEltSize)
break;
- if (Dst->getScalarSizeInBits() == 1) {
+ if (DstEltSize == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
// Instead we use the following instructions to truncate to mask vector:
@@ -972,15 +972,11 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
case ISD::FP_EXTEND:
case ISD::FP_ROUND: {
// Early return for invalid operation
- if ((ISD == ISD::FP_ROUND) &&
- Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+ if ((ISD == ISD::FP_ROUND) && DstEltSize >= SrcEltSize)
break;
- if ((ISD == ISD::FP_EXTEND) &&
- Src->getScalarSizeInBits() >= Dst->getScalarSizeInBits())
+ if ((ISD == ISD::FP_EXTEND) && SrcEltSize >= DstEltSize)
break;
// Counts of narrow/widen instructions.
- unsigned SrcEltSize = Src->getScalarSizeInBits();
- unsigned DstEltSize = Dst->getScalarSizeInBits();
unsigned Op = (ISD == ISD::TRUNCATE) ? RISCV::VNSRL_WI
: (ISD == ISD::FP_EXTEND) ? RISCV::VFWCVT_F_F_V
@@ -1001,7 +997,7 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
- if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+ if (SrcEltSize == 1 || DstEltSize == 1) {
// The cost of convert from or to mask vector is different from other
// cases. We could not use PowDiff to calculate it.
// For mask vector to fp, we should use the following instructions:
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll b/llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll
new file mode 100644
index 00000000000000..16da04cdf7dc93
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll
@@ -0,0 +1,55 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=loop-vectorize -mtriple=riscv64 -mattr=+v -S 2>&1 | FileCheck %s
+
+define void @c() {
+; CHECK-LABEL: define void @c(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[FOR_COND:%.*]]
+; CHECK: for.cond:
+; CHECK-NEXT: [[F_0:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[ADD:%.*]], [[COND_END:%.*]] ]
+; CHECK-NEXT: [[ADD]] = add i32 [[F_0]], 1
+; CHECK-NEXT: br i1 false, label [[COND_FALSE:%.*]], label [[COND_TRUE:%.*]]
+; CHECK: cond.true:
+; CHECK-NEXT: [[CONV10:%.*]] = trunc i64 0 to i32
+; CHECK-NEXT: br label [[COND_END]]
+; CHECK: cond.false:
+; CHECK-NEXT: [[TOBOOL15:%.*]] = zext i8 0 to i32
+; CHECK-NEXT: br label [[COND_END]]
+; CHECK: cond.end:
+; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[CONV10]], [[COND_TRUE]] ], [ 0, [[COND_FALSE]] ]
+; CHECK-NEXT: [[TMP0:%.*]] = trunc i32 [[COND]] to i8
+; CHECK-NEXT: [[CONV17:%.*]] = and i8 [[TMP0]], 0
+; CHECK-NEXT: store i8 [[CONV17]], ptr null, align 1
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[F_0]], 1
+; CHECK-NEXT: br i1 [[CMP]], label [[FOR_COND]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK: for.cond.cleanup:
+; CHECK-NEXT: ret void
+;
+entry:
+ br label %for.cond
+
+for.cond:
+ %f.0 = phi i32 [ 0, %entry ], [ %add, %cond.end ]
+ %add = add i32 %f.0, 1
+ br i1 false, label %cond.false, label %cond.true
+
+cond.true:
+ %conv10 = trunc i64 0 to i32
+ br label %cond.end
+
+cond.false:
+ %tobool15 = zext i8 0 to i32
+ br label %cond.end
+
+cond.end:
+ %cond = phi i32 [ %conv10, %cond.true ], [ 0, %cond.false ]
+ %0 = trunc i32 %cond to i8
+ %conv17 = and i8 %0, 0
+ store i8 %conv17, ptr null, align 1
+ %cmp = icmp slt i32 %f.0, 1
+ br i1 %cmp, label %for.cond, label %for.cond.cleanup
+
+for.cond.cleanup:
+ ret void
+}
More information about the llvm-commits
mailing list