[llvm] [IndVars] Support shl by constant and or disjoint in getExtendedOperandRecurrence. (PR #84282)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 22:51:58 PST 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/84282

>From 64a8bd479c655ef0f9b8bd61b1aaa4373fb39645 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 6 Mar 2024 22:18:05 -0800
Subject: [PATCH 1/3] [IndVars] Add test for missed support for shl by constant
 and or disjoint in getExtendedOperandRecurrence. NFCwq

---
 .../IndVarSimplify/iv-widen-elim-ext.ll       | 54 +++++++++++++++++++
 1 file changed, 54 insertions(+)

diff --git a/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll b/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
index 0e21bf8ba347a8..b6f1f5f651d277 100644
--- a/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
+++ b/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
@@ -493,3 +493,57 @@ for.body:                                         ; preds = %for.body.lr.ph, %fo
   %cmp = icmp ult i32 %add, %length
   br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
 }
+
+; Test that we can handle shl and disjoint or in getExtendedOperandRecurrence.
+define void @foo7(i32 %n, ptr %a, i32 %x) {
+; CHECK-LABEL: @foo7(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP6:%.*]] = icmp sgt i32 [[N:%.*]], 0
+; CHECK-NEXT:    br i1 [[CMP6]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_COND_CLEANUP:%.*]]
+; CHECK:       for.body.lr.ph:
+; CHECK-NEXT:    [[ADD1:%.*]] = add nsw i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[ADD1]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[N]] to i64
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.cond.cleanup.loopexit:
+; CHECK-NEXT:    br label [[FOR_COND_CLEANUP]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[FOR_BODY_LR_PH]] ]
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[INDVARS_IV]] to i32
+; CHECK-NEXT:    [[MUL:%.*]] = shl nsw i32 [[TMP2]], 1
+; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[MUL]], 1
+; CHECK-NEXT:    [[IDXPROM:%.*]] = sext i32 [[ADD]] to i64
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[IDXPROM]]
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[INDVARS_IV]] to i32
+; CHECK-NEXT:    store i32 [[TMP3]], ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nsw i64 [[INDVARS_IV]], [[TMP0]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV_NEXT]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]]
+;
+entry:
+  %cmp6 = icmp sgt i32 %n, 0
+  br i1 %cmp6, label %for.body.lr.ph, label %for.cond.cleanup
+
+for.body.lr.ph:                                   ; preds = %entry
+  %add1 = add nsw i32 %x, 2
+  br label %for.body
+
+for.cond.cleanup.loopexit:                        ; preds = %for.body
+  br label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit, %entry
+  ret void
+
+for.body:                                         ; preds = %for.body.lr.ph, %for.body
+  %i.07 = phi i32 [ 0, %for.body.lr.ph ], [ %add2, %for.body ]
+  %mul = shl nsw i32 %i.07, 1
+  %add = or disjoint i32 %mul, 1
+  %idxprom = sext i32 %add to i64
+  %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom
+  store i32 %i.07, ptr %arrayidx, align 4
+  %add2 = add nsw i32 %add1, %i.07
+  %cmp = icmp slt i32 %add2, %n
+  br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
+}

>From ce3d292b3339af4dc03c3a9e8499a31974acb300 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 6 Mar 2024 22:33:01 -0800
Subject: [PATCH 2/3] [IndVars] Support shl by constant and or disjoint in
 getExtendedOperandRecurrence.

We can treat a shift by constant as a multiply by a power of 2
and we can treat an or disjoint as a 'add nsw nuw'.

I've added a helper struct similar to a struct used in ScalarEvolution.cpp
to represent the opcode, operands, and NSW/NUW flags for normal add/sub/mul
and shl/or that are being treated as mul/add.

I don't think we need to teach cloneIVUser about this. It will continue
to clone them using cloneBitwiseIVUser. After the cloning we will ask
for the SCEV expression for the cloned IV user and verify that it matches
the AddRec returned by getExtendedOperandRecurrence. Since SCEV also
knows how to convert shl to mul and or disjoint to add nsw nuw, this should
usually match. If it doesn't match, the cloned IV user will be deleted.
---
 llvm/lib/Transforms/Utils/SimplifyIndVar.cpp  | 100 +++++++++++++++---
 .../IndVarSimplify/iv-widen-elim-ext.ll       |  12 +--
 2 files changed, 88 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index b8fa985fa3462e..7896f019132a04 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -1381,6 +1381,76 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
   };
 }
 
+namespace {
+
+// Represents a interesting integer binary operation for
+// getExtendedOperandRecurrence. This may be a shl that is being treated as a
+// multiply or a 'or disjoint' that is being treated as 'add nsw nuw'.
+struct BinaryOp {
+  unsigned Opcode;
+  std::array<Value *, 2> Operands;
+  bool IsNSW = false;
+  bool IsNUW = false;
+
+  explicit BinaryOp(Instruction *Op)
+      : Opcode(Op->getOpcode()),
+        Operands({Op->getOperand(0), Op->getOperand(1)}) {
+    if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
+      IsNSW = OBO->hasNoSignedWrap();
+      IsNUW = OBO->hasNoUnsignedWrap();
+    }
+  }
+
+  explicit BinaryOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
+                    bool IsNSW = false, bool IsNUW = false)
+      : Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {}
+};
+
+} // end anonymous namespace
+
+static std::optional<BinaryOp> matchBinaryOp(Instruction *Op) {
+  switch (Op->getOpcode()) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+    return BinaryOp(Op);
+  case Instruction::Or: {
+    // Convert or disjoint into add nuw nsw.
+    if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
+      return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
+                      /*IsNSW=*/true, /*IsNUW=*/true);
+    break;
+  }
+  case Instruction::Shl: {
+    if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
+      unsigned BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
+
+      // If the shift count is not less than the bitwidth, the result of
+      // the shift is undefined. Don't try to analyze it, because the
+      // resolution chosen here may differ from the resolution chosen in
+      // other parts of the compiler.
+      if (SA->getValue().ult(BitWidth)) {
+        // We can safely preserve the nuw flag in all cases. It's also safe to
+        // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
+        // requires special handling. It can be preserved as long as we're not
+        // left shifting by bitwidth - 1.
+        bool IsNUW = Op->hasNoUnsignedWrap();
+        bool IsNSW =
+            Op->hasNoSignedWrap() && (IsNUW || SA->getValue().ult(BitWidth - 1));
+
+        ConstantInt *X = ConstantInt::get(
+            Op->getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+        return BinaryOp(Instruction::Mul, Op->getOperand(0), X, IsNSW, IsNUW);
+      }
+    }
+
+    break;
+  }
+  }
+
+  return std::nullopt;
+}
+
 /// No-wrap operations can transfer sign extension of their result to their
 /// operands. Generate the SCEV value for the widened operation without
 /// actually modifying the IR yet. If the expression after extending the
@@ -1388,24 +1458,21 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
 /// extension used.
 WidenIV::WidenedRecTy
 WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
-  // Handle the common case of add<nsw/nuw>
-  const unsigned OpCode = DU.NarrowUse->getOpcode();
-  // Only Add/Sub/Mul instructions supported yet.
-  if (OpCode != Instruction::Add && OpCode != Instruction::Sub &&
-      OpCode != Instruction::Mul)
+  auto Op = matchBinaryOp(DU.NarrowUse);
+  if (!Op)
     return {nullptr, ExtendKind::Unknown};
 
+  assert((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
+          Op->Opcode == Instruction::Mul) && "Unexpected opcode");
+
   // One operand (NarrowDef) has already been extended to WideDef. Now determine
   // if extending the other will lead to a recurrence.
-  const unsigned ExtendOperIdx =
-      DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0;
-  assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU");
+  const unsigned ExtendOperIdx = Op->Operands[0] == DU.NarrowDef ? 1 : 0;
+  assert(Op->Operands[1-ExtendOperIdx] == DU.NarrowDef && "bad DU");
 
-  const OverflowingBinaryOperator *OBO =
-    cast<OverflowingBinaryOperator>(DU.NarrowUse);
   ExtendKind ExtKind = getExtendKind(DU.NarrowDef);
-  if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) &&
-      !(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) {
+  if (!(ExtKind == ExtendKind::Sign && Op->IsNSW) &&
+      !(ExtKind == ExtendKind::Zero && Op->IsNUW)) {
     ExtKind = ExtendKind::Unknown;
 
     // For a non-negative NarrowDef, we can choose either type of
@@ -1413,16 +1480,15 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
     // (see above), and we only hit this code if we need to check
     // the opposite case.
     if (DU.NeverNegative) {
-      if (OBO->hasNoSignedWrap()) {
+      if (Op->IsNSW) {
         ExtKind = ExtendKind::Sign;
-      } else if (OBO->hasNoUnsignedWrap()) {
+      } else if (Op->IsNUW) {
         ExtKind = ExtendKind::Zero;
       }
     }
   }
 
-  const SCEV *ExtendOperExpr =
-      SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx));
+  const SCEV *ExtendOperExpr = SE->getSCEV(Op->Operands[ExtendOperIdx]);
   if (ExtKind == ExtendKind::Sign)
     ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType);
   else if (ExtKind == ExtendKind::Zero)
@@ -1443,7 +1509,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
   if (ExtendOperIdx == 0)
     std::swap(lhs, rhs);
   const SCEVAddRecExpr *AddRec =
-      dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode));
+      dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, Op->Opcode));
 
   if (!AddRec || AddRec->getLoop() != L)
     return {nullptr, ExtendKind::Unknown};
diff --git a/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll b/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
index b6f1f5f651d277..59a0241bfe9fde 100644
--- a/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
+++ b/llvm/test/Transforms/IndVarSimplify/iv-widen-elim-ext.ll
@@ -511,13 +511,11 @@ define void @foo7(i32 %n, ptr %a, i32 %x) {
 ; CHECK-NEXT:    ret void
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 0, [[FOR_BODY_LR_PH]] ]
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i64 [[INDVARS_IV]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = shl nsw i32 [[TMP2]], 1
-; CHECK-NEXT:    [[ADD:%.*]] = or disjoint i32 [[MUL]], 1
-; CHECK-NEXT:    [[IDXPROM:%.*]] = sext i32 [[ADD]] to i64
-; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[IDXPROM]]
-; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[INDVARS_IV]] to i32
-; CHECK-NEXT:    store i32 [[TMP3]], ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = or disjoint i64 [[TMP2]], 1
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = trunc i64 [[INDVARS_IV]] to i32
+; CHECK-NEXT:    store i32 [[TMP4]], ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nsw i64 [[INDVARS_IV]], [[TMP0]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV_NEXT]], [[TMP1]]
 ; CHECK-NEXT:    br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]]

>From 5f30e22d399f84e3c929d5b4f129eb98898912c2 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 6 Mar 2024 22:51:44 -0800
Subject: [PATCH 3/3] fixup! clang-format

---
 llvm/lib/Transforms/Utils/SimplifyIndVar.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index 7896f019132a04..f6d72265305d7d 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -1435,11 +1435,12 @@ static std::optional<BinaryOp> matchBinaryOp(Instruction *Op) {
         // requires special handling. It can be preserved as long as we're not
         // left shifting by bitwidth - 1.
         bool IsNUW = Op->hasNoUnsignedWrap();
-        bool IsNSW =
-            Op->hasNoSignedWrap() && (IsNUW || SA->getValue().ult(BitWidth - 1));
+        bool IsNSW = Op->hasNoSignedWrap() &&
+                     (IsNUW || SA->getValue().ult(BitWidth - 1));
 
-        ConstantInt *X = ConstantInt::get(
-            Op->getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+        ConstantInt *X =
+            ConstantInt::get(Op->getContext(),
+                             APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
         return BinaryOp(Instruction::Mul, Op->getOperand(0), X, IsNSW, IsNUW);
       }
     }
@@ -1463,12 +1464,13 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
     return {nullptr, ExtendKind::Unknown};
 
   assert((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
-          Op->Opcode == Instruction::Mul) && "Unexpected opcode");
+          Op->Opcode == Instruction::Mul) &&
+         "Unexpected opcode");
 
   // One operand (NarrowDef) has already been extended to WideDef. Now determine
   // if extending the other will lead to a recurrence.
   const unsigned ExtendOperIdx = Op->Operands[0] == DU.NarrowDef ? 1 : 0;
-  assert(Op->Operands[1-ExtendOperIdx] == DU.NarrowDef && "bad DU");
+  assert(Op->Operands[1 - ExtendOperIdx] == DU.NarrowDef && "bad DU");
 
   ExtendKind ExtKind = getExtendKind(DU.NarrowDef);
   if (!(ExtKind == ExtendKind::Sign && Op->IsNSW) &&



More information about the llvm-commits mailing list