[llvm] 3f23c7f - [InstSimplify] Actually use NewOps for calls in simplifyInstructionWithOperands

Arthur Eubanks via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 22 09:26:26 PDT 2023


Author: Arthur Eubanks
Date: 2023-03-22T09:22:00-07:00
New Revision: 3f23c7f5bedc8786d3f4567d2331a7efcbb2a77e

URL: https://github.com/llvm/llvm-project/commit/3f23c7f5bedc8786d3f4567d2331a7efcbb2a77e
DIFF: https://github.com/llvm/llvm-project/commit/3f23c7f5bedc8786d3f4567d2331a7efcbb2a77e.diff

LOG: [InstSimplify] Actually use NewOps for calls in simplifyInstructionWithOperands

Resolves a TODO.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D146599

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/unittests/Transforms/Utils/LocalTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 861fa3b20a495..826bd45d8057b 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -302,8 +302,9 @@ Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
 Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF,
                      const SimplifyQuery &Q);
 
-/// Given a callsite, fold the result or return null.
-Value *simplifyCall(CallBase *Call, const SimplifyQuery &Q);
+/// Given a callsite, callee, and arguments, fold the result or return null.
+Value *simplifyCall(CallBase *Call, Value *Callee, ArrayRef<Value *> Args,
+                    const SimplifyQuery &Q);
 
 /// Given a constrained FP intrinsic call, tries to compute its simplified
 /// version. Returns a simplified result or null.

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index ecb0cdbd13c62..eaf0af92484d7 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6391,10 +6391,13 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
   return nullptr;
 }
 
-static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
-
-  unsigned NumOperands = Call->arg_size();
-  Function *F = cast<Function>(Call->getCalledFunction());
+static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
+                                ArrayRef<Value *> Args,
+                                const SimplifyQuery &Q) {
+  // Operand bundles should not be in Args.
+  assert(Call->arg_size() == Args.size());
+  unsigned NumOperands = Args.size();
+  Function *F = cast<Function>(Callee);
   Intrinsic::ID IID = F->getIntrinsicID();
 
   // Most of the intrinsics with no operands have some kind of side effect.
@@ -6420,18 +6423,17 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
   }
 
   if (NumOperands == 1)
-    return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q);
+    return simplifyUnaryIntrinsic(F, Args[0], Q);
 
   if (NumOperands == 2)
-    return simplifyBinaryIntrinsic(F, Call->getArgOperand(0),
-                                   Call->getArgOperand(1), Q);
+    return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q);
 
   // Handle intrinsics with 3 or more arguments.
   switch (IID) {
   case Intrinsic::masked_load:
   case Intrinsic::masked_gather: {
-    Value *MaskArg = Call->getArgOperand(2);
-    Value *PassthruArg = Call->getArgOperand(3);
+    Value *MaskArg = Args[2];
+    Value *PassthruArg = Args[3];
     // If the mask is all zeros or undef, the "passthru" argument is the result.
     if (maskIsAllZeroOrUndef(MaskArg))
       return PassthruArg;
@@ -6439,8 +6441,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
   }
   case Intrinsic::fshl:
   case Intrinsic::fshr: {
-    Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1),
-          *ShAmtArg = Call->getArgOperand(2);
+    Value *Op0 = Args[0], *Op1 = Args[1], *ShAmtArg = Args[2];
 
     // If both operands are undef, the result is undef.
     if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1))
@@ -6448,14 +6449,14 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
 
     // If shift amount is undef, assume it is zero.
     if (Q.isUndefValue(ShAmtArg))
-      return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);
+      return Args[IID == Intrinsic::fshl ? 0 : 1];
 
     const APInt *ShAmtC;
     if (match(ShAmtArg, m_APInt(ShAmtC))) {
       // If there's effectively no shift, return the 1st arg or 2nd arg.
       APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth());
       if (ShAmtC->urem(BitWidth).isZero())
-        return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);
+        return Args[IID == Intrinsic::fshl ? 0 : 1];
     }
 
     // Rotating zero by anything is zero.
@@ -6469,31 +6470,24 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
     return nullptr;
   }
   case Intrinsic::experimental_constrained_fma: {
-    Value *Op0 = Call->getArgOperand(0);
-    Value *Op1 = Call->getArgOperand(1);
-    Value *Op2 = Call->getArgOperand(2);
     auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
-    if (Value *V =
-            simplifyFPOp({Op0, Op1, Op2}, {}, Q, *FPI->getExceptionBehavior(),
-                         *FPI->getRoundingMode()))
+    if (Value *V = simplifyFPOp(Args, {}, Q, *FPI->getExceptionBehavior(),
+                                *FPI->getRoundingMode()))
       return V;
     return nullptr;
   }
   case Intrinsic::fma:
   case Intrinsic::fmuladd: {
-    Value *Op0 = Call->getArgOperand(0);
-    Value *Op1 = Call->getArgOperand(1);
-    Value *Op2 = Call->getArgOperand(2);
-    if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, fp::ebIgnore,
+    if (Value *V = simplifyFPOp(Args, {}, Q, fp::ebIgnore,
                                 RoundingMode::NearestTiesToEven))
       return V;
     return nullptr;
   }
   case Intrinsic::smul_fix:
   case Intrinsic::smul_fix_sat: {
-    Value *Op0 = Call->getArgOperand(0);
-    Value *Op1 = Call->getArgOperand(1);
-    Value *Op2 = Call->getArgOperand(2);
+    Value *Op0 = Args[0];
+    Value *Op1 = Args[1];
+    Value *Op2 = Args[2];
     Type *ReturnType = F->getReturnType();
 
     // Canonicalize constant operand as Op1 (ConstantFolding handles the case
@@ -6520,9 +6514,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
     return nullptr;
   }
   case Intrinsic::vector_insert: {
-    Value *Vec = Call->getArgOperand(0);
-    Value *SubVec = Call->getArgOperand(1);
-    Value *Idx = Call->getArgOperand(2);
+    Value *Vec = Args[0];
+    Value *SubVec = Args[1];
+    Value *Idx = Args[2];
     Type *ReturnType = F->getReturnType();
 
     // (insert_vector Y, (extract_vector X, 0), 0) -> X
@@ -6539,51 +6533,52 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
   }
   case Intrinsic::experimental_constrained_fadd: {
     auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
-    return simplifyFAddInst(
-        FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
-        Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
+    return simplifyFAddInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
+                            *FPI->getExceptionBehavior(),
+                            *FPI->getRoundingMode());
   }
   case Intrinsic::experimental_constrained_fsub: {
     auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
-    return simplifyFSubInst(
-        FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
-        Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
+    return simplifyFSubInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
+                            *FPI->getExceptionBehavior(),
+                            *FPI->getRoundingMode());
   }
   case Intrinsic::experimental_constrained_fmul: {
     auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
-    return simplifyFMulInst(
-        FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
-        Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
+    return simplifyFMulInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
+                            *FPI->getExceptionBehavior(),
+                            *FPI->getRoundingMode());
   }
   case Intrinsic::experimental_constrained_fdiv: {
     auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
-    return simplifyFDivInst(
-        FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
-        Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
+    return simplifyFDivInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
+                            *FPI->getExceptionBehavior(),
+                            *FPI->getRoundingMode());
   }
   case Intrinsic::experimental_constrained_frem: {
     auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
-    return simplifyFRemInst(
-        FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
-        Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
+    return simplifyFRemInst(Args[0], Args[1], FPI->getFastMathFlags(), Q,
+                            *FPI->getExceptionBehavior(),
+                            *FPI->getRoundingMode());
   }
   default:
     return nullptr;
   }
 }
 
-static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) {
-  auto *F = dyn_cast<Function>(Call->getCalledOperand());
+static Value *tryConstantFoldCall(CallBase *Call, Value *Callee,
+                                  ArrayRef<Value *> Args,
+                                  const SimplifyQuery &Q) {
+  auto *F = dyn_cast<Function>(Callee);
   if (!F || !canConstantFoldCallTo(Call, F))
     return nullptr;
 
   SmallVector<Constant *, 4> ConstantArgs;
-  unsigned NumArgs = Call->arg_size();
-  ConstantArgs.reserve(NumArgs);
-  for (auto &Arg : Call->args()) {
-    Constant *C = dyn_cast<Constant>(&Arg);
+  ConstantArgs.reserve(Args.size());
+  for (Value *Arg : Args) {
+    Constant *C = dyn_cast<Constant>(Arg);
     if (!C) {
-      if (isa<MetadataAsValue>(Arg.get()))
+      if (isa<MetadataAsValue>(Arg))
         continue;
       return nullptr;
     }
@@ -6593,7 +6588,11 @@ static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) {
   return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI);
 }
 
-Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) {
+Value *llvm::simplifyCall(CallBase *Call, Value *Callee, ArrayRef<Value *> Args,
+                          const SimplifyQuery &Q) {
+  // Args should not contain operand bundle operands.
+  assert(Call->arg_size() == Args.size());
+
   // musttail calls can only be simplified if they are also DCEd.
   // As we can't guarantee this here, don't simplify them.
   if (Call->isMustTailCall())
@@ -6601,16 +6600,15 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) {
 
   // call undef -> poison
   // call null -> poison
-  Value *Callee = Call->getCalledOperand();
   if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee))
     return PoisonValue::get(Call->getType());
 
-  if (Value *V = tryConstantFoldCall(Call, Q))
+  if (Value *V = tryConstantFoldCall(Call, Callee, Args, Q))
     return V;
 
   auto *F = dyn_cast<Function>(Callee);
   if (F && F->isIntrinsic())
-    if (Value *Ret = simplifyIntrinsic(Call, Q))
+    if (Value *Ret = simplifyIntrinsic(Call, Callee, Args, Q))
       return Ret;
 
   return nullptr;
@@ -6618,9 +6616,10 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) {
 
 Value *llvm::simplifyConstrainedFPCall(CallBase *Call, const SimplifyQuery &Q) {
   assert(isa<ConstrainedFPIntrinsic>(Call));
-  if (Value *V = tryConstantFoldCall(Call, Q))
+  SmallVector<Value *, 4> Args(Call->args());
+  if (Value *V = tryConstantFoldCall(Call, Call->getCalledOperand(), Args, Q))
     return V;
-  if (Value *Ret = simplifyIntrinsic(Call, Q))
+  if (Value *Ret = simplifyIntrinsic(Call, Call->getCalledOperand(), Args, Q))
     return Ret;
   return nullptr;
 }
@@ -6775,8 +6774,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
   case Instruction::PHI:
     return simplifyPHINode(cast<PHINode>(I), NewOps, Q);
   case Instruction::Call:
-    // TODO: Use NewOps
-    return simplifyCall(cast<CallInst>(I), Q);
+    return simplifyCall(
+        cast<CallInst>(I), NewOps.back(),
+        NewOps.drop_back(1 + cast<CallInst>(I)->getNumTotalBundleOperands()), Q);
   case Instruction::Freeze:
     return llvm::simplifyFreezeInst(NewOps[0], Q);
 #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc:

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 2b61b58dbc36a..0fbd62e8a41c0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1288,9 +1288,15 @@ foldShuffledIntrinsicOperands(IntrinsicInst *II,
 Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   // Don't try to simplify calls without uses. It will not do anything useful,
   // but will result in the following folds being skipped.
-  if (!CI.use_empty())
-    if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI)))
+  if (!CI.use_empty()) {
+    SmallVector<Value *, 4> Args;
+    Args.reserve(CI.arg_size());
+    for (Value *Op : CI.args())
+      Args.push_back(Op);
+    if (Value *V = simplifyCall(&CI, CI.getCalledOperand(), Args,
+                                SQ.getWithInstruction(&CI)))
       return replaceInstUsesWith(CI, V);
+  }
 
   if (Value *FreedOp = getFreedOperand(&CI, &TLI))
     return visitFree(CI, FreedOp);

diff  --git a/llvm/unittests/Transforms/Utils/LocalTest.cpp b/llvm/unittests/Transforms/Utils/LocalTest.cpp
index d6b09b35f2caf..443f1f09915fd 100644
--- a/llvm/unittests/Transforms/Utils/LocalTest.cpp
+++ b/llvm/unittests/Transforms/Utils/LocalTest.cpp
@@ -598,7 +598,8 @@ TEST(Local, SimplifyVScaleWithRange) {
 
   // Test that simplifyCall won't try to query it's parent function for
   // vscale_range attributes in order to simplify llvm.vscale -> constant.
-  EXPECT_EQ(simplifyCall(CI, SimplifyQuery(M.getDataLayout())), nullptr);
+  EXPECT_EQ(simplifyCall(CI, VScale, {}, SimplifyQuery(M.getDataLayout())),
+            nullptr);
   delete CI;
 }
 


        


More information about the llvm-commits mailing list