[Mlir-commits] [mlir] [MLIR] Add `IntegerDivisibilityAnalysis` and `InferIntDivisibilityOpInterface` (PR #197728)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 14 09:01:10 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Alan Li (lialan)
<details>
<summary>Changes</summary>
This patch is a port from https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp to upstream
It introduces a dataflow analysis that tracks integer divisibility (divisor + remainder lattice) for SSA values, plus an op interface `InferIntDivisibilityOpInterface` for ops to participate.
It adds:
* `IntegerDivisibilityAnalysis` produces a `Divisibility` lattice `{divisor, remainder}`
* `InferIntDivisibilityOpInterface` interface
* External-model implementations for `arith` and `affine` ops
* `test-int-divisibility` test pass + lit tests
Example:
Here is the usual approach to laod element `i` from `i4` buffer emulated in `i8` buffer:
```mlir
%byte_idx = arith.divui %i, %c2 : index
%bit_off = arith.remui %i, %c2 : index
%byte = memref.load %buf[%byte_idx] : memref<?xi8>
%shifted = arith.shrui %byte, %bit_off_as_i8
%nibble = arith.andi %shifted, %c0xF : i8
```
If `i` comes from `affine.apply (d0) -> (d0 * 8 + 0)` over loop IV, analysis derives `divisor=8, remainder=0` for `%i`. Then we can emit more concise code by avoiding the shift. There are more scenarios that could benefit from this analysis.
---
Patch is 57.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/197728.diff
19 Files Affected:
- (added) mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h (+64)
- (added) mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h (+21)
- (added) mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h (+21)
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1)
- (added) mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h (+120)
- (added) mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td (+41)
- (modified) mlir/lib/Analysis/CMakeLists.txt (+3)
- (added) mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp (+135)
- (modified) mlir/lib/Dialect/Affine/IR/CMakeLists.txt (+2)
- (added) mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp (+368)
- (modified) mlir/lib/Dialect/Arith/IR/CMakeLists.txt (+3)
- (added) mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp (+162)
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+2)
- (added) mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp (+11)
- (modified) mlir/lib/RegisterAllDialects.cpp (+4)
- (added) mlir/test/Analysis/DataFlow/integer-divisibility.mlir (+152)
- (modified) mlir/test/lib/Analysis/CMakeLists.txt (+2)
- (added) mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp (+97)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
new file mode 100644
index 0000000000000..3a877647490a3
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
@@ -0,0 +1,64 @@
+//===- IntegerDivisibilityAnalysis.h - Integer divisibility -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer divisibility
+// inference. Operations participate in the analysis by implementing
+// `InferIntDivisibilityOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
+#define MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
+
+namespace mlir::dataflow {
+
+/// This lattice element represents the integer divisibility of an SSA value.
+class IntegerDivisibilityLattice : public Lattice<IntegerDivisibility> {
+public:
+ using Lattice::Lattice;
+};
+
+/// Integer divisibility analysis determines, for each integer-typed SSA
+/// value, a divisor that the value is guaranteed to be a multiple of. It
+/// uses operations that implement `InferIntDivisibilityOpInterface` and
+/// also sets the divisibility of induction variables of loops with known
+/// lower bounds and steps.
+///
+/// This analysis depends on DeadCodeAnalysis, and will be a silent no-op
+/// if DeadCodeAnalysis is not loaded in the same solver context.
+class IntegerDivisibilityAnalysis
+ : public SparseForwardDataFlowAnalysis<IntegerDivisibilityLattice> {
+public:
+ using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+
+ /// At an entry point, set the lattice to the most pessimistic state,
+ /// indicating that no further reasoning can be done.
+ void setToEntryState(IntegerDivisibilityLattice *lattice) override;
+
+ /// Visit an operation, invoking the transfer function.
+ LogicalResult
+ visitOperation(Operation *op,
+ ArrayRef<const IntegerDivisibilityLattice *> operands,
+ ArrayRef<IntegerDivisibilityLattice *> results) override;
+
+ /// Visit block arguments or operation results of an operation with region
+ /// control-flow for which values are not defined by region control-flow. This
+ /// function tries to infer the divisibility of loop induction variables based
+ /// on known loop bounds and steps.
+ void visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ValueRange successorInputs,
+ ArrayRef<IntegerDivisibilityLattice *> argLattices) override;
+};
+
+} // namespace mlir::dataflow
+
+#endif // MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
diff --git a/mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h
new file mode 100644
index 0000000000000..560bef63b56a8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- InferIntDivisibilityOpInterfaceImpl.h --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AFFINE_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+#define MLIR_DIALECT_AFFINE_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace affine {
+void registerInferIntDivisibilityOpInterfaceExternalModels(
+ DialectRegistry ®istry);
+} // namespace affine
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AFFINE_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h
new file mode 100644
index 0000000000000..0909790b8b7d0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- InferIntDivisibilityOpInterfaceImpl.h --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerInferIntDivisibilityOpInterfaceExternalModels(
+ DialectRegistry ®istry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_IR_INFERINTDIVISIBILITYOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 3cbc9df05f3d7..6461c68423c73 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
+add_mlir_interface(InferIntDivisibilityOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferStridedMetadataInterface)
add_mlir_interface(InferTypeOpInterface)
diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
new file mode 100644
index 0000000000000..374acee05cb10
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
@@ -0,0 +1,120 @@
+//===- InferIntDivisibilityOpInterface.h - Integer Divisibility -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the integer divisibility inference
+// interface defined in `InferIntDivisibilityOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
+#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+#include <numeric>
+#include <optional>
+
+namespace mlir {
+
+/// Statically known divisibility information for an integer SSA value.
+/// Tracks separate divisors for the unsigned and signed interpretations of
+/// the value so that subsequent analyses can use whichever is more precise.
+class ConstantIntDivisibility {
+public:
+ ConstantIntDivisibility() = default;
+ ConstantIntDivisibility(uint64_t udiv, uint64_t sdiv)
+ : udivVal(udiv), sdivVal(sdiv) {}
+
+ bool operator==(const ConstantIntDivisibility &other) const {
+ return udivVal == other.udivVal && sdivVal == other.sdivVal;
+ }
+
+ uint64_t udiv() const { return this->udivVal; }
+ uint64_t sdiv() const { return this->sdivVal; }
+
+ // Returns the union (computed separately for signed and unsigned bounds)
+ // for this divisibility and `other`.
+ ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const {
+ return ConstantIntDivisibility(
+ /*udiv=*/std::gcd(udiv(), other.udiv()),
+ /*sdiv=*/std::gcd(sdiv(), other.sdiv()));
+ }
+
+private:
+ uint64_t udivVal;
+ uint64_t sdivVal;
+
+ friend raw_ostream &operator<<(raw_ostream &os,
+ const ConstantIntDivisibility &div);
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+ const ConstantIntDivisibility &div) {
+ os << "ConstantIntDivisibility(udiv = " << div.udivVal
+ << ", sdiv = " << div.sdivVal << ")";
+ return os;
+}
+
+/// This lattice value represents the integer divisibility of an SSA value.
+class IntegerDivisibility {
+public:
+ IntegerDivisibility(ConstantIntDivisibility value)
+ : value(std::move(value)) {}
+ explicit IntegerDivisibility(
+ std::optional<ConstantIntDivisibility> value = std::nullopt)
+ : value(std::move(value)) {}
+ // Gets the minimum divisibility of 1 that is used to indicate that the value
+ // cannot be analyzed further.
+ static IntegerDivisibility getMinDivisibility() {
+ return IntegerDivisibility(ConstantIntDivisibility(1, 1));
+ }
+
+ bool isUninitialized() const { return !value.has_value(); }
+ const ConstantIntDivisibility &getValue() const {
+ assert(!isUninitialized());
+ return *value;
+ }
+
+ bool operator==(const IntegerDivisibility &rhs) const {
+ return value == rhs.value;
+ }
+
+ static IntegerDivisibility join(const IntegerDivisibility &lhs,
+ const IntegerDivisibility &rhs) {
+ if (lhs.isUninitialized()) {
+ return rhs;
+ }
+ if (rhs.isUninitialized()) {
+ return lhs;
+ }
+ return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue()));
+ }
+
+ void print(raw_ostream &os) const { os << value; }
+
+private:
+ std::optional<ConstantIntDivisibility> value;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+ const IntegerDivisibility &div) {
+ div.print(os);
+ return os;
+}
+
+/// The type of the `setResultDivs` callback provided to ops implementing
+/// InferIntDivisibilityOpInterface. It should be called once for each integer
+/// result value and be passed the ConstantIntDivisibility corresponding to
+/// that value.
+using SetIntDivisibilityFn =
+ llvm::function_ref<void(Value, const ConstantIntDivisibility &)>;
+
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
new file mode 100644
index 0000000000000..c665475e0fd7f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
@@ -0,0 +1,41 @@
+//===- InferIntDivisibilityOpInterface.td - Integer Divisibility -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for divisibility analysis on scalar integers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
+#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferIntDivisibilityOpInterface :
+ OpInterface<"InferIntDivisibilityOpInterface"> {
+ let description = [{
+ Allows operations to participate in integer divisibility analysis.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Infer the divisibility of the results of this op given the
+ divisibility of its arguments. For each result value, the method
+ should call `setResultDivs` with that `Value` as an argument.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"inferResultDivisibility",
+ /*args=*/(ins
+ "::llvm::ArrayRef<::mlir::IntegerDivisibility>":$argDivs,
+ "::mlir::SetIntDivisibilityFn":$setResultDivs)
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index db10ebcf2c311..596ffaff428b5 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/DenseAnalysis.cpp
+ DataFlow/IntegerDivisibilityAnalysis.cpp
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
@@ -37,6 +38,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/DenseAnalysis.cpp
+ DataFlow/IntegerDivisibilityAnalysis.cpp
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
@@ -53,6 +55,7 @@ add_mlir_library(MLIRAnalysis
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
+ MLIRInferIntDivisibilityOpInterface
MLIRInferIntRangeInterface
MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
diff --git a/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
new file mode 100644
index 0000000000000..ba10a8b5a0060
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
@@ -0,0 +1,135 @@
+//===- IntegerDivisibilityAnalysis.cpp - Integer divisibility ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer divisibility
+// inference. Operations participate in the analysis by implementing
+// `InferIntDivisibilityOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h"
+
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-divisibility-analysis"
+
+using llvm::dbgs;
+
+namespace mlir::dataflow {
+
+void IntegerDivisibilityAnalysis::setToEntryState(
+ IntegerDivisibilityLattice *lattice) {
+ propagateIfChanged(lattice,
+ lattice->join(IntegerDivisibility::getMinDivisibility()));
+}
+
+LogicalResult IntegerDivisibilityAnalysis::visitOperation(
+ Operation *op, ArrayRef<const IntegerDivisibilityLattice *> operands,
+ ArrayRef<IntegerDivisibilityLattice *> results) {
+ auto inferrable = dyn_cast<InferIntDivisibilityOpInterface>(op);
+ if (!inferrable) {
+ setAllToEntryStates(results);
+ return success();
+ }
+
+ LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n");
+ auto argDivs = llvm::map_to_vector(
+ operands, [](const IntegerDivisibilityLattice *lattice) {
+ return lattice->getValue();
+ });
+ auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) {
+ auto result = dyn_cast<OpResult>(v);
+ if (!result) {
+ return;
+ }
+ assert(llvm::is_contained(op->getResults(), result));
+
+ LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n");
+ IntegerDivisibilityLattice *lattice = results[result.getResultNumber()];
+ IntegerDivisibility oldDiv = lattice->getValue();
+
+ ChangeResult changed = lattice->join(newDiv);
+
+ // Catch loop results with loop-variant divisibility and conservatively
+ // set them to divisibility 1 (no information) so we don't ratchet
+ // indefinitely (the dataflow analysis in MLIR doesn't attempt to work
+ // out trip counts and often can't).
+ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedResult && !oldDiv.isUninitialized() &&
+ !(lattice->getValue() == oldDiv)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= lattice->join(IntegerDivisibility::getMinDivisibility());
+ }
+ propagateIfChanged(lattice, changed);
+ };
+
+ inferrable.inferResultDivisibility(argDivs, joinCallback);
+ return success();
+}
+
+void IntegerDivisibilityAnalysis::visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor, ValueRange successorInputs,
+ ArrayRef<IntegerDivisibilityLattice *> argLattices) {
+ // Get the constant divisibility, or query the lattice for Values.
+ auto getDivFromOfr = [&](std::optional<OpFoldResult> ofr, Block *block,
+ bool isUnsigned) -> uint64_t {
+ if (ofr.has_value()) {
+ if (auto constBound = getConstantIntValue(*ofr)) {
+ return constBound.value();
+ }
+ auto value = cast<Value>(ofr.value());
+ const IntegerDivisibilityLattice *lattice =
+ getLatticeElementFor(getProgramPointBefore(block), value);
+ if (lattice != nullptr && !lattice->getValue().isUninitialized()) {
+ return isUnsigned ? lattice->getValue().getValue().udiv()
+ : lattice->getValue().getValue().sdiv();
+ }
+ }
+ return isUnsigned
+ ? IntegerDivisibility::getMinDivisibility().getValue().udiv()
+ : IntegerDivisibility::getMinDivisibility().getValue().sdiv();
+ };
+
+ // Infer bounds for loop arguments that have static bounds
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+ std::optional<SmallVector<Value>> ivs = loop.getLoopInductionVars();
+ std::optional<SmallVector<OpFoldResult>> lbs = loop.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loop.getLoopSteps();
+ if (!ivs || !lbs || !steps) {
+ return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, successorInputs, argLattices);
+ }
+ for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) {
+ IntegerDivisibilityLattice *ivEntry = getLatticeElement(iv);
+ Block *block = iv.getParentBlock();
+ uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true);
+ uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false);
+ uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true);
+ uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false);
+ ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv);
+ ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv);
+
+ // Loop induction variables are computed as `lb + i * step`. The
+ // divisibility for `i * step` is just the divisibility of `step`, so
+ // the total divisibility is obtained by unioning the step divisibility
+ // with the lower bound divisibility, which takes the GCD of the two.
+ ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv);
+ propagateIfChanged(ivEntry, ivEntry->join(ivDiv));
+ }
+ return;
+ }
+
+ return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, successorInputs, argLattices);
+}
+
+} // namespace mlir::dataflow
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 566bc060e5d38..1caf2fa396797 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect
AffineMemoryOpInterfaces.cpp
AffineOps.cpp
AffineValueMap.cpp
+ InferIntDivisibilityOpInterfaceImpl.cpp
InferIntRangeInterfaceImpls.cpp
ValueBoundsOpInterfaceImpl.cpp
@@ -16,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineDialect
MLIRArithDialect
MLIRDialectUtils
MLIRIR
+ MLIRInferIntDivisibilityOpInterface
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
diff --git a/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..30850cf0a4df0
--- /dev/null
+++ b/mlir/lib/D...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/197728
More information about the Mlir-commits
mailing list