[llvm] 1411452 - [IRBuilder] Fold binary intrinsics (#80743)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 15 01:58:29 PDT 2024


Author: Artem Tyurin
Date: 2024-03-15T09:58:25+01:00
New Revision: 141145232f915b44aef6e3854f091da03c41a2b6

URL: https://github.com/llvm/llvm-project/commit/141145232f915b44aef6e3854f091da03c41a2b6
DIFF: https://github.com/llvm/llvm-project/commit/141145232f915b44aef6e3854f091da03c41a2b6.diff

LOG: [IRBuilder] Fold binary intrinsics (#80743)

Fixes https://github.com/llvm/llvm-project/issues/61240.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ConstantFolding.h
    llvm/include/llvm/Analysis/InstSimplifyFolder.h
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/include/llvm/Analysis/TargetFolder.h
    llvm/include/llvm/IR/ConstantFolder.h
    llvm/include/llvm/IR/IRBuilder.h
    llvm/include/llvm/IR/IRBuilderFolder.h
    llvm/include/llvm/IR/NoFolder.h
    llvm/lib/Analysis/ConstantFolding.cpp
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/IR/IRBuilder.cpp
    llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/unittests/IR/IRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ConstantFolding.h b/llvm/include/llvm/Analysis/ConstantFolding.h
index fd885a8c0ea3e5..c54b1e8f01d2b6 100644
--- a/llvm/include/llvm/Analysis/ConstantFolding.h
+++ b/llvm/include/llvm/Analysis/ConstantFolding.h
@@ -22,6 +22,11 @@
 #include <stdint.h>
 
 namespace llvm {
+
+namespace Intrinsic {
+using ID = unsigned;
+}
+
 class APInt;
 template <typename T> class ArrayRef;
 class CallBase;
@@ -187,6 +192,10 @@ Constant *ConstantFoldCall(const CallBase *Call, Function *F,
                            ArrayRef<Constant *> Operands,
                            const TargetLibraryInfo *TLI = nullptr);
 
+Constant *ConstantFoldBinaryIntrinsic(Intrinsic::ID ID, Constant *LHS,
+                                      Constant *RHS, Type *Ty,
+                                      Instruction *FMFSource);
+
 /// ConstantFoldLoadThroughBitcast - try to cast constant to destination type
 /// returning null if unsuccessful. Can cast pointer to pointer or pointer to
 /// integer and vice versa if their sizes are equal.

diff  --git a/llvm/include/llvm/Analysis/InstSimplifyFolder.h b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
index 23e2ea80e8cbe6..8a3269d6add0e0 100644
--- a/llvm/include/llvm/Analysis/InstSimplifyFolder.h
+++ b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
@@ -117,6 +117,12 @@ class InstSimplifyFolder final : public IRBuilderFolder {
     return simplifyCastInst(Op, V, DestTy, SQ);
   }
 
+  Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
+                             Instruction *FMFSource) const override {
+    return simplifyBinaryIntrinsic(ID, Ty, LHS, RHS, SQ,
+                                   dyn_cast_if_present<CallBase>(FMFSource));
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index a29955a06cf4e0..03d7ad12c12d8f 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -186,6 +186,11 @@ Value *simplifyExtractElementInst(Value *Vec, Value *Idx,
 Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
                         const SimplifyQuery &Q);
 
+/// Given operands for a BinaryIntrinsic, fold the result or return null.
+Value *simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType, Value *Op0,
+                               Value *Op1, const SimplifyQuery &Q,
+                               const CallBase *Call);
+
 /// Given operands for a ShuffleVectorInst, fold the result or return null.
 /// See class ShuffleVectorInst for a description of the mask representation.
 Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,

diff  --git a/llvm/include/llvm/Analysis/TargetFolder.h b/llvm/include/llvm/Analysis/TargetFolder.h
index 978e1002515fc0..b4105ad76c02e2 100644
--- a/llvm/include/llvm/Analysis/TargetFolder.h
+++ b/llvm/include/llvm/Analysis/TargetFolder.h
@@ -191,6 +191,15 @@ class TargetFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
+  Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
+                             Instruction *FMFSource) const override {
+    auto *C1 = dyn_cast<Constant>(LHS);
+    auto *C2 = dyn_cast<Constant>(RHS);
+    if (C1 && C2)
+      return ConstantFoldBinaryIntrinsic(ID, C1, C2, Ty, FMFSource);
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/ConstantFolder.h b/llvm/include/llvm/IR/ConstantFolder.h
index c2b30a65e32e25..3e74a563a58421 100644
--- a/llvm/include/llvm/IR/ConstantFolder.h
+++ b/llvm/include/llvm/IR/ConstantFolder.h
@@ -18,8 +18,8 @@
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Constants.h"
 #include "llvm/IR/ConstantFold.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilderFolder.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Operator.h"
@@ -89,7 +89,7 @@ class ConstantFolder final : public IRBuilderFolder {
   }
 
   Value *FoldUnOpFMF(Instruction::UnaryOps Opc, Value *V,
-                      FastMathFlags FMF) const override {
+                     FastMathFlags FMF) const override {
     if (Constant *C = dyn_cast<Constant>(V))
       return ConstantFoldUnaryInstruction(Opc, C);
     return nullptr;
@@ -183,6 +183,12 @@ class ConstantFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
+  Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
+                             Instruction *FMFSource) const override {
+    // Use TargetFolder or InstSimplifyFolder instead.
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index f2922311097e9b..c07ffea7115115 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -962,9 +962,9 @@ class IRBuilderBase {
 
   /// Create a call to intrinsic \p ID with 2 operands which is mangled on the
   /// first type.
-  CallInst *CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS,
-                                  Instruction *FMFSource = nullptr,
-                                  const Twine &Name = "");
+  Value *CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS,
+                               Instruction *FMFSource = nullptr,
+                               const Twine &Name = "");
 
   /// Create a call to intrinsic \p ID with \p Args, mangled using \p Types. If
   /// \p FMFSource is provided, copy fast-math-flags from that instruction to
@@ -983,7 +983,7 @@ class IRBuilderBase {
                             const Twine &Name = "");
 
   /// Create call to the minnum intrinsic.
-  CallInst *CreateMinNum(Value *LHS, Value *RHS, const Twine &Name = "") {
+  Value *CreateMinNum(Value *LHS, Value *RHS, const Twine &Name = "") {
     if (IsFPConstrained) {
       return CreateConstrainedFPUnroundedBinOp(
           Intrinsic::experimental_constrained_minnum, LHS, RHS, nullptr, Name);
@@ -993,7 +993,7 @@ class IRBuilderBase {
   }
 
   /// Create call to the maxnum intrinsic.
-  CallInst *CreateMaxNum(Value *LHS, Value *RHS, const Twine &Name = "") {
+  Value *CreateMaxNum(Value *LHS, Value *RHS, const Twine &Name = "") {
     if (IsFPConstrained) {
       return CreateConstrainedFPUnroundedBinOp(
           Intrinsic::experimental_constrained_maxnum, LHS, RHS, nullptr, Name);
@@ -1003,19 +1003,19 @@ class IRBuilderBase {
   }
 
   /// Create call to the minimum intrinsic.
-  CallInst *CreateMinimum(Value *LHS, Value *RHS, const Twine &Name = "") {
+  Value *CreateMinimum(Value *LHS, Value *RHS, const Twine &Name = "") {
     return CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS, nullptr, Name);
   }
 
   /// Create call to the maximum intrinsic.
-  CallInst *CreateMaximum(Value *LHS, Value *RHS, const Twine &Name = "") {
+  Value *CreateMaximum(Value *LHS, Value *RHS, const Twine &Name = "") {
     return CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS, nullptr, Name);
   }
 
   /// Create call to the copysign intrinsic.
-  CallInst *CreateCopySign(Value *LHS, Value *RHS,
-                           Instruction *FMFSource = nullptr,
-                           const Twine &Name = "") {
+  Value *CreateCopySign(Value *LHS, Value *RHS,
+                        Instruction *FMFSource = nullptr,
+                        const Twine &Name = "") {
     return CreateBinaryIntrinsic(Intrinsic::copysign, LHS, RHS, FMFSource,
                                  Name);
   }

diff  --git a/llvm/include/llvm/IR/IRBuilderFolder.h b/llvm/include/llvm/IR/IRBuilderFolder.h
index bd2324dfc5f1ba..3020f2684ee457 100644
--- a/llvm/include/llvm/IR/IRBuilderFolder.h
+++ b/llvm/include/llvm/IR/IRBuilderFolder.h
@@ -73,6 +73,10 @@ class IRBuilderFolder {
   virtual Value *FoldCast(Instruction::CastOps Op, Value *V,
                           Type *DestTy) const = 0;
 
+  virtual Value *
+  FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
+                      Instruction *FMFSource = nullptr) const = 0;
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/NoFolder.h b/llvm/include/llvm/IR/NoFolder.h
index a612f98465aeaa..7bb5d5e696e9ea 100644
--- a/llvm/include/llvm/IR/NoFolder.h
+++ b/llvm/include/llvm/IR/NoFolder.h
@@ -112,6 +112,11 @@ class NoFolder final : public IRBuilderFolder {
     return nullptr;
   }
 
+  Value *FoldBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, Type *Ty,
+                             Instruction *FMFSource) const override {
+    return nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Cast/Conversion Operators
   //===--------------------------------------------------------------------===//

diff  --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 8b7031e7fe4a6f..6b2e88d06b18c7 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -2534,12 +2534,73 @@ static Constant *evaluateCompare(const APFloat &Op1, const APFloat &Op2,
   return nullptr;
 }
 
-static Constant *ConstantFoldScalarCall2(StringRef Name,
-                                         Intrinsic::ID IntrinsicID,
-                                         Type *Ty,
-                                         ArrayRef<Constant *> Operands,
-                                         const TargetLibraryInfo *TLI,
-                                         const CallBase *Call) {
+static Constant *ConstantFoldLibCall2(StringRef Name, Type *Ty,
+                                      ArrayRef<Constant *> Operands,
+                                      const TargetLibraryInfo *TLI) {
+  if (!TLI)
+    return nullptr;
+
+  LibFunc Func = NotLibFunc;
+  if (!TLI->getLibFunc(Name, Func))
+    return nullptr;
+
+  const auto *Op1 = dyn_cast<ConstantFP>(Operands[0]);
+  if (!Op1)
+    return nullptr;
+
+  const auto *Op2 = dyn_cast<ConstantFP>(Operands[1]);
+  if (!Op2)
+    return nullptr;
+
+  const APFloat &Op1V = Op1->getValueAPF();
+  const APFloat &Op2V = Op2->getValueAPF();
+
+  switch (Func) {
+  default:
+    break;
+  case LibFunc_pow:
+  case LibFunc_powf:
+  case LibFunc_pow_finite:
+  case LibFunc_powf_finite:
+    if (TLI->has(Func))
+      return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
+    break;
+  case LibFunc_fmod:
+  case LibFunc_fmodf:
+    if (TLI->has(Func)) {
+      APFloat V = Op1->getValueAPF();
+      if (APFloat::opStatus::opOK == V.mod(Op2->getValueAPF()))
+        return ConstantFP::get(Ty->getContext(), V);
+    }
+    break;
+  case LibFunc_remainder:
+  case LibFunc_remainderf:
+    if (TLI->has(Func)) {
+      APFloat V = Op1->getValueAPF();
+      if (APFloat::opStatus::opOK == V.remainder(Op2->getValueAPF()))
+        return ConstantFP::get(Ty->getContext(), V);
+    }
+    break;
+  case LibFunc_atan2:
+  case LibFunc_atan2f:
+    // atan2(+/-0.0, +/-0.0) is known to raise an exception on some libm
+    // (Solaris), so we do not assume a known result for that.
+    if (Op1V.isZero() && Op2V.isZero())
+      return nullptr;
+    [[fallthrough]];
+  case LibFunc_atan2_finite:
+  case LibFunc_atan2f_finite:
+    if (TLI->has(Func))
+      return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
+    break;
+  }
+
+  return nullptr;
+}
+
+static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
+                                            ArrayRef<Constant *> Operands,
+                                            const CallBase *Call) {
   assert(Operands.size() == 2 && "Wrong number of operands.");
 
   if (Ty->isFloatingPointTy()) {
@@ -2569,7 +2630,8 @@ static Constant *ConstantFoldScalarCall2(StringRef Name,
         return nullptr;
       const APFloat &Op2V = Op2->getValueAPF();
 
-      if (const auto *ConstrIntr = dyn_cast<ConstrainedFPIntrinsic>(Call)) {
+      if (const auto *ConstrIntr =
+              dyn_cast_if_present<ConstrainedFPIntrinsic>(Call)) {
         RoundingMode RM = getEvaluationRoundingMode(ConstrIntr);
         APFloat Res = Op1V;
         APFloat::opStatus St;
@@ -2632,52 +2694,6 @@ static Constant *ConstantFoldScalarCall2(StringRef Name,
         return ConstantFP::get(Ty->getContext(), Op1V * Op2V);
       }
 
-      if (!TLI)
-        return nullptr;
-
-      LibFunc Func = NotLibFunc;
-      if (!TLI->getLibFunc(Name, Func))
-        return nullptr;
-
-      switch (Func) {
-      default:
-        break;
-      case LibFunc_pow:
-      case LibFunc_powf:
-      case LibFunc_pow_finite:
-      case LibFunc_powf_finite:
-        if (TLI->has(Func))
-          return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
-        break;
-      case LibFunc_fmod:
-      case LibFunc_fmodf:
-        if (TLI->has(Func)) {
-          APFloat V = Op1->getValueAPF();
-          if (APFloat::opStatus::opOK == V.mod(Op2->getValueAPF()))
-            return ConstantFP::get(Ty->getContext(), V);
-        }
-        break;
-      case LibFunc_remainder:
-      case LibFunc_remainderf:
-        if (TLI->has(Func)) {
-          APFloat V = Op1->getValueAPF();
-          if (APFloat::opStatus::opOK == V.remainder(Op2->getValueAPF()))
-            return ConstantFP::get(Ty->getContext(), V);
-        }
-        break;
-      case LibFunc_atan2:
-      case LibFunc_atan2f:
-        // atan2(+/-0.0, +/-0.0) is known to raise an exception on some libm
-        // (Solaris), so we do not assume a known result for that.
-        if (Op1V.isZero() && Op2V.isZero())
-          return nullptr;
-        [[fallthrough]];
-      case LibFunc_atan2_finite:
-      case LibFunc_atan2f_finite:
-        if (TLI->has(Func))
-          return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
-        break;
-      }
     } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
       switch (IntrinsicID) {
       case Intrinsic::ldexp: {
@@ -3168,8 +3184,13 @@ static Constant *ConstantFoldScalarCall(StringRef Name,
   if (Operands.size() == 1)
     return ConstantFoldScalarCall1(Name, IntrinsicID, Ty, Operands, TLI, Call);
 
-  if (Operands.size() == 2)
-    return ConstantFoldScalarCall2(Name, IntrinsicID, Ty, Operands, TLI, Call);
+  if (Operands.size() == 2) {
+    if (Constant *FoldedLibCall =
+            ConstantFoldLibCall2(Name, Ty, Operands, TLI)) {
+      return FoldedLibCall;
+    }
+    return ConstantFoldIntrinsicCall2(IntrinsicID, Ty, Operands, Call);
+  }
 
   if (Operands.size() == 3)
     return ConstantFoldScalarCall3(Name, IntrinsicID, Ty, Operands, TLI, Call);
@@ -3376,6 +3397,13 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID,
 
 } // end anonymous namespace
 
+Constant *llvm::ConstantFoldBinaryIntrinsic(Intrinsic::ID ID, Constant *LHS,
+                                            Constant *RHS, Type *Ty,
+                                            Instruction *FMFSource) {
+  return ConstantFoldIntrinsicCall2(ID, Ty, {LHS, RHS},
+                                    dyn_cast_if_present<CallBase>(FMFSource));
+}
+
 Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F,
                                  ArrayRef<Constant *> Operands,
                                  const TargetLibraryInfo *TLI) {

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index ce651783caf16b..d14897a86314e1 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6384,11 +6384,10 @@ static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0,
   return nullptr;
 }
 
-static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
-                                      const SimplifyQuery &Q,
-                                      const CallBase *Call) {
-  Intrinsic::ID IID = F->getIntrinsicID();
-  Type *ReturnType = F->getReturnType();
+Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
+                                     Value *Op0, Value *Op1,
+                                     const SimplifyQuery &Q,
+                                     const CallBase *Call) {
   unsigned BitWidth = ReturnType->getScalarSizeInBits();
   switch (IID) {
   case Intrinsic::abs:
@@ -6646,19 +6645,21 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
     // float, if the ninf flag is set.
     const APFloat *C;
     if (match(Op1, m_APFloat(C)) &&
-        (C->isInfinity() || (Call->hasNoInfs() && C->isLargest()))) {
+        (C->isInfinity() || (Call && Call->hasNoInfs() && C->isLargest()))) {
       // minnum(X, -inf) -> -inf
       // maxnum(X, +inf) -> +inf
       // minimum(X, -inf) -> -inf if nnan
       // maximum(X, +inf) -> +inf if nnan
-      if (C->isNegative() == IsMin && (!PropagateNaN || Call->hasNoNaNs()))
+      if (C->isNegative() == IsMin &&
+          (!PropagateNaN || (Call && Call->hasNoNaNs())))
         return ConstantFP::get(ReturnType, *C);
 
       // minnum(X, +inf) -> X if nnan
       // maxnum(X, -inf) -> X if nnan
       // minimum(X, +inf) -> X
       // maximum(X, -inf) -> X
-      if (C->isNegative() != IsMin && (PropagateNaN || Call->hasNoNaNs()))
+      if (C->isNegative() != IsMin &&
+          (PropagateNaN || (Call && Call->hasNoNaNs())))
         return Op0;
     }
 
@@ -6672,8 +6673,6 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
     break;
   }
   case Intrinsic::vector_extract: {
-    Type *ReturnType = F->getReturnType();
-
     // (extract_vector (insert_vector _, X, 0), 0) -> X
     unsigned IdxN = cast<ConstantInt>(Op1)->getZExtValue();
     Value *X = nullptr;
@@ -6720,7 +6719,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
     return simplifyUnaryIntrinsic(F, Args[0], Q, Call);
 
   if (NumOperands == 2)
-    return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q, Call);
+    return simplifyBinaryIntrinsic(IID, F->getReturnType(), Args[0], Args[1], Q,
+                                   Call);
 
   // Handle intrinsics with 3 or more arguments.
   switch (IID) {

diff  --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index b09b80f95871a1..d6746d1d438242 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -918,12 +918,14 @@ CallInst *IRBuilderBase::CreateUnaryIntrinsic(Intrinsic::ID ID, Value *V,
   return createCallHelper(Fn, {V}, Name, FMFSource);
 }
 
-CallInst *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS,
-                                               Value *RHS,
-                                               Instruction *FMFSource,
-                                               const Twine &Name) {
+Value *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS,
+                                            Value *RHS, Instruction *FMFSource,
+                                            const Twine &Name) {
   Module *M = BB->getModule();
   Function *Fn = Intrinsic::getDeclaration(M, ID, { LHS->getType() });
+  if (Value *V = Folder.FoldBinaryIntrinsic(ID, LHS, RHS, Fn->getReturnType(),
+                                            FMFSource))
+    return V;
   return createCallHelper(Fn, {LHS, RHS}, Name, FMFSource);
 }
 

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index fb829fab0a2c19..5b7fa13f2e835f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -767,19 +767,21 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     // Checking for NaN before canonicalization provides better fidelity when
     // mapping other operations onto fmed3 since the order of operands is
     // unchanged.
-    CallInst *NewCall = nullptr;
+    Value *V = nullptr;
     if (match(Src0, PatternMatch::m_NaN()) || isa<UndefValue>(Src0)) {
-      NewCall = IC.Builder.CreateMinNum(Src1, Src2);
+      V = IC.Builder.CreateMinNum(Src1, Src2);
     } else if (match(Src1, PatternMatch::m_NaN()) || isa<UndefValue>(Src1)) {
-      NewCall = IC.Builder.CreateMinNum(Src0, Src2);
+      V = IC.Builder.CreateMinNum(Src0, Src2);
     } else if (match(Src2, PatternMatch::m_NaN()) || isa<UndefValue>(Src2)) {
-      NewCall = IC.Builder.CreateMaxNum(Src0, Src1);
+      V = IC.Builder.CreateMaxNum(Src0, Src1);
     }
 
-    if (NewCall) {
-      NewCall->copyFastMathFlags(&II);
-      NewCall->takeName(&II);
-      return IC.replaceInstUsesWith(II, NewCall);
+    if (V) {
+      if (auto *CI = dyn_cast<CallInst>(V)) {
+        CI->copyFastMathFlags(&II);
+        CI->takeName(&II);
+      }
+      return IC.replaceInstUsesWith(II, V);
     }
 
     bool Swap = false;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 694b18017babc5..8537dbc6fe531b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2289,13 +2289,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         default:
           llvm_unreachable("unexpected intrinsic ID");
         }
-        Instruction *NewCall = Builder.CreateBinaryIntrinsic(
+        Value *V = Builder.CreateBinaryIntrinsic(
             IID, X, ConstantFP::get(Arg0->getType(), Res), II);
         // TODO: Conservatively intersecting FMF. If Res == C2, the transform
         //       was a simplification (so Arg0 and its original flags could
         //       propagate?)
-        NewCall->andIRFlags(M);
-        return replaceInstUsesWith(*II, NewCall);
+        if (auto *CI = dyn_cast<CallInst>(V))
+          CI->andIRFlags(M);
+        return replaceInstUsesWith(*II, V);
       }
     }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index aee18d770f729d..ee76a6294428b3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1191,7 +1191,7 @@ static Value *canonicalizeSPF(ICmpInst &Cmp, Value *TrueVal, Value *FalseVal,
                           match(RHS, m_NSWNeg(m_Specific(LHS)));
     Constant *IntMinIsPoisonC =
         ConstantInt::get(Type::getInt1Ty(Cmp.getContext()), IntMinIsPoison);
-    Instruction *Abs =
+    Value *Abs =
         IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC);
 
     if (SPF == SelectPatternFlavor::SPF_NABS)

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index acb738ef281e0d..57cb0b1a2e6d04 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -15078,13 +15078,13 @@ class HorizontalReduction {
       if (IsConstant)
         return ConstantFP::get(LHS->getType(),
                                maximum(cast<ConstantFP>(LHS)->getValueAPF(),
-                                      cast<ConstantFP>(RHS)->getValueAPF()));
+                                       cast<ConstantFP>(RHS)->getValueAPF()));
       return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS);
     case RecurKind::FMinimum:
       if (IsConstant)
         return ConstantFP::get(LHS->getType(),
                                minimum(cast<ConstantFP>(LHS)->getValueAPF(),
-                                      cast<ConstantFP>(RHS)->getValueAPF()));
+                                       cast<ConstantFP>(RHS)->getValueAPF()));
       return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS);
     case RecurKind::SMax:
       if (IsConstant || UseSelect) {

diff  --git a/llvm/unittests/IR/IRBuilderTest.cpp b/llvm/unittests/IR/IRBuilderTest.cpp
index 139e8832c97b91..77bd93569ea93b 100644
--- a/llvm/unittests/IR/IRBuilderTest.cpp
+++ b/llvm/unittests/IR/IRBuilderTest.cpp
@@ -57,7 +57,7 @@ TEST_F(IRBuilderTest, Intrinsics) {
   IRBuilder<> Builder(BB);
   Value *V;
   Instruction *I;
-  CallInst *Call;
+  Value *Result;
   IntrinsicInst *II;
 
   V = Builder.CreateLoad(GV->getValueType(), GV);
@@ -65,78 +65,80 @@ TEST_F(IRBuilderTest, Intrinsics) {
   I->setHasNoInfs(true);
   I->setHasNoNaNs(false);
 
-  Call = Builder.CreateMinNum(V, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateMinNum(V, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::minnum);
 
-  Call = Builder.CreateMaxNum(V, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateMaxNum(V, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::maxnum);
 
-  Call = Builder.CreateMinimum(V, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateMinimum(V, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::minimum);
 
-  Call = Builder.CreateMaximum(V, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateMaximum(V, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::maximum);
 
-  Call = Builder.CreateIntrinsic(Intrinsic::readcyclecounter, {}, {});
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateIntrinsic(Intrinsic::readcyclecounter, {}, {});
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::readcyclecounter);
 
-  Call = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::fabs);
   EXPECT_FALSE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, V, I);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, V, I);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::fabs);
   EXPECT_TRUE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateBinaryIntrinsic(Intrinsic::pow, V, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateBinaryIntrinsic(Intrinsic::pow, V, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::pow);
   EXPECT_FALSE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateBinaryIntrinsic(Intrinsic::pow, V, V, I);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateBinaryIntrinsic(Intrinsic::pow, V, V, I);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::pow);
   EXPECT_TRUE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateIntrinsic(Intrinsic::fma, {V->getType()}, {V, V, V});
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateIntrinsic(Intrinsic::fma, {V->getType()}, {V, V, V});
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::fma);
   EXPECT_FALSE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateIntrinsic(Intrinsic::fma, {V->getType()}, {V, V, V}, I);
-  II = cast<IntrinsicInst>(Call);
+  Result =
+      Builder.CreateIntrinsic(Intrinsic::fma, {V->getType()}, {V, V, V}, I);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::fma);
   EXPECT_TRUE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateIntrinsic(Intrinsic::fma, {V->getType()}, {V, V, V}, I);
-  II = cast<IntrinsicInst>(Call);
+  Result =
+      Builder.CreateIntrinsic(Intrinsic::fma, {V->getType()}, {V, V, V}, I);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::fma);
   EXPECT_TRUE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateUnaryIntrinsic(Intrinsic::roundeven, V);
-  II = cast<IntrinsicInst>(Call);
+  Result = Builder.CreateUnaryIntrinsic(Intrinsic::roundeven, V);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::roundeven);
   EXPECT_FALSE(II->hasNoInfs());
   EXPECT_FALSE(II->hasNoNaNs());
 
-  Call = Builder.CreateIntrinsic(
+  Result = Builder.CreateIntrinsic(
       Intrinsic::set_rounding, {},
       {Builder.getInt32(static_cast<uint32_t>(RoundingMode::TowardZero))});
-  II = cast<IntrinsicInst>(Call);
+  II = cast<IntrinsicInst>(Result);
   EXPECT_EQ(II->getIntrinsicID(), Intrinsic::set_rounding);
 }
 


        


More information about the llvm-commits mailing list