[llvm] 2595931 - [IndVars] Support shl by constant and or disjoint in getExtendedOperandRecurrence. (#84282)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 12:52:15 PDT 2024


Author: Craig Topper
Date: 2024-03-14T12:52:12-07:00
New Revision: 25959310a50240c320283d1e4b23046fed152313

URL: https://github.com/llvm/llvm-project/commit/25959310a50240c320283d1e4b23046fed152313
DIFF: https://github.com/llvm/llvm-project/commit/25959310a50240c320283d1e4b23046fed152313.diff

LOG: [IndVars] Support shl by constant and or disjoint in getExtendedOperandRecurrence. (#84282)

We can treat a shift by constant as a multiply by a power of 2
and we can treat an or disjoint as a 'add nsw nuw'.

I've added a helper struct similar to a struct used in
ScalarEvolution.cpp
to represent the opcode, operands, and NSW/NUW flags for normal
add/sub/mul
and shl/or that are being treated as mul/add.

I don't think we need to teach cloneIVUser about this. It will continue
to clone them using cloneBitwiseIVUser. After the cloning we will ask
for the SCEV expression for the cloned IV user and verify that it
matches
the AddRec returned by getExtendedOperandRecurrence. Since SCEV also
knows how to convert shl to mul and or disjoint to add nsw nuw, this
should
usually match. If it doesn't match, the cloned IV user will be deleted.

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
    llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index b8fa985fa3462e..f6d72265305d7d 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -1381,6 +1381,77 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
   };
 }
 
+namespace {
+
+// Represents a interesting integer binary operation for
+// getExtendedOperandRecurrence. This may be a shl that is being treated as a
+// multiply or a 'or disjoint' that is being treated as 'add nsw nuw'.
+struct BinaryOp {
+  unsigned Opcode;
+  std::array<Value *, 2> Operands;
+  bool IsNSW = false;
+  bool IsNUW = false;
+
+  explicit BinaryOp(Instruction *Op)
+      : Opcode(Op->getOpcode()),
+        Operands({Op->getOperand(0), Op->getOperand(1)}) {
+    if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
+      IsNSW = OBO->hasNoSignedWrap();
+      IsNUW = OBO->hasNoUnsignedWrap();
+    }
+  }
+
+  explicit BinaryOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
+                    bool IsNSW = false, bool IsNUW = false)
+      : Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {}
+};
+
+} // end anonymous namespace
+
+static std::optional<BinaryOp> matchBinaryOp(Instruction *Op) {
+  switch (Op->getOpcode()) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+    return BinaryOp(Op);
+  case Instruction::Or: {
+    // Convert or disjoint into add nuw nsw.
+    if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
+      return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
+                      /*IsNSW=*/true, /*IsNUW=*/true);
+    break;
+  }
+  case Instruction::Shl: {
+    if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
+      unsigned BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
+
+      // If the shift count is not less than the bitwidth, the result of
+      // the shift is undefined. Don't try to analyze it, because the
+      // resolution chosen here may 
diff er from the resolution chosen in
+      // other parts of the compiler.
+      if (SA->getValue().ult(BitWidth)) {
+        // We can safely preserve the nuw flag in all cases. It's also safe to
+        // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
+        // requires special handling. It can be preserved as long as we're not
+        // left shifting by bitwidth - 1.
+        bool IsNUW = Op->hasNoUnsignedWrap();
+        bool IsNSW = Op->hasNoSignedWrap() &&
+                     (IsNUW || SA->getValue().ult(BitWidth - 1));
+
+        ConstantInt *X =
+            ConstantInt::get(Op->getContext(),
+                             APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+        return BinaryOp(Instruction::Mul, Op->getOperand(0), X, IsNSW, IsNUW);
+      }
+    }
+
+    break;
+  }
+  }
+
+  return std::nullopt;
+}
+
 /// No-wrap operations can transfer sign extension of their result to their
 /// operands. Generate the SCEV value for the widened operation without
 /// actually modifying the IR yet. If the expression after extending the
@@ -1388,24 +1459,22 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
 /// extension used.
 WidenIV::WidenedRecTy
 WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
-  // Handle the common case of add<nsw/nuw>
-  const unsigned OpCode = DU.NarrowUse->getOpcode();
-  // Only Add/Sub/Mul instructions supported yet.
-  if (OpCode != Instruction::Add && OpCode != Instruction::Sub &&
-      OpCode != Instruction::Mul)
+  auto Op = matchBinaryOp(DU.NarrowUse);
+  if (!Op)
     return {nullptr, ExtendKind::Unknown};
 
+  assert((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
+          Op->Opcode == Instruction::Mul) &&
+         "Unexpected opcode");
+
   // One operand (NarrowDef) has already been extended to WideDef. Now determine
   // if extending the other will lead to a recurrence.
-  const unsigned ExtendOperIdx =
-      DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0;
-  assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU");
+  const unsigned ExtendOperIdx = Op->Operands[0] == DU.NarrowDef ? 1 : 0;
+  assert(Op->Operands[1 - ExtendOperIdx] == DU.NarrowDef && "bad DU");
 
-  const OverflowingBinaryOperator *OBO =
-    cast<OverflowingBinaryOperator>(DU.NarrowUse);
   ExtendKind ExtKind = getExtendKind(DU.NarrowDef);
-  if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) &&
-      !(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) {
+  if (!(ExtKind == ExtendKind::Sign && Op->IsNSW) &&
+      !(ExtKind == ExtendKind::Zero && Op->IsNUW)) {
     ExtKind = ExtendKind::Unknown;
 
     // For a non-negative NarrowDef, we can choose either type of
@@ -1413,16 +1482,15 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
     // (see above), and we only hit this code if we need to check
     // the opposite case.
     if (DU.NeverNegative) {
-      if (OBO->hasNoSignedWrap()) {
+      if (Op->IsNSW) {
         ExtKind = ExtendKind::Sign;
-      } else if (OBO->hasNoUnsignedWrap()) {
+      } else if (Op->IsNUW) {
         ExtKind = ExtendKind::Zero;
       }
     }
   }
 
-  const SCEV *ExtendOperExpr =
-      SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx));
+  const SCEV *ExtendOperExpr = SE->getSCEV(Op->Operands[ExtendOperIdx]);
   if (ExtKind == ExtendKind::Sign)
     ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType);
   else if (ExtKind == ExtendKind::Zero)
@@ -1443,7 +1511,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
   if (ExtendOperIdx == 0)
     std::swap(lhs, rhs);
   const SCEVAddRecExpr *AddRec =
-      dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode));
+      dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, Op->Opcode));
 
   if (!AddRec || AddRec->getLoop() != L)
     return {nullptr, ExtendKind::Unknown};

diff  --git a/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll b/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
index 0e21bf8ba347a8..59a0241bfe9fde 100644
--- a/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
+++ b/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
@@ -493,3 +493,55 @@ for.body:                                         ; preds = %for.body.lr.ph, %fo
   %cmp = icmp ult i32 %add, %length
   br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
 }
+
+; Test that we can handle shl and disjoint or in getExtendedOperandRecurrence.
+define void @foo7(i32 %n, ptr %a, i32 %x) {
+; CHECK-LABEL: @foo7(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP6:%.*]] = icmp sgt i32 [[N:%.*]], 0
+; CHECK-NEXT:    br i1 [[CMP6]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       for.body.lr.ph:
+; CHECK-NEXT:    [[ADD1:%.*]] = add nsw i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[ADD1]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[N]] to i64
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[FOR_BODY_LR_PH]] ]
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = or disjoint i64 [[TMP2]], 1
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = trunc i64 [[INDVARS_IV]] to i32
+; CHECK-NEXT:    store i32 [[TMP4]], ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nsw i64 [[INDVARS_IV]], [[TMP0]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV_NEXT]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]]
+;
+entry:
+  %cmp6 = icmp sgt i32 %n, 0
+  br i1 %cmp6, label %for.body.lr.ph, label %for.cond.cleanup
+
+for.body.lr.ph:                                   ; preds = %entry
+  %add1 = add nsw i32 %x, 2
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit, %entry
+  ret void
+
+for.body:                                         ; preds = %for.body.lr.ph, %for.body
+  %i.07 = phi i32 [ 0, %for.body.lr.ph ], [ %add2, %for.body ]
+  %mul = shl nsw i32 %i.07, 1
+  %add = or disjoint i32 %mul, 1
+  %idxprom = sext i32 %add to i64
+  %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
+  store i32 %i.07, ptr %arrayidx, align 4
+  %add2 = add nsw i32 %add1, %i.07
+  %cmp = icmp slt i32 %add2, %n
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}


        


More information about the llvm-commits mailing list