[llvm] [SLPVectorizer] Refactor HorizontalReduction::createOp (NFC) (PR #121549)

Mel Chen via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 00:24:27 PST 2025


https://github.com/Mel-Chen updated https://github.com/llvm/llvm-project/pull/121549

>From 3e593d0703a322b483b03ac1a4630d318f58948c Mon Sep 17 00:00:00 2001
From: Mel Chen <mel.chen at sifive.com>
Date: Thu, 2 Jan 2025 23:31:28 -0800
Subject: [PATCH 1/4] [SLPVectorizer] Refactor createOp

---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 58 +++++++------------
 1 file changed, 21 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36fed8937aec28..3069d771a10e82 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -19380,59 +19380,43 @@ class HorizontalReduction {
   /// Creates reduction operation with the current opcode.
   static Value *createOp(IRBuilderBase &Builder, RecurKind Kind, Value *LHS,
                          Value *RHS, const Twine &Name, bool UseSelect) {
-    unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+    if (UseSelect) {
+      if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind)) {
+        CmpInst::Predicate Pred = llvm::getMinMaxReductionPredicate(Kind);
+        Value *Cmp = Builder.CreateCmp(Pred, LHS, RHS, Name);
+        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+      }
+      if ((Kind == RecurKind::Or || Kind == RecurKind::And) &&
+          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType())) {
+        Value *TrueVal = Kind == RecurKind::Or ? Builder.getTrue() : RHS;
+        Value *FalseVal = Kind == RecurKind::Or ? RHS : Builder.getFalse();
+        return Builder.CreateSelect(LHS, TrueVal, FalseVal, Name);
+      }
+    }
+
     switch (Kind) {
     case RecurKind::Or:
-      if (UseSelect &&
-          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType()))
-        return Builder.CreateSelect(LHS, Builder.getTrue(), RHS, Name);
-      return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
-                                 Name);
     case RecurKind::And:
-      if (UseSelect &&
-          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType()))
-        return Builder.CreateSelect(LHS, RHS, Builder.getFalse(), Name);
-      return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
-                                 Name);
     case RecurKind::Add:
     case RecurKind::Mul:
     case RecurKind::Xor:
     case RecurKind::FAdd:
-    case RecurKind::FMul:
+    case RecurKind::FMul: {
+      unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
+    }
     case RecurKind::FMax:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
     case RecurKind::FMin:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
     case RecurKind::FMaximum:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS);
     case RecurKind::FMinimum:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS);
     case RecurKind::SMax:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::smax, LHS, RHS);
     case RecurKind::SMin:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::smin, LHS, RHS);
     case RecurKind::UMax:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::umax, LHS, RHS);
-    case RecurKind::UMin:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::umin, LHS, RHS);
+    case RecurKind::UMin: {
+      Intrinsic::ID Id = llvm::getMinMaxReductionIntrinsicOp(Kind);
+      return Builder.CreateBinaryIntrinsic(Id, LHS, RHS);
+    }
     default:
       llvm_unreachable("Unknown reduction operation.");
     }

>From 64e4f22226573da360d673b58e9cb7db9f39a549 Mon Sep 17 00:00:00 2001
From: Mel Chen <mel.chen at sifive.com>
Date: Wed, 8 Jan 2025 23:57:41 -0800
Subject: [PATCH 2/4] Revert "[SLPVectorizer] Refactor createOp"

This reverts commit 3e593d0703a322b483b03ac1a4630d318f58948c.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 58 ++++++++++++-------
 1 file changed, 37 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 3069d771a10e82..36fed8937aec28 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -19380,43 +19380,59 @@ class HorizontalReduction {
   /// Creates reduction operation with the current opcode.
   static Value *createOp(IRBuilderBase &Builder, RecurKind Kind, Value *LHS,
                          Value *RHS, const Twine &Name, bool UseSelect) {
-    if (UseSelect) {
-      if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind)) {
-        CmpInst::Predicate Pred = llvm::getMinMaxReductionPredicate(Kind);
-        Value *Cmp = Builder.CreateCmp(Pred, LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      if ((Kind == RecurKind::Or || Kind == RecurKind::And) &&
-          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType())) {
-        Value *TrueVal = Kind == RecurKind::Or ? Builder.getTrue() : RHS;
-        Value *FalseVal = Kind == RecurKind::Or ? RHS : Builder.getFalse();
-        return Builder.CreateSelect(LHS, TrueVal, FalseVal, Name);
-      }
-    }
-
+    unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
     switch (Kind) {
     case RecurKind::Or:
+      if (UseSelect &&
+          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType()))
+        return Builder.CreateSelect(LHS, Builder.getTrue(), RHS, Name);
+      return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
+                                 Name);
     case RecurKind::And:
+      if (UseSelect &&
+          LHS->getType() == CmpInst::makeCmpResultType(LHS->getType()))
+        return Builder.CreateSelect(LHS, RHS, Builder.getFalse(), Name);
+      return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
+                                 Name);
     case RecurKind::Add:
     case RecurKind::Mul:
     case RecurKind::Xor:
     case RecurKind::FAdd:
-    case RecurKind::FMul: {
-      unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind);
+    case RecurKind::FMul:
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
-    }
     case RecurKind::FMax:
+      return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
     case RecurKind::FMin:
+      return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
     case RecurKind::FMaximum:
+      return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS);
     case RecurKind::FMinimum:
+      return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS);
     case RecurKind::SMax:
+      if (UseSelect) {
+        Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
+        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+      }
+      return Builder.CreateBinaryIntrinsic(Intrinsic::smax, LHS, RHS);
     case RecurKind::SMin:
+      if (UseSelect) {
+        Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name);
+        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+      }
+      return Builder.CreateBinaryIntrinsic(Intrinsic::smin, LHS, RHS);
     case RecurKind::UMax:
-    case RecurKind::UMin: {
-      Intrinsic::ID Id = llvm::getMinMaxReductionIntrinsicOp(Kind);
-      return Builder.CreateBinaryIntrinsic(Id, LHS, RHS);
-    }
+      if (UseSelect) {
+        Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name);
+        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+      }
+      return Builder.CreateBinaryIntrinsic(Intrinsic::umax, LHS, RHS);
+    case RecurKind::UMin:
+      if (UseSelect) {
+        Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name);
+        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
+      }
+      return Builder.CreateBinaryIntrinsic(Intrinsic::umin, LHS, RHS);
     default:
       llvm_unreachable("Unknown reduction operation.");
     }

>From ab4c22266066c656bbd4c5c43111843bc811eec9 Mon Sep 17 00:00:00 2001
From: Mel Chen <mel.chen at sifive.com>
Date: Thu, 9 Jan 2025 00:04:07 -0800
Subject: [PATCH 3/4] Split patch

---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 38 ++++++-------------
 1 file changed, 12 insertions(+), 26 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36fed8937aec28..252f2e165d8bf3 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -19401,38 +19401,24 @@ class HorizontalReduction {
     case RecurKind::FMul:
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
-    case RecurKind::FMax:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
-    case RecurKind::FMin:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
-    case RecurKind::FMaximum:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS);
-    case RecurKind::FMinimum:
-      return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS);
     case RecurKind::SMax:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::smax, LHS, RHS);
     case RecurKind::SMin:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::smin, LHS, RHS);
     case RecurKind::UMax:
+    case RecurKind::UMin: {
       if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name);
-        return Builder.CreateSelect(Cmp, LHS, RHS, Name);
-      }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::umax, LHS, RHS);
-    case RecurKind::UMin:
-      if (UseSelect) {
-        Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name);
+        CmpInst::Predicate Pred = llvm::getMinMaxReductionPredicate(Kind);
+        Value *Cmp = Builder.CreateCmp(Pred, LHS, RHS, Name);
         return Builder.CreateSelect(Cmp, LHS, RHS, Name);
       }
-      return Builder.CreateBinaryIntrinsic(Intrinsic::umin, LHS, RHS);
+    }
+      [[fallthrough]];
+    case RecurKind::FMax:
+    case RecurKind::FMin:
+    case RecurKind::FMaximum:
+    case RecurKind::FMinimum: {
+      Intrinsic::ID Id = llvm::getMinMaxReductionIntrinsicOp(Kind);
+      return Builder.CreateBinaryIntrinsic(Id, LHS, RHS);
+    }
     default:
       llvm_unreachable("Unknown reduction operation.");
     }

>From d1a7fd9f13c463f75d962fb0d68967a5faa221a6 Mon Sep 17 00:00:00 2001
From: Mel Chen <mel.chen at sifive.com>
Date: Thu, 9 Jan 2025 00:22:08 -0800
Subject: [PATCH 4/4] Replace CreateCmp with CreateICmp

---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 252f2e165d8bf3..57b05571597278 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -19407,7 +19407,7 @@ class HorizontalReduction {
     case RecurKind::UMin: {
       if (UseSelect) {
         CmpInst::Predicate Pred = llvm::getMinMaxReductionPredicate(Kind);
-        Value *Cmp = Builder.CreateCmp(Pred, LHS, RHS, Name);
+        Value *Cmp = Builder.CreateICmp(Pred, LHS, RHS, Name);
         return Builder.CreateSelect(Cmp, LHS, RHS, Name);
       }
     }



More information about the llvm-commits mailing list