[llvm] [RISCV][TTI] Add checks for invalid cast operations (PR #88854)

Shih-Po Hung via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 16 18:37:47 PDT 2024


https://github.com/arcbbb updated https://github.com/llvm/llvm-project/pull/88854

>From a5278c4275b0eaf2d8b2b331eca18a095d2eb628 Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Tue, 16 Apr 2024 00:59:45 -0700
Subject: [PATCH 1/2] [RISCV][TTI] Add checks for invalid cast operations

In issue #88802, the LV cost model would query the cost of
the TRUNC for source type 2xi1 and destination type 2xi32.
This patch adds an early exit check to prevent invalid operations.
---
 llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 38304ff90252f0..c4f1c275f63b65 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -956,6 +956,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     return getRISCVInstructionCost(Op, DstLT.second, CostKind);
   }
   case ISD::TRUNCATE:
+    // Early return for invalid operation
+    if (Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+      break;
     if (Dst->getScalarSizeInBits() == 1) {
       // We do not use several vncvt to truncate to mask vector. So we could
       // not use PowDiff to calculate it.
@@ -968,6 +971,13 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     [[fallthrough]];
   case ISD::FP_EXTEND:
   case ISD::FP_ROUND: {
+    // Early return for invalid operation
+    if ((ISD == ISD::FP_ROUND) &&
+        Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+      break;
+    if ((ISD == ISD::FP_EXTEND) &&
+        Src->getScalarSizeInBits() >= Dst->getScalarSizeInBits())
+      break;
     // Counts of narrow/widen instructions.
     unsigned SrcEltSize = Src->getScalarSizeInBits();
     unsigned DstEltSize = Dst->getScalarSizeInBits();

>From 145b59aa35216fab59158a2dde74ed92f29bc08c Mon Sep 17 00:00:00 2001
From: ShihPo Hung <shihpo.hung at sifive.com>
Date: Tue, 16 Apr 2024 18:23:07 -0700
Subject: [PATCH 2/2] Address comments: Add a test and create var. for
 getScalarSizeInBits()

---
 .../Target/RISCV/RISCVTargetTransformInfo.cpp | 24 ++++----
 .../RISCV/cost-on-invalid-cast.ll             | 55 +++++++++++++++++++
 2 files changed, 65 insertions(+), 14 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll

diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index c4f1c275f63b65..7362bdeab4e0ee 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -919,9 +919,11 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   if (!IsVectorType)
     return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
 
+  unsigned SrcEltSize = Src->getScalarSizeInBits();
+  unsigned DstEltSize = Dst->getScalarSizeInBits();
   bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
-                     (Src->getScalarSizeInBits() <= ST->getELen()) &&
-                     (Dst->getScalarSizeInBits() <= ST->getELen());
+                     (SrcEltSize <= ST->getELen()) &&
+                     (DstEltSize <= ST->getELen());
 
   // FIXME: Need to compute legalizing cost for illegal types.
   if (!IsTypeLegal)
@@ -933,12 +935,10 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
-  int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
-                (int)Log2_32(Src->getScalarSizeInBits());
+  int PowDiff = (int)Log2_32(DstEltSize) - (int)Log2_32(SrcEltSize);
   switch (ISD) {
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND: {
-    const unsigned SrcEltSize = Src->getScalarSizeInBits();
     if (SrcEltSize == 1) {
       // We do not use vsext/vzext to extend from mask vector.
       // Instead we use the following instructions to extend from mask vector:
@@ -957,9 +957,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   }
   case ISD::TRUNCATE:
     // Early return for invalid operation
-    if (Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+    if (DstEltSize >= SrcEltSize)
       break;
-    if (Dst->getScalarSizeInBits() == 1) {
+    if (DstEltSize == 1) {
       // We do not use several vncvt to truncate to mask vector. So we could
       // not use PowDiff to calculate it.
       // Instead we use the following instructions to truncate to mask vector:
@@ -972,15 +972,11 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   case ISD::FP_EXTEND:
   case ISD::FP_ROUND: {
     // Early return for invalid operation
-    if ((ISD == ISD::FP_ROUND) &&
-        Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+    if ((ISD == ISD::FP_ROUND) && DstEltSize >= SrcEltSize)
       break;
-    if ((ISD == ISD::FP_EXTEND) &&
-        Src->getScalarSizeInBits() >= Dst->getScalarSizeInBits())
+    if ((ISD == ISD::FP_EXTEND) && SrcEltSize >= DstEltSize)
       break;
     // Counts of narrow/widen instructions.
-    unsigned SrcEltSize = Src->getScalarSizeInBits();
-    unsigned DstEltSize = Dst->getScalarSizeInBits();
 
     unsigned Op = (ISD == ISD::TRUNCATE)    ? RISCV::VNSRL_WI
                   : (ISD == ISD::FP_EXTEND) ? RISCV::VFWCVT_F_F_V
@@ -1001,7 +997,7 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   case ISD::FP_TO_UINT:
   case ISD::SINT_TO_FP:
   case ISD::UINT_TO_FP:
-    if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+    if (SrcEltSize == 1 || DstEltSize == 1) {
       // The cost of convert from or to mask vector is different from other
       // cases. We could not use PowDiff to calculate it.
       // For mask vector to fp, we should use the following instructions:
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll b/llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll
new file mode 100644
index 00000000000000..16da04cdf7dc93
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/cost-on-invalid-cast.ll
@@ -0,0 +1,55 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=loop-vectorize -mtriple=riscv64 -mattr=+v -S 2>&1 | FileCheck %s
+
+define void @c() {
+; CHECK-LABEL: define void @c(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[FOR_COND:%.*]]
+; CHECK:       for.cond:
+; CHECK-NEXT:    [[F_0:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[ADD:%.*]], [[COND_END:%.*]] ]
+; CHECK-NEXT:    [[ADD]] = add i32 [[F_0]], 1
+; CHECK-NEXT:    br i1 false, label [[COND_FALSE:%.*]], label [[COND_TRUE:%.*]]
+; CHECK:       cond.true:
+; CHECK-NEXT:    [[CONV10:%.*]] = trunc i64 0 to i32
+; CHECK-NEXT:    br label [[COND_END]]
+; CHECK:       cond.false:
+; CHECK-NEXT:    [[TOBOOL15:%.*]] = zext i8 0 to i32
+; CHECK-NEXT:    br label [[COND_END]]
+; CHECK:       cond.end:
+; CHECK-NEXT:    [[COND:%.*]] = phi i32 [ [[CONV10]], [[COND_TRUE]] ], [ 0, [[COND_FALSE]] ]
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i32 [[COND]] to i8
+; CHECK-NEXT:    [[CONV17:%.*]] = and i8 [[TMP0]], 0
+; CHECK-NEXT:    store i8 [[CONV17]], ptr null, align 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[F_0]], 1
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_COND]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.cond
+
+for.cond:
+  %f.0 = phi i32 [ 0, %entry ], [ %add, %cond.end ]
+  %add = add i32 %f.0, 1
+  br i1 false, label %cond.false, label %cond.true
+
+cond.true:
+  %conv10 = trunc i64 0 to i32
+  br label %cond.end
+
+cond.false:
+  %tobool15 = zext i8 0 to i32
+  br label %cond.end
+
+cond.end:
+  %cond = phi i32 [ %conv10, %cond.true ], [ 0, %cond.false ]
+  %0 = trunc i32 %cond to i8
+  %conv17 = and i8 %0, 0
+  store i8 %conv17, ptr null, align 1
+  %cmp = icmp slt i32 %f.0, 1
+  br i1 %cmp, label %for.cond, label %for.cond.cleanup
+
+for.cond.cleanup:
+  ret void
+}



More information about the llvm-commits mailing list