[llvm] [SystemZ] Fix Operand Retrieval for Vector Reduction Intrinsic in `shouldExpandReduction` (PR #88874)

Dominik Steenken via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 17 04:11:37 PDT 2024


https://github.com/dominik-steenken updated https://github.com/llvm/llvm-project/pull/88874

>From 5bb6fbeb52884de54aad380caaecf64f782c2fb8 Mon Sep 17 00:00:00 2001
From: Dominik Steenken <dost at de.ibm.com>
Date: Mon, 15 Apr 2024 22:03:53 +0200
Subject: [PATCH 1/4] [SystemZ] Fix Operand Retrieval for Vector Reduction
 Intrinsic

In the existing version, SystemZTTIImpl::shouldExpandReduction will
create a `cast` error when handling vector reduction intrinsics that
do not have the vector to reduce as their first operand, such as
`llvm.vector.reduce.fadd` and `llvm.vector.reduce.fmul`.
This commit fixes that problem by introducing a short loop to find
the vector operand instead of assuming that it is the first operand.
---
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 39 ++++++++++++++-----
 1 file changed, 29 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 4c9e78c05dbcac..5da42d7cccd3c1 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -18,6 +18,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/CostTable.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/Support/Debug.h"
@@ -1323,25 +1324,43 @@ SystemZTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
 }
 
 bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
-  // Always expand on Subtargets without vector instructions
+  // Always expand on Subtargets without vector instructions.
   if (!ST->hasVector())
     return true;
 
-  // Always expand for operands that do not fill one vector reg
-  auto *Type = cast<FixedVectorType>(II->getOperand(0)->getType());
-  unsigned NumElts = Type->getNumElements();
-  unsigned ScalarSize = Type->getScalarSizeInBits();
+  // Find the type of the vector operand of the intrinsic
+  // This assumes that each vector reduction intrinsic only
+  // has one vector operand.
+  FixedVectorType *VType = 0x0;
+  for (unsigned I = 0; I < II->getNumOperands(); ++I) {
+    auto *T = II->getOperand(I)->getType();
+    if (T->isVectorTy()) {
+      VType = cast<FixedVectorType>(T);
+      break;
+    }
+  }
+
+  // If we did not find a vector operand, do not continue.
+  if (VType == 0x0)
+    return true;
+
+  // If the vector operand is not a full vector, the reduction
+  // should be expanded.
+  unsigned NumElts = VType->getNumElements();
+  unsigned ScalarSize = VType->getScalarSizeInBits();
   unsigned MaxElts = SystemZ::VectorBits / ScalarSize;
   if (NumElts < MaxElts)
     return true;
 
-  // Otherwise
+  // Handling of full vector operands depends on the
+  // individual intrinsic.
   switch (II->getIntrinsicID()) {
-  // Do not expand vector.reduce.add
-  case Intrinsic::vector_reduce_add:
-    // Except for i64, since the performance benefit is dubious there
-    return ScalarSize >= 64;
   default:
     return true;
+  // Do not expand vector.reduce.add...
+  case Intrinsic::vector_reduce_add:
+    // ...unless the scalar size is i64 or larger, since the
+    // performance benefit is dubious there
+    return ScalarSize >= 64;
   }
 }

>From 2c47c58396a0beafacc90648cef9cc6ad089f2e8 Mon Sep 17 00:00:00 2001
From: Dominik Steenken <dost at de.ibm.com>
Date: Tue, 16 Apr 2024 13:52:39 +0200
Subject: [PATCH 2/4] Use nullptr instead of 0x0

---
 llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 5da42d7cccd3c1..b930a304eb8a2e 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -1331,7 +1331,7 @@ bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
   // Find the type of the vector operand of the intrinsic
   // This assumes that each vector reduction intrinsic only
   // has one vector operand.
-  FixedVectorType *VType = 0x0;
+  FixedVectorType *VType = nullptr;
   for (unsigned I = 0; I < II->getNumOperands(); ++I) {
     auto *T = II->getOperand(I)->getType();
     if (T->isVectorTy()) {
@@ -1341,7 +1341,7 @@ bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
   }
 
   // If we did not find a vector operand, do not continue.
-  if (VType == 0x0)
+  if (VType == nullptr)
     return true;
 
   // If the vector operand is not a full vector, the reduction

>From 8661e9a947ca9304264cacc94d65864d930d53ba Mon Sep 17 00:00:00 2001
From: Dominik Steenken <dost at de.ibm.com>
Date: Wed, 17 Apr 2024 11:09:21 +0200
Subject: [PATCH 3/4] Restruture and only handle by opcode

---
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 46 +++++++++----------
 1 file changed, 21 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index b930a304eb8a2e..74fc00cb9dfe9e 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -1323,44 +1323,40 @@ SystemZTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
 }
 
-bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
-  // Always expand on Subtargets without vector instructions.
-  if (!ST->hasVector())
-    return true;
-
-  // Find the type of the vector operand of the intrinsic
-  // This assumes that each vector reduction intrinsic only
-  // has one vector operand.
-  FixedVectorType *VType = nullptr;
+// Find the type of the first vector operand of the intrinsic.
+// Returns nullptr in case no vector operand was found.
+FixedVectorType *getFirstOperandVectorType(const IntrinsicInst *II) {
   for (unsigned I = 0; I < II->getNumOperands(); ++I) {
     auto *T = II->getOperand(I)->getType();
     if (T->isVectorTy()) {
-      VType = cast<FixedVectorType>(T);
-      break;
+      return cast<FixedVectorType>(T);
     }
   }
+  return nullptr;
+}
 
-  // If we did not find a vector operand, do not continue.
-  if (VType == nullptr)
-    return true;
+// determine if the given vector type represents a full
+// machine vector register.
+bool isVectorFull(FixedVectorType *VType) {
+  unsigned MaxElts = SystemZ::VectorBits / VType->getScalarSizeInBits();
+  return VType->getNumElements() >= MaxElts;
+}
 
-  // If the vector operand is not a full vector, the reduction
-  // should be expanded.
-  unsigned NumElts = VType->getNumElements();
-  unsigned ScalarSize = VType->getScalarSizeInBits();
-  unsigned MaxElts = SystemZ::VectorBits / ScalarSize;
-  if (NumElts < MaxElts)
+bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
+  // Always expand on Subtargets without vector instructions.
+  if (!ST->hasVector())
     return true;
 
-  // Handling of full vector operands depends on the
-  // individual intrinsic.
+  // Whether or not to expand is a per-intrinsic decision.
   switch (II->getIntrinsicID()) {
   default:
     return true;
   // Do not expand vector.reduce.add...
   case Intrinsic::vector_reduce_add:
-    // ...unless the scalar size is i64 or larger, since the
-    // performance benefit is dubious there
-    return ScalarSize >= 64;
+    auto *VType = getFirstOperandVectorType(II);
+    // ...unless the scalar size is i64 or larger,
+    // or the operand vector is not full, since the
+    // performance benefit is dubious in those cases
+    return (VType->getScalarSizeInBits() >= 64) || not isVectorFull(VType);
   }
 }

>From b1868f2d8a7a2a6f1a2c491905bc20424644db12 Mon Sep 17 00:00:00 2001
From: Dominik Steenken <dost at de.ibm.com>
Date: Wed, 17 Apr 2024 13:11:06 +0200
Subject: [PATCH 4/4] move operand choice into case switch

---
 .../SystemZ/SystemZTargetTransformInfo.cpp    | 28 ++++++-------------
 1 file changed, 9 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
index 74fc00cb9dfe9e..76abfad654437b 100644
--- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp
@@ -1323,23 +1323,12 @@ SystemZTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
 }
 
-// Find the type of the first vector operand of the intrinsic.
-// Returns nullptr in case no vector operand was found.
-FixedVectorType *getFirstOperandVectorType(const IntrinsicInst *II) {
-  for (unsigned I = 0; I < II->getNumOperands(); ++I) {
-    auto *T = II->getOperand(I)->getType();
-    if (T->isVectorTy()) {
-      return cast<FixedVectorType>(T);
-    }
-  }
-  return nullptr;
-}
-
-// determine if the given vector type represents a full
-// machine vector register.
-bool isVectorFull(FixedVectorType *VType) {
-  unsigned MaxElts = SystemZ::VectorBits / VType->getScalarSizeInBits();
-  return VType->getNumElements() >= MaxElts;
+// Find the type of the vector operand indicated by index.
+// Asserts that the operand indicated is actually a vector.
+FixedVectorType *getOperandVectorType(const IntrinsicInst *II, unsigned Index) {
+  auto *T = II->getOperand(Index)->getType();
+  assert (T->isVectorTy());
+  return cast<FixedVectorType>(T);
 }
 
 bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
@@ -1353,10 +1342,11 @@ bool SystemZTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
     return true;
   // Do not expand vector.reduce.add...
   case Intrinsic::vector_reduce_add:
-    auto *VType = getFirstOperandVectorType(II);
+    auto *VType = getOperandVectorType(II, 0);
     // ...unless the scalar size is i64 or larger,
     // or the operand vector is not full, since the
     // performance benefit is dubious in those cases
-    return (VType->getScalarSizeInBits() >= 64) || not isVectorFull(VType);
+    return (VType->getScalarSizeInBits() >= 64) ||
+           VType->getPrimitiveSizeInBits() < SystemZ::VectorBits;
   }
 }



More information about the llvm-commits mailing list