[llvm] [IRBuilder] Fold binary intrinsics (PR #80743)

Artem Tyurin via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 08:56:00 PST 2024


https://github.com/agentcooper updated https://github.com/llvm/llvm-project/pull/80743

>From 84a1fd55bb77d967e6a9195388f763d4af8e58f2 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Mon, 5 Feb 2024 22:25:43 +0100
Subject: [PATCH 1/2] [IRBuilder] Fold binary intrinsics

Fixes https://github.com/llvm/llvm-project/issues/61240.
---
 .../llvm/Analysis/InstSimplifyFolder.h        |  6 ++++++
 llvm/include/llvm/Analysis/TargetFolder.h     |  5 +++++
 llvm/include/llvm/IR/ConstantFolder.h         | 20 +++++++++++++++++++
 llvm/include/llvm/IR/IRBuilderFolder.h        |  3 +++
 llvm/include/llvm/IR/NoFolder.h               |  4 ++++
 llvm/lib/IR/IRBuilder.cpp                     |  2 ++
 .../Transforms/Vectorize/SLPVectorizer.cpp    |  8 --------
 7 files changed, 40 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/Analysis/InstSimplifyFolder.h b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
index 23e2ea80e8cbe..d8ec313b49467 100644
--- a/llvm/include/llvm/Analysis/InstSimplifyFolder.h
+++ b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
@@ -117,6 +117,12 @@ class InstSimplifyFolder final : public IRBuilderFolder {
     return simplifyCastInst(Op, V, DestTy, SQ);
   }
 
+  Value *FoldBinaryIntrinsics(Intrinsic::ID ID, Value *LHS,
+                              Value *RHS) const override {
+    // TODO: should this be defined?
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/Analysis/TargetFolder.h b/llvm/include/llvm/Analysis/TargetFolder.h
index 978e1002515fc..cabce12541435 100644
--- a/llvm/include/llvm/Analysis/TargetFolder.h
+++ b/llvm/include/llvm/Analysis/TargetFolder.h
@@ -191,6 +191,11 @@ class TargetFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
+  Value *FoldBinaryIntrinsics(Intrinsic::ID ID, Value *LHS, Value *RHS) const override {
+    // TODO: should this be defined?
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/ConstantFolder.h b/llvm/include/llvm/IR/ConstantFolder.h
index c2b30a65e32e2..001d7b1fdf44e 100644
--- a/llvm/include/llvm/IR/ConstantFolder.h
+++ b/llvm/include/llvm/IR/ConstantFolder.h
@@ -183,6 +183,26 @@ class ConstantFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
+  Value *FoldBinaryIntrinsics(Intrinsic::ID ID, Value *LHS,
+                              Value *RHS) const override {
+    auto *LC = dyn_cast<Constant>(LHS);
+    auto *RC = dyn_cast<Constant>(RHS);
+    if (LC && RC) {
+      if (ID == Intrinsic::maxnum) {
+        return ConstantFP::get(LHS->getType(),
+                               maxnum(cast<ConstantFP>(LHS)->getValueAPF(),
+                                      cast<ConstantFP>(RHS)->getValueAPF()));
+      }
+      if (ID == Intrinsic::minnum) {
+        return ConstantFP::get(LHS->getType(),
+                               minnum(cast<ConstantFP>(LHS)->getValueAPF(),
+                                      cast<ConstantFP>(RHS)->getValueAPF()));
+      }
+      // TODO: use switch, handle more intrinsics
+    }
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/IRBuilderFolder.h b/llvm/include/llvm/IR/IRBuilderFolder.h
index bd2324dfc5f1b..3e0fd093b3f66 100644
--- a/llvm/include/llvm/IR/IRBuilderFolder.h
+++ b/llvm/include/llvm/IR/IRBuilderFolder.h
@@ -73,6 +73,9 @@ class IRBuilderFolder {
   virtual Value *FoldCast(Instruction::CastOps Op, Value *V,
                           Type *DestTy) const = 0;
 
+  virtual Value *FoldBinaryIntrinsics(Intrinsic::ID ID, Value *LHS,
+                                      Value *RHS) const = 0;
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/NoFolder.h b/llvm/include/llvm/IR/NoFolder.h
index a612f98465aea..787dd10b4cbaa 100644
--- a/llvm/include/llvm/IR/NoFolder.h
+++ b/llvm/include/llvm/IR/NoFolder.h
@@ -112,6 +112,10 @@ class NoFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
+  Value *FoldBinaryIntrinsics(Intrinsic::ID ID, Value *LHS, Value *RHS) const override {
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index b09b80f95871a..e5015c1368cea 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -922,6 +922,8 @@ CallInst *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS,
                                                Value *RHS,
                                                Instruction *FMFSource,
                                                const Twine &Name) {
+  if (Value *V = Folder.FoldBinaryIntrinsics(ID, LHS, RHS))
+    return (CallInst *) V; // TODO: should return value be changed to Value *?
   Module *M = BB->getModule();
   Function *Fn = Intrinsic::getDeclaration(M, ID, { LHS->getType() });
   return createCallHelper(Fn, {LHS, RHS}, Name, FMFSource);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index b8d04322de298..b88f6d060e28e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -14071,16 +14071,8 @@ class HorizontalReduction {
       return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS,
                                  Name);
     case RecurKind::FMax:
-      if (IsConstant)
-        return ConstantFP::get(LHS->getType(),
-                               maxnum(cast<ConstantFP>(LHS)->getValueAPF(),
-                                      cast<ConstantFP>(RHS)->getValueAPF()));
       return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS);
     case RecurKind::FMin:
-      if (IsConstant)
-        return ConstantFP::get(LHS->getType(),
-                               minnum(cast<ConstantFP>(LHS)->getValueAPF(),
-                                      cast<ConstantFP>(RHS)->getValueAPF()));
       return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
     case RecurKind::FMaximum:
       if (IsConstant)

>From 4044d93d4b6333a6d7dad24eff2e1de7a9f9d302 Mon Sep 17 00:00:00 2001
From: Artem Tyurin <artem.tyurin at gmail.com>
Date: Tue, 6 Feb 2024 17:55:46 +0100
Subject: [PATCH 2/2] Reuse simplifyBinaryIntrinsic

---
 llvm/include/llvm/Analysis/InstSimplifyFolder.h  | 3 +--
 llvm/include/llvm/Analysis/InstructionSimplify.h | 5 +++++
 llvm/lib/Analysis/InstructionSimplify.cpp        | 8 ++------
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/Analysis/InstSimplifyFolder.h b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
index d8ec313b49467..fbcdf79927d42 100644
--- a/llvm/include/llvm/Analysis/InstSimplifyFolder.h
+++ b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
@@ -119,8 +119,7 @@ class InstSimplifyFolder final : public IRBuilderFolder {
 
   Value *FoldBinaryIntrinsics(Intrinsic::ID ID, Value *LHS,
                               Value *RHS) const override {
-    // TODO: should this be defined?
-    return nullptr;
+    return simplifyBinaryIntrinsic(ID, LHS->getType(), LHS, RHS, SQ, nullptr);
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index a29955a06cf4e..c48812c7923d6 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -186,6 +186,11 @@ Value *simplifyExtractElementInst(Value *Vec, Value *Idx,
 Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
                         const SimplifyQuery &Q);
 
+/// Given operands for a BinaryIntrinsic, fold the result or return null.
+Value *simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, Value *Op0, Value *Op1,
+                                      const SimplifyQuery &Q,
+                                      const CallBase *Call);
+
 /// Given operands for a ShuffleVectorInst, fold the result or return null.
 /// See class ShuffleVectorInst for a description of the mask representation.
 Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index c10b9d9be101b..ee70b9fc2d0ae 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6379,11 +6379,9 @@ static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0,
   return nullptr;
 }
 
-static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
+Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, Value *Op0, Value *Op1,
                                       const SimplifyQuery &Q,
                                       const CallBase *Call) {
-  Intrinsic::ID IID = F->getIntrinsicID();
-  Type *ReturnType = F->getReturnType();
   unsigned BitWidth = ReturnType->getScalarSizeInBits();
   switch (IID) {
   case Intrinsic::abs:
@@ -6667,8 +6665,6 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
     break;
   }
   case Intrinsic::vector_extract: {
-    Type *ReturnType = F->getReturnType();
-
     // (extract_vector (insert_vector _, X, 0), 0) -> X
     unsigned IdxN = cast<ConstantInt>(Op1)->getZExtValue();
     Value *X = nullptr;
@@ -6715,7 +6711,7 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
     return simplifyUnaryIntrinsic(F, Args[0], Q, Call);
 
   if (NumOperands == 2)
-    return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q, Call);
+    return simplifyBinaryIntrinsic(IID, F->getReturnType(), Args[0], Args[1], Q, Call);
 
   // Handle intrinsics with 3 or more arguments.
   switch (IID) {



More information about the llvm-commits mailing list