[Mlir-commits] [mlir] [MLIR] Add `IntegerDivisibilityAnalysis` and `InferIntDivisibilityOpInterface` (PR #197728)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 14 09:01:10 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Alan Li (lialan)

<details>
<summary>Changes</summary>

This patch is a port from https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp to upstream

It introduces a dataflow analysis that tracks integer divisibility (divisor + remainder lattice) for SSA values, plus an op interface `InferIntDivisibilityOpInterface` for ops to participate.

It adds:
* `IntegerDivisibilityAnalysis` produces a `Divisibility` lattice `{divisor, remainder}`
* `InferIntDivisibilityOpInterface` interface
* External-model implementations for `arith` and `affine` ops
* `test-int-divisibility` test pass + lit tests

Example:
Here is the usual approach to laod element `i` from `i4` buffer emulated in `i8` buffer:
```mlir
  %byte_idx = arith.divui %i, %c2 : index
  %bit_off  = arith.remui %i, %c2 : index
  %byte     = memref.load %buf[%byte_idx] : memref<?xi8>
  %shifted  = arith.shrui %byte, %bit_off_as_i8
  %nibble   = arith.andi  %shifted, %c0xF : i8
```
If `i` comes from `affine.apply (d0) -> (d0 * 8 + 0)` over loop IV, analysis derives `divisor=8, remainder=0` for `%i`. Then we can emit more concise code by avoiding the shift. There are more scenarios that could benefit from this analysis.


---

Patch is 57.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/197728.diff


19 Files Affected:

- (added) mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h (+64) 
- (added) mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h (+21) 
- (added) mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h (+21) 
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1) 
- (added) mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h (+120) 
- (added) mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td (+41) 
- (modified) mlir/lib/Analysis/CMakeLists.txt (+3) 
- (added) mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp (+135) 
- (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+2) 
- (added) mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp (+368) 
- (modified) mlir/lib/Dialect/Arith/IR/CMakeLists.txt (+3) 
- (added) mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp (+162) 
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+2) 
- (added) mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp (+11) 
- (modified) mlir/lib/RegisterAllDialects.cpp (+4) 
- (added) mlir/test/Analysis/DataFlow/integer-divisibility.mlir (+152) 
- (modified) mlir/test/lib/Analysis/CMakeLists.txt (+2) 
- (added) mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp (+97) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
new file mode 100644
index 0000000000000..3a877647490a3
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
@@ -0,0 +1,64 @@
+//===- IntegerDivisibilityAnalysis.h - Integer divisibility -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer divisibility
+// inference. Operations participate in the analysis by implementing
+// `InferIntDivisibilityOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
+#define MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
+
+namespace mlir::dataflow {
+
+/// This lattice element represents the integer divisibility of an SSA value.
+class IntegerDivisibilityLattice : public Lattice<IntegerDivisibility> {
+public:
+  using Lattice::Lattice;
+};
+
+/// Integer divisibility analysis determines, for each integer-typed SSA
+/// value, a divisor that the value is guaranteed to be a multiple of. It
+/// uses operations that implement `InferIntDivisibilityOpInterface` and
+/// also sets the divisibility of induction variables of loops with known
+/// lower bounds and steps.
+///
+/// This analysis depends on DeadCodeAnalysis, and will be a silent no-op
+/// if DeadCodeAnalysis is not loaded in the same solver context.
+class IntegerDivisibilityAnalysis
+    : public SparseForwardDataFlowAnalysis<IntegerDivisibilityLattice> {
+public:
+  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+
+  /// At an entry point, set the lattice to the most pessimistic state,
+  /// indicating that no further reasoning can be done.
+  void setToEntryState(IntegerDivisibilityLattice *lattice) override;
+
+  /// Visit an operation, invoking the transfer function.
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const IntegerDivisibilityLattice *> operands,
+                 ArrayRef<IntegerDivisibilityLattice *> results) override;
+
+  /// Visit block arguments or operation results of an operation with region
+  /// control-flow for which values are not defined by region control-flow. This
+  /// function tries to infer the divisibility of loop induction variables based
+  /// on known loop bounds and steps.
+  void visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &successor,
+      ValueRange successorInputs,
+      ArrayRef<IntegerDivisibilityLattice *> argLattices) override;
+};
+
+} // namespace mlir::dataflow
+
+#endif // MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
diff --git a/mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h
new file mode 100644
index 0000000000000..560bef63b56a8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- InferIntDivisibilityOpInterfaceImpl.h --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AFFINE_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+#define MLIR_DIALECT_AFFINE_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace affine {
+void registerInferIntDivisibilityOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace affine
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AFFINE_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h
new file mode 100644
index 0000000000000..0909790b8b7d0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- InferIntDivisibilityOpInterfaceImpl.h --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerInferIntDivisibilityOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 3cbc9df05f3d7..6461c68423c73 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
 add_mlir_interface(IndexingMapOpInterface)
+add_mlir_interface(InferIntDivisibilityOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferStridedMetadataInterface)
 add_mlir_interface(InferTypeOpInterface)
diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
new file mode 100644
index 0000000000000..374acee05cb10
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
@@ -0,0 +1,120 @@
+//===- InferIntDivisibilityOpInterface.h - Integer Divisibility -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the integer divisibility inference
+// interface defined in `InferIntDivisibilityOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
+#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+#include <numeric>
+#include <optional>
+
+namespace mlir {
+
+/// Statically known divisibility information for an integer SSA value.
+/// Tracks separate divisors for the unsigned and signed interpretations of
+/// the value so that subsequent analyses can use whichever is more precise.
+class ConstantIntDivisibility {
+public:
+  ConstantIntDivisibility() = default;
+  ConstantIntDivisibility(uint64_t udiv, uint64_t sdiv)
+      : udivVal(udiv), sdivVal(sdiv) {}
+
+  bool operator==(const ConstantIntDivisibility &other) const {
+    return udivVal == other.udivVal && sdivVal == other.sdivVal;
+  }
+
+  uint64_t udiv() const { return this->udivVal; }
+  uint64_t sdiv() const { return this->sdivVal; }
+
+  // Returns the union (computed separately for signed and unsigned bounds)
+  // for this divisibility and `other`.
+  ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const {
+    return ConstantIntDivisibility(
+        /*udiv=*/std::gcd(udiv(), other.udiv()),
+        /*sdiv=*/std::gcd(sdiv(), other.sdiv()));
+  }
+
+private:
+  uint64_t udivVal;
+  uint64_t sdivVal;
+
+  friend raw_ostream &operator<<(raw_ostream &os,
+                                 const ConstantIntDivisibility &div);
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const ConstantIntDivisibility &div) {
+  os << "ConstantIntDivisibility(udiv = " << div.udivVal
+     << ", sdiv = " << div.sdivVal << ")";
+  return os;
+}
+
+/// This lattice value represents the integer divisibility of an SSA value.
+class IntegerDivisibility {
+public:
+  IntegerDivisibility(ConstantIntDivisibility value)
+      : value(std::move(value)) {}
+  explicit IntegerDivisibility(
+      std::optional<ConstantIntDivisibility> value = std::nullopt)
+      : value(std::move(value)) {}
+  // Gets the minimum divisibility of 1 that is used to indicate that the value
+  // cannot be analyzed further.
+  static IntegerDivisibility getMinDivisibility() {
+    return IntegerDivisibility(ConstantIntDivisibility(1, 1));
+  }
+
+  bool isUninitialized() const { return !value.has_value(); }
+  const ConstantIntDivisibility &getValue() const {
+    assert(!isUninitialized());
+    return *value;
+  }
+
+  bool operator==(const IntegerDivisibility &rhs) const {
+    return value == rhs.value;
+  }
+
+  static IntegerDivisibility join(const IntegerDivisibility &lhs,
+                                  const IntegerDivisibility &rhs) {
+    if (lhs.isUninitialized()) {
+      return rhs;
+    }
+    if (rhs.isUninitialized()) {
+      return lhs;
+    }
+    return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue()));
+  }
+
+  void print(raw_ostream &os) const { os << value; }
+
+private:
+  std::optional<ConstantIntDivisibility> value;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const IntegerDivisibility &div) {
+  div.print(os);
+  return os;
+}
+
+/// The type of the `setResultDivs` callback provided to ops implementing
+/// InferIntDivisibilityOpInterface. It should be called once for each integer
+/// result value and be passed the ConstantIntDivisibility corresponding to
+/// that value.
+using SetIntDivisibilityFn =
+    llvm::function_ref<void(Value, const ConstantIntDivisibility &)>;
+
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
new file mode 100644
index 0000000000000..c665475e0fd7f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
@@ -0,0 +1,41 @@
+//===- InferIntDivisibilityOpInterface.td - Integer Divisibility -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for divisibility analysis on scalar integers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
+#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferIntDivisibilityOpInterface :
+    OpInterface<"InferIntDivisibilityOpInterface"> {
+  let description = [{
+    Allows operations to participate in integer divisibility analysis.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Infer the divisibility of the results of this op given the
+        divisibility of its arguments. For each result value, the method
+        should call `setResultDivs` with that `Value` as an argument.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"inferResultDivisibility",
+      /*args=*/(ins
+        "::llvm::ArrayRef<::mlir::IntegerDivisibility>":$argDivs,
+        "::mlir::SetIntDivisibilityFn":$setResultDivs)
+    >
+  ];
+}
+
+#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index db10ebcf2c311..596ffaff428b5 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES
   DataFlow/ConstantPropagationAnalysis.cpp
   DataFlow/DeadCodeAnalysis.cpp
   DataFlow/DenseAnalysis.cpp
+  DataFlow/IntegerDivisibilityAnalysis.cpp
   DataFlow/IntegerRangeAnalysis.cpp
   DataFlow/LivenessAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
@@ -37,6 +38,7 @@ add_mlir_library(MLIRAnalysis
   DataFlow/ConstantPropagationAnalysis.cpp
   DataFlow/DeadCodeAnalysis.cpp
   DataFlow/DenseAnalysis.cpp
+  DataFlow/IntegerDivisibilityAnalysis.cpp
   DataFlow/IntegerRangeAnalysis.cpp
   DataFlow/LivenessAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
@@ -53,6 +55,7 @@ add_mlir_library(MLIRAnalysis
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
   MLIRFunctionInterfaces
+  MLIRInferIntDivisibilityOpInterface
   MLIRInferIntRangeInterface
   MLIRInferStridedMetadataInterface
   MLIRInferTypeOpInterface
diff --git a/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
new file mode 100644
index 0000000000000..ba10a8b5a0060
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
@@ -0,0 +1,135 @@
+//===- IntegerDivisibilityAnalysis.cpp - Integer divisibility ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer divisibility
+// inference. Operations participate in the analysis by implementing
+// `InferIntDivisibilityOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h"
+
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-divisibility-analysis"
+
+using llvm::dbgs;
+
+namespace mlir::dataflow {
+
+void IntegerDivisibilityAnalysis::setToEntryState(
+    IntegerDivisibilityLattice *lattice) {
+  propagateIfChanged(lattice,
+                     lattice->join(IntegerDivisibility::getMinDivisibility()));
+}
+
+LogicalResult IntegerDivisibilityAnalysis::visitOperation(
+    Operation *op, ArrayRef<const IntegerDivisibilityLattice *> operands,
+    ArrayRef<IntegerDivisibilityLattice *> results) {
+  auto inferrable = dyn_cast<InferIntDivisibilityOpInterface>(op);
+  if (!inferrable) {
+    setAllToEntryStates(results);
+    return success();
+  }
+
+  LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n");
+  auto argDivs = llvm::map_to_vector(
+      operands, [](const IntegerDivisibilityLattice *lattice) {
+        return lattice->getValue();
+      });
+  auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) {
+    auto result = dyn_cast<OpResult>(v);
+    if (!result) {
+      return;
+    }
+    assert(llvm::is_contained(op->getResults(), result));
+
+    LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n");
+    IntegerDivisibilityLattice *lattice = results[result.getResultNumber()];
+    IntegerDivisibility oldDiv = lattice->getValue();
+
+    ChangeResult changed = lattice->join(newDiv);
+
+    // Catch loop results with loop-variant divisibility and conservatively
+    // set them to divisibility 1 (no information) so we don't ratchet
+    // indefinitely (the dataflow analysis in MLIR doesn't attempt to work
+    // out trip counts and often can't).
+    bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+      return op->hasTrait<OpTrait::IsTerminator>();
+    });
+    if (isYieldedResult && !oldDiv.isUninitialized() &&
+        !(lattice->getValue() == oldDiv)) {
+      LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+      changed |= lattice->join(IntegerDivisibility::getMinDivisibility());
+    }
+    propagateIfChanged(lattice, changed);
+  };
+
+  inferrable.inferResultDivisibility(argDivs, joinCallback);
+  return success();
+}
+
+void IntegerDivisibilityAnalysis::visitNonControlFlowArguments(
+    Operation *op, const RegionSuccessor &successor, ValueRange successorInputs,
+    ArrayRef<IntegerDivisibilityLattice *> argLattices) {
+  // Get the constant divisibility, or query the lattice for Values.
+  auto getDivFromOfr = [&](std::optional<OpFoldResult> ofr, Block *block,
+                           bool isUnsigned) -> uint64_t {
+    if (ofr.has_value()) {
+      if (auto constBound = getConstantIntValue(*ofr)) {
+        return constBound.value();
+      }
+      auto value = cast<Value>(ofr.value());
+      const IntegerDivisibilityLattice *lattice =
+          getLatticeElementFor(getProgramPointBefore(block), value);
+      if (lattice != nullptr && !lattice->getValue().isUninitialized()) {
+        return isUnsigned ? lattice->getValue().getValue().udiv()
+                          : lattice->getValue().getValue().sdiv();
+      }
+    }
+    return isUnsigned
+               ? IntegerDivisibility::getMinDivisibility().getValue().udiv()
+               : IntegerDivisibility::getMinDivisibility().getValue().sdiv();
+  };
+
+  // Infer bounds for loop arguments that have static bounds
+  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+    std::optional<SmallVector<Value>> ivs = loop.getLoopInductionVars();
+    std::optional<SmallVector<OpFoldResult>> lbs = loop.getLoopLowerBounds();
+    std::optional<SmallVector<OpFoldResult>> steps = loop.getLoopSteps();
+    if (!ivs || !lbs || !steps) {
+      return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+          op, successor, successorInputs, argLattices);
+    }
+    for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) {
+      IntegerDivisibilityLattice *ivEntry = getLatticeElement(iv);
+      Block *block = iv.getParentBlock();
+      uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true);
+      uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false);
+      uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true);
+      uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false);
+      ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv);
+      ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv);
+
+      // Loop induction variables are computed as `lb + i * step`. The
+      // divisibility for `i * step` is just the divisibility of `step`, so
+      // the total divisibility is obtained by unioning the step divisibility
+      // with the lower bound divisibility, which takes the GCD of the two.
+      ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv);
+      propagateIfChanged(ivEntry, ivEntry->join(ivDiv));
+    }
+    return;
+  }
+
+  return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+      op, successor, successorInputs, argLattices);
+}
+
+} // namespace mlir::dataflow
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 566bc060e5d38..1caf2fa396797 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect
   AffineMemoryOpInterfaces.cpp
   AffineOps.cpp
   AffineValueMap.cpp
+  InferIntDivisibilityOpInterfaceImpl.cpp
   InferIntRangeInterfaceImpls.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
@@ -16,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineDialect
   MLIRArithDialect
   MLIRDialectUtils
   MLIRIR
+  MLIRInferIntDivisibilityOpInterface
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRLoopLikeInterface
diff --git a/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..30850cf0a4df0
--- /dev/null
+++ b/mlir/lib/D...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/197728


More information about the Mlir-commits mailing list