[Mlir-commits] [llvm] [mlir] Reland "[MLIR] Add `IntegerDivisibilityAnalysis` and `InferIntDivisibilityOpInterface`" (PR #198110)
Alan Li
llvmlistbot at llvm.org
Mon May 18 08:31:07 PDT 2026
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/198110
>From bb2e3a1802fd01e481641e34f74e7866f4d9ed8d Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Sat, 16 May 2026 09:47:24 -0700
Subject: [PATCH 1/4] Reland "[MLIR] Add `IntegerDivisibilityAnalysis` and
`InferIntDivisibilityOpInterface`"
This relands llvm/llvm-project#197728 (reverted in #198048).
The previous landing broke shared-library builds because
`mlir/lib/Analysis/CMakeLists.txt` did not link `MLIRDialectUtils`
even though `IntegerDivisibilityAnalysis.cpp` calls
`mlir::getConstantIntValue(OpFoldResult)` (defined in
`mlir/lib/Dialect/Utils/StaticValueUtils.cpp`). The link dependency
is added in a follow-up commit.
---
.../DataFlow/IntegerDivisibilityAnalysis.h | 64 ++++
.../mlir/Dialect/Affine/IR/AffineOps.td | 12 +-
mlir/include/mlir/Dialect/Arith/IR/Arith.h | 1 +
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 28 +-
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 +
.../InferIntDivisibilityOpInterface.h | 120 +++++++
.../InferIntDivisibilityOpInterface.td | 41 +++
mlir/lib/Analysis/CMakeLists.txt | 3 +
.../DataFlow/IntegerDivisibilityAnalysis.cpp | 135 ++++++++
mlir/lib/Dialect/Affine/IR/CMakeLists.txt | 2 +
.../InferIntDivisibilityOpInterfaceImpl.cpp | 312 ++++++++++++++++++
mlir/lib/Dialect/Arith/IR/CMakeLists.txt | 3 +
.../InferIntDivisibilityOpInterfaceImpl.cpp | 122 +++++++
mlir/lib/Interfaces/CMakeLists.txt | 2 +
.../InferIntDivisibilityOpInterface.cpp | 11 +
.../DataFlow/integer-divisibility.mlir | 152 +++++++++
mlir/test/lib/Analysis/CMakeLists.txt | 2 +
.../TestIntegerDivisibilityAnalysis.cpp | 93 ++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
19 files changed, 1095 insertions(+), 11 deletions(-)
create mode 100644 mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
create mode 100644 mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
create mode 100644 mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
create mode 100644 mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
create mode 100644 mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp
create mode 100644 mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp
create mode 100644 mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp
create mode 100644 mlir/test/Analysis/DataFlow/integer-divisibility.mlir
create mode 100644 mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
new file mode 100644
index 0000000000000..3a877647490a3
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h
@@ -0,0 +1,64 @@
+//===- IntegerDivisibilityAnalysis.h - Integer divisibility -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer divisibility
+// inference. Operations participate in the analysis by implementing
+// `InferIntDivisibilityOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
+#define MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
+
+namespace mlir::dataflow {
+
+/// This lattice element represents the integer divisibility of an SSA value.
+class IntegerDivisibilityLattice : public Lattice<IntegerDivisibility> {
+public:
+ using Lattice::Lattice;
+};
+
+/// Integer divisibility analysis determines, for each integer-typed SSA
+/// value, a divisor that the value is guaranteed to be a multiple of. It
+/// uses operations that implement `InferIntDivisibilityOpInterface` and
+/// also sets the divisibility of induction variables of loops with known
+/// lower bounds and steps.
+///
+/// This analysis depends on DeadCodeAnalysis, and will be a silent no-op
+/// if DeadCodeAnalysis is not loaded in the same solver context.
+class IntegerDivisibilityAnalysis
+ : public SparseForwardDataFlowAnalysis<IntegerDivisibilityLattice> {
+public:
+ using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+
+ /// At an entry point, set the lattice to the most pessimistic state,
+ /// indicating that no further reasoning can be done.
+ void setToEntryState(IntegerDivisibilityLattice *lattice) override;
+
+ /// Visit an operation, invoking the transfer function.
+ LogicalResult
+ visitOperation(Operation *op,
+ ArrayRef<const IntegerDivisibilityLattice *> operands,
+ ArrayRef<IntegerDivisibilityLattice *> results) override;
+
+ /// Visit block arguments or operation results of an operation with region
+ /// control-flow for which values are not defined by region control-flow. This
+ /// function tries to infer the divisibility of loop induction variables based
+ /// on known loop bounds and steps.
+ void visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ValueRange successorInputs,
+ ArrayRef<IntegerDivisibilityLattice *> argLattices) override;
+};
+
+} // namespace mlir::dataflow
+
+#endif // MLIR_ANALYSIS_DATAFLOW_INTEGERDIVISIBILITYANALYSIS_H
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index b2a4cf7f488bd..3d7cbcc375d2a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -16,6 +16,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntDivisibilityOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
@@ -43,7 +44,9 @@ def ImplicitAffineTerminator
: SingleBlockImplicitTerminator<"AffineYieldOp">;
def AffineApplyOp : Affine_Op<"apply",
- [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
+ [Pure,
+ DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "affine apply operation";
let description = [{
The `affine.apply` operation applies an [affine mapping](#affine-maps)
@@ -570,7 +573,8 @@ class AffineMinMaxOpBase<string mnemonic, list<Trait> traits = []> :
let hasVerifier = 1;
}
-def AffineMinOp : AffineMinMaxOpBase<"min", [Pure]> {
+def AffineMinOp : AffineMinMaxOpBase<"min",
+ [Pure, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "min operation";
let description = [{
Syntax:
@@ -594,7 +598,8 @@ def AffineMinOp : AffineMinMaxOpBase<"min", [Pure]> {
}];
}
-def AffineMaxOp : AffineMinMaxOpBase<"max", [Pure]> {
+def AffineMaxOp : AffineMinMaxOpBase<"max",
+ [Pure, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "max operation";
let description = [{
The `affine.max` operation computes the maximum value result from a multi-result
@@ -1071,6 +1076,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
[Pure, Elementwise,
+ DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>,
// Infer linear_index type from the first result type during parsing.
TypesMatchWith<"linear_index type must match result types",
"multi_index", "linear_index", "$_self[0]">
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 0fc3db8e993d8..bf6eb18df2e8a 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -15,6 +15,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index fa85b840e2707..1f8b07aed3f0d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntDivisibilityOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -223,6 +224,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
[ConstantLike, Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllTypesMatch<["value", "result"]>,
+ DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "integer or floating point constant";
let description = [{
@@ -270,7 +272,8 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// AddIOp
//===----------------------------------------------------------------------===//
-def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
+def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi",
+ [Commutative, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "integer addition operation";
let description = [{
Performs N-bit addition on the operands. The operands are interpreted as
@@ -416,7 +419,8 @@ def Arith_SubUIExtendedOp : Arith_Op<"subui_extended", [Pure,
// SubIOp
//===----------------------------------------------------------------------===//
-def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
+def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi",
+ [DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = [{
Integer subtraction operation.
}];
@@ -461,7 +465,9 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
//===----------------------------------------------------------------------===//
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
- [Commutative, DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]
+ [Commutative,
+ DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]
> {
let summary = [{
Integer multiplication operation.
@@ -593,7 +599,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
//===----------------------------------------------------------------------===//
def Arith_DivUIOp : Arith_IntBinaryOpWithExactFlag<"divui",
- [ConditionallySpeculatable]> {
+ [ConditionallySpeculatable,
+ DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "unsigned integer division operation";
let description = [{
Unsigned integer division. Rounds towards zero. Treats the leading bit as
@@ -1191,7 +1198,8 @@ def Arith_MaxNumFOp : Arith_FloatBinaryOp<"maxnumf", [Commutative]> {
// MaxSIOp
//===----------------------------------------------------------------------===//
-def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
+def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi",
+ [Commutative, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "signed integer maximum operation";
let hasFolder = 1;
}
@@ -1200,7 +1208,8 @@ def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> {
// MaxUIOp
//===----------------------------------------------------------------------===//
-def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> {
+def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui",
+ [Commutative, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "unsigned integer maximum operation";
let hasFolder = 1;
}
@@ -1250,7 +1259,8 @@ def Arith_MinNumFOp : Arith_FloatBinaryOp<"minnumf", [Commutative]> {
// MinSIOp
//===----------------------------------------------------------------------===//
-def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
+def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi",
+ [Commutative, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "signed integer minimum operation";
let hasFolder = 1;
}
@@ -1259,7 +1269,8 @@ def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> {
// MinUIOp
//===----------------------------------------------------------------------===//
-def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
+def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui",
+ [Commutative, DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>]> {
let summary = "unsigned integer minimum operation";
let hasFolder = 1;
}
@@ -2004,6 +2015,7 @@ class BooleanConditionOrMatchingShape<string condition, string result> :
def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
BooleanConditionOrMatchingShape<"condition", "result">,
+ DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
let summary = "select operation";
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 3cbc9df05f3d7..6461c68423c73 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
+add_mlir_interface(InferIntDivisibilityOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferStridedMetadataInterface)
add_mlir_interface(InferTypeOpInterface)
diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
new file mode 100644
index 0000000000000..374acee05cb10
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.h
@@ -0,0 +1,120 @@
+//===- InferIntDivisibilityOpInterface.h - Integer Divisibility -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of the integer divisibility inference
+// interface defined in `InferIntDivisibilityOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
+#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
+
+#include "mlir/IR/OpDefinition.h"
+#include <numeric>
+#include <optional>
+
+namespace mlir {
+
+/// Statically known divisibility information for an integer SSA value.
+/// Tracks separate divisors for the unsigned and signed interpretations of
+/// the value so that subsequent analyses can use whichever is more precise.
+class ConstantIntDivisibility {
+public:
+ ConstantIntDivisibility() = default;
+ ConstantIntDivisibility(uint64_t udiv, uint64_t sdiv)
+ : udivVal(udiv), sdivVal(sdiv) {}
+
+ bool operator==(const ConstantIntDivisibility &other) const {
+ return udivVal == other.udivVal && sdivVal == other.sdivVal;
+ }
+
+ uint64_t udiv() const { return this->udivVal; }
+ uint64_t sdiv() const { return this->sdivVal; }
+
+ // Returns the union (computed separately for signed and unsigned bounds)
+ // for this divisibility and `other`.
+ ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const {
+ return ConstantIntDivisibility(
+ /*udiv=*/std::gcd(udiv(), other.udiv()),
+ /*sdiv=*/std::gcd(sdiv(), other.sdiv()));
+ }
+
+private:
+ uint64_t udivVal;
+ uint64_t sdivVal;
+
+ friend raw_ostream &operator<<(raw_ostream &os,
+ const ConstantIntDivisibility &div);
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+ const ConstantIntDivisibility &div) {
+ os << "ConstantIntDivisibility(udiv = " << div.udivVal
+ << ", sdiv = " << div.sdivVal << ")";
+ return os;
+}
+
+/// This lattice value represents the integer divisibility of an SSA value.
+class IntegerDivisibility {
+public:
+ IntegerDivisibility(ConstantIntDivisibility value)
+ : value(std::move(value)) {}
+ explicit IntegerDivisibility(
+ std::optional<ConstantIntDivisibility> value = std::nullopt)
+ : value(std::move(value)) {}
+ // Gets the minimum divisibility of 1 that is used to indicate that the value
+ // cannot be analyzed further.
+ static IntegerDivisibility getMinDivisibility() {
+ return IntegerDivisibility(ConstantIntDivisibility(1, 1));
+ }
+
+ bool isUninitialized() const { return !value.has_value(); }
+ const ConstantIntDivisibility &getValue() const {
+ assert(!isUninitialized());
+ return *value;
+ }
+
+ bool operator==(const IntegerDivisibility &rhs) const {
+ return value == rhs.value;
+ }
+
+ static IntegerDivisibility join(const IntegerDivisibility &lhs,
+ const IntegerDivisibility &rhs) {
+ if (lhs.isUninitialized()) {
+ return rhs;
+ }
+ if (rhs.isUninitialized()) {
+ return lhs;
+ }
+ return IntegerDivisibility(lhs.getValue().getUnion(rhs.getValue()));
+ }
+
+ void print(raw_ostream &os) const { os << value; }
+
+private:
+ std::optional<ConstantIntDivisibility> value;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+ const IntegerDivisibility &div) {
+ div.print(os);
+ return os;
+}
+
+/// The type of the `setResultDivs` callback provided to ops implementing
+/// InferIntDivisibilityOpInterface. It should be called once for each integer
+/// result value and be passed the ConstantIntDivisibility corresponding to
+/// that value.
+using SetIntDivisibilityFn =
+ llvm::function_ref<void(Value, const ConstantIntDivisibility &)>;
+
+} // end namespace mlir
+
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE_H
diff --git a/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
new file mode 100644
index 0000000000000..c665475e0fd7f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/InferIntDivisibilityOpInterface.td
@@ -0,0 +1,41 @@
+//===- InferIntDivisibilityOpInterface.td - Integer Divisibility -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the interface for divisibility analysis on scalar integers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
+#define MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def InferIntDivisibilityOpInterface :
+ OpInterface<"InferIntDivisibilityOpInterface"> {
+ let description = [{
+ Allows operations to participate in integer divisibility analysis.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Infer the divisibility of the results of this op given the
+ divisibility of its arguments. For each result value, the method
+ should call `setResultDivs` with that `Value` as an argument.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"inferResultDivisibility",
+ /*args=*/(ins
+ "::llvm::ArrayRef<::mlir::IntegerDivisibility>":$argDivs,
+ "::mlir::SetIntDivisibilityFn":$setResultDivs)
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_INFERINTDIVISIBILITYOPINTERFACE
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index db10ebcf2c311..596ffaff428b5 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_OPTIONAL_SOURCES
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/DenseAnalysis.cpp
+ DataFlow/IntegerDivisibilityAnalysis.cpp
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
@@ -37,6 +38,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/DenseAnalysis.cpp
+ DataFlow/IntegerDivisibilityAnalysis.cpp
DataFlow/IntegerRangeAnalysis.cpp
DataFlow/LivenessAnalysis.cpp
DataFlow/SparseAnalysis.cpp
@@ -53,6 +55,7 @@ add_mlir_library(MLIRAnalysis
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRFunctionInterfaces
+ MLIRInferIntDivisibilityOpInterface
MLIRInferIntRangeInterface
MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
diff --git a/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
new file mode 100644
index 0000000000000..ba10a8b5a0060
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/IntegerDivisibilityAnalysis.cpp
@@ -0,0 +1,135 @@
+//===- IntegerDivisibilityAnalysis.cpp - Integer divisibility ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer divisibility
+// inference. Operations participate in the analysis by implementing
+// `InferIntDivisibilityOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h"
+
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-divisibility-analysis"
+
+using llvm::dbgs;
+
+namespace mlir::dataflow {
+
+void IntegerDivisibilityAnalysis::setToEntryState(
+ IntegerDivisibilityLattice *lattice) {
+ propagateIfChanged(lattice,
+ lattice->join(IntegerDivisibility::getMinDivisibility()));
+}
+
+LogicalResult IntegerDivisibilityAnalysis::visitOperation(
+ Operation *op, ArrayRef<const IntegerDivisibilityLattice *> operands,
+ ArrayRef<IntegerDivisibilityLattice *> results) {
+ auto inferrable = dyn_cast<InferIntDivisibilityOpInterface>(op);
+ if (!inferrable) {
+ setAllToEntryStates(results);
+ return success();
+ }
+
+ LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n");
+ auto argDivs = llvm::map_to_vector(
+ operands, [](const IntegerDivisibilityLattice *lattice) {
+ return lattice->getValue();
+ });
+ auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) {
+ auto result = dyn_cast<OpResult>(v);
+ if (!result) {
+ return;
+ }
+ assert(llvm::is_contained(op->getResults(), result));
+
+ LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n");
+ IntegerDivisibilityLattice *lattice = results[result.getResultNumber()];
+ IntegerDivisibility oldDiv = lattice->getValue();
+
+ ChangeResult changed = lattice->join(newDiv);
+
+ // Catch loop results with loop-variant divisibility and conservatively
+ // set them to divisibility 1 (no information) so we don't ratchet
+ // indefinitely (the dataflow analysis in MLIR doesn't attempt to work
+ // out trip counts and often can't).
+ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedResult && !oldDiv.isUninitialized() &&
+ !(lattice->getValue() == oldDiv)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= lattice->join(IntegerDivisibility::getMinDivisibility());
+ }
+ propagateIfChanged(lattice, changed);
+ };
+
+ inferrable.inferResultDivisibility(argDivs, joinCallback);
+ return success();
+}
+
+void IntegerDivisibilityAnalysis::visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor, ValueRange successorInputs,
+ ArrayRef<IntegerDivisibilityLattice *> argLattices) {
+ // Get the constant divisibility, or query the lattice for Values.
+ auto getDivFromOfr = [&](std::optional<OpFoldResult> ofr, Block *block,
+ bool isUnsigned) -> uint64_t {
+ if (ofr.has_value()) {
+ if (auto constBound = getConstantIntValue(*ofr)) {
+ return constBound.value();
+ }
+ auto value = cast<Value>(ofr.value());
+ const IntegerDivisibilityLattice *lattice =
+ getLatticeElementFor(getProgramPointBefore(block), value);
+ if (lattice != nullptr && !lattice->getValue().isUninitialized()) {
+ return isUnsigned ? lattice->getValue().getValue().udiv()
+ : lattice->getValue().getValue().sdiv();
+ }
+ }
+ return isUnsigned
+ ? IntegerDivisibility::getMinDivisibility().getValue().udiv()
+ : IntegerDivisibility::getMinDivisibility().getValue().sdiv();
+ };
+
+ // Infer bounds for loop arguments that have static bounds
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+ std::optional<SmallVector<Value>> ivs = loop.getLoopInductionVars();
+ std::optional<SmallVector<OpFoldResult>> lbs = loop.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loop.getLoopSteps();
+ if (!ivs || !lbs || !steps) {
+ return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, successorInputs, argLattices);
+ }
+ for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) {
+ IntegerDivisibilityLattice *ivEntry = getLatticeElement(iv);
+ Block *block = iv.getParentBlock();
+ uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true);
+ uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false);
+ uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true);
+ uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false);
+ ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv);
+ ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv);
+
+ // Loop induction variables are computed as `lb + i * step`. The
+ // divisibility for `i * step` is just the divisibility of `step`, so
+ // the total divisibility is obtained by unioning the step divisibility
+ // with the lower bound divisibility, which takes the GCD of the two.
+ ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv);
+ propagateIfChanged(ivEntry, ivEntry->join(ivDiv));
+ }
+ return;
+ }
+
+ return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, successorInputs, argLattices);
+}
+
+} // namespace mlir::dataflow
diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
index 566bc060e5d38..1caf2fa396797 100644
--- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRAffineDialect
AffineMemoryOpInterfaces.cpp
AffineOps.cpp
AffineValueMap.cpp
+ InferIntDivisibilityOpInterfaceImpl.cpp
InferIntRangeInterfaceImpls.cpp
ValueBoundsOpInterfaceImpl.cpp
@@ -16,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineDialect
MLIRArithDialect
MLIRDialectUtils
MLIRIR
+ MLIRInferIntDivisibilityOpInterface
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRLoopLikeInterface
diff --git a/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..451ec98ef3737
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/IR/InferIntDivisibilityOpInterfaceImpl.cpp
@@ -0,0 +1,312 @@
+//===- InferIntDivisibilityOpInterfaceImpl.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Direct implementations of `InferIntDivisibilityOpInterface` for affine ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
+
+#include <cstdlib>
+#include <numeric>
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+static ConstantIntDivisibility
+getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility) {
+ if (!divisibility.isUninitialized())
+ return divisibility.getValue();
+ APInt intVal;
+ if (matchPattern(v, m_ConstantInt(&intVal))) {
+ uint64_t udiv = intVal.getZExtValue();
+ uint64_t sdiv = std::abs(intVal.getSExtValue());
+ return ConstantIntDivisibility(udiv, sdiv);
+ }
+ return ConstantIntDivisibility(1, 1);
+}
+
+/// Visits affine expressions and recursively calculates the divisibilities of
+/// each subexpression. The final divisibilities of the expression and its
+/// subexpressions will be stored in the map for which a reference is provided
+/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`).
+class AffineExprDivisibilityFinder
+ : public AffineExprVisitor<AffineExprDivisibilityFinder,
+ ConstantIntDivisibility> {
+public:
+ using ExprDivisibilityMap =
+ llvm::DenseMap<AffineExpr, ConstantIntDivisibility>;
+ AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap)
+ : divisibilityMap(divisibilityMap) {}
+
+ ConstantIntDivisibility visitConstantExpr(AffineConstantExpr expr) {
+ // Constant expressions are trivial, since they are always static.
+ uint64_t constValue = std::abs(expr.getValue());
+ return ConstantIntDivisibility(constValue, constValue);
+ }
+
+ ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) {
+ // Dim expressions cannot be analyzed further, so return the divisibility
+ // in `divisibilityMap` if it has been populated by the caller, or fallback
+ // to the minimum divisibility.
+ if (divisibilityMap.contains(expr))
+ return divisibilityMap[expr];
+ return IntegerDivisibility::getMinDivisibility().getValue();
+ }
+
+ ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) {
+ // Symbol expressions cannot be analyzed further, so return the divisibility
+ // in `divisibilityMap` if it has been populated by the caller, or fallback
+ // to the minimum divisibility.
+ if (divisibilityMap.contains(expr))
+ return divisibilityMap[expr];
+ return IntegerDivisibility::getMinDivisibility().getValue();
+ }
+
+ /// Infer the divisibility of an addition or subtraction expression by
+ /// recursively visiting the LHS and RHS, and then unioning the results.
+ ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr))
+ return divisibilityMap[expr];
+ // The divisibility of an addition is the GCD of its constituents'
+ // divisibilities.
+ ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ ConstantIntDivisibility rhsDiv = visit(expr.getRHS());
+ return lhsDiv.getUnion(rhsDiv);
+ }
+
+ /// Infer the divisibility of a multiplication expression by recursively
+ /// visiting the LHS and RHS, and then multiplying the results.
+ ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr))
+ return divisibilityMap[expr];
+ // The divisibility of a multiplication is the product of its constituents'
+ // divisibilities.
+ ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ ConstantIntDivisibility rhsDiv = visit(expr.getRHS());
+ return ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(),
+ lhsDiv.sdiv() * rhsDiv.sdiv());
+ }
+
+ ConstantIntDivisibility visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr);
+ }
+
+ ConstantIntDivisibility visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr);
+ }
+
+ /// Infer the divisibility of a mod expression. If the RHS is a constant,
+ /// the result divisibility is gcd(lhs_divisibility, rhs_constant), since
+ /// (d * k) mod c is always divisible by gcd(d, c). Furthermore, if the
+ /// LHS divisibility is itself divisible by the constant (i.e., d % c == 0),
+ /// then (d * k) mod c is always zero, represented as divisibility 0.
+ ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr))
+ return divisibilityMap[expr];
+ auto constRhs = dyn_cast<AffineConstantExpr>(expr.getRHS());
+ if (!constRhs || constRhs.getValue() == 0)
+ return ConstantIntDivisibility(1, 1);
+ auto constValue = static_cast<uint64_t>(std::abs(constRhs.getValue()));
+ ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ // If the LHS is always a multiple of constValue, x mod constValue is
+ // always zero. Divisibility 0 is the lattice top ("divides everything").
+ uint64_t modUDiv = (lhsDiv.udiv() % constValue == 0)
+ ? 0
+ : std::gcd(lhsDiv.udiv(), constValue);
+ uint64_t modSDiv = (lhsDiv.sdiv() % constValue == 0)
+ ? 0
+ : std::gcd(lhsDiv.sdiv(), constValue);
+ return ConstantIntDivisibility(modUDiv, modSDiv);
+ }
+
+private:
+ ConstantIntDivisibility visitInvalidExpr(AffineBinaryOpExpr expr) {
+ return IntegerDivisibility::getMinDivisibility().getValue();
+ }
+
+ /// Helper shared by ceildiv and floordiv implementations. Returns the minimum
+ /// divisibility as a fallback if the divisor is not a constant, because the
+ /// divisibility cannot be inferred in this case. If the divisor is a
+ /// constant, then this function recursively visits the dividend, and returns
+ /// the quotient of the dividend's divisibility with the divisor.
+ ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr))
+ return divisibilityMap[expr];
+ auto constRhs = dyn_cast<AffineConstantExpr>(expr.getRHS());
+ // Division by zero is undefined, so return the minimum divisibility.
+ if (!constRhs || constRhs.getValue() == 0)
+ return ConstantIntDivisibility(1, 1);
+ auto constValue = static_cast<uint64_t>(std::abs(constRhs.getValue()));
+ ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ uint64_t divUDiv =
+ lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1;
+ uint64_t divSDiv =
+ lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1;
+ return ConstantIntDivisibility(divUDiv, divSDiv);
+ }
+
+ ExprDivisibilityMap &divisibilityMap;
+};
+
+/// Returns the divisibilities of each AffineMap result based on the
+/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities`
+/// should contain the divisibilities of the dims, followed by the
+/// divisibilities of the symbols in ascending order by their positions.
+SmallVector<ConstantIntDivisibility> getResultDivisibilities(
+ AffineMap map,
+ ArrayRef<ConstantIntDivisibility> dimAndSymbolDivisibilities) {
+ // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities.
+ llvm::DenseMap<AffineExpr, ConstantIntDivisibility> exprDivisibilityMap;
+ SmallVector<AffineExpr> inputExprs;
+ inputExprs.append(llvm::map_to_vector(
+ llvm::seq<int64_t>(map.getNumDims()),
+ [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); }));
+ inputExprs.append(llvm::map_to_vector(
+ llvm::seq<int64_t>(map.getNumSymbols()),
+ [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); }));
+ for (auto [expr, divisibility] :
+ llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) {
+ exprDivisibilityMap[expr] = divisibility;
+ }
+ AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap);
+
+ // Walk each result expression and compute their divisibilities.
+ SmallVector<ConstantIntDivisibility> resultDivisibilities;
+ for (AffineExpr resultExpr : map.getResults())
+ resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr));
+ return resultDivisibilities;
+}
+
+/// Infer the result divisibility of an affine.min or affine.max operation
+/// based on its operand divisibilities. The result divisibility is the GCD
+/// of the divisibilities of each of the affine map results, because the result
+/// of the affine.min/max op could be any of these results.
+template <typename MinOrMaxTy>
+void inferAffineMinOrMaxResultDivisibility(
+ MinOrMaxTy minOrMaxOp, ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ static_assert(llvm::is_one_of<MinOrMaxTy, AffineMinOp, AffineMaxOp>::value,
+ "MinOrMaxTy must be AffineMinOp or AffineMaxOp");
+ SmallVector<ConstantIntDivisibility> operandDivisibilities;
+ for (auto [operand, divisibility] :
+ llvm::zip(minOrMaxOp.getOperands(), argDivs)) {
+ operandDivisibilities.push_back(
+ getDivisibilityOfOperand(operand, divisibility));
+ }
+
+ SmallVector<ConstantIntDivisibility> resultDivisibilities =
+ getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities);
+
+ ConstantIntDivisibility resultDivisibility =
+ resultDivisibilities.pop_back_val();
+ for (auto divisibility : resultDivisibilities)
+ resultDivisibility = resultDivisibility.getUnion(divisibility);
+ setResultDivs(minOrMaxOp.getResult(), resultDivisibility);
+}
+
+} // namespace
+
+void AffineApplyOp::inferResultDivisibility(
+ ArrayRef<IntegerDivisibility> argDivs, SetIntDivisibilityFn setResultDivs) {
+ SmallVector<ConstantIntDivisibility> operandDivisibilities;
+ for (auto [operand, divisibility] : llvm::zip(getOperands(), argDivs)) {
+ operandDivisibilities.push_back(
+ getDivisibilityOfOperand(operand, divisibility));
+ }
+
+ SmallVector<ConstantIntDivisibility> resultDivisibilities =
+ getResultDivisibilities(getMap(), operandDivisibilities);
+ for (auto [result, divisibility] :
+ llvm::zip_equal(getOperation()->getResults(), resultDivisibilities)) {
+ setResultDivs(result, divisibility);
+ }
+}
+
+void AffineMinOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferAffineMinOrMaxResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void AffineMaxOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferAffineMinOrMaxResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void AffineDelinearizeIndexOp::inferResultDivisibility(
+ ArrayRef<IntegerDivisibility> argDivs, SetIntDivisibilityFn setResultDivs) {
+ MLIRContext *ctx = getContext();
+
+ // Operands are: [linear_index, dynamic_basis_values...]
+ ConstantIntDivisibility linearDiv =
+ getDivisibilityOfOperand(getLinearIndex(), argDivs[0]);
+
+ ArrayRef<int64_t> staticBasis = getStaticBasis();
+ int64_t numResults = getNumResults();
+
+ // Build affine expressions for each result.
+ // Dim 0 = linear index, symbols = dynamic basis values.
+ AffineExpr linearExpr = getAffineDimExpr(0, ctx);
+
+ // Collect operand divisibilities: [linear_index_div, dynamic_basis_divs...]
+ SmallVector<ConstantIntDivisibility> operandDivs;
+ operandDivs.push_back(linearDiv);
+
+ // Map static/dynamic basis values to affine expressions.
+ int64_t dynIdx = 0;
+ SmallVector<AffineExpr> basisExprs;
+ for (int64_t i = 0, e = static_cast<int64_t>(staticBasis.size()); i < e;
+ ++i) {
+ if (ShapedType::isDynamic(staticBasis[i])) {
+ basisExprs.push_back(getAffineSymbolExpr(dynIdx, ctx));
+ operandDivs.push_back(getDivisibilityOfOperand(getDynamicBasis()[dynIdx],
+ argDivs[1 + dynIdx]));
+ dynIdx++;
+ } else {
+ basisExprs.push_back(getAffineConstantExpr(staticBasis[i], ctx));
+ }
+ }
+
+ // The computation basis skips the outer bound if present.
+ bool hasOuter = hasOuterBound();
+ int64_t basisStart = hasOuter ? 1 : 0;
+
+ // Each result[i] can be expressed as an affine expression of the linear
+ // index using the effective basis (after dropping outer bound if present).
+ // Effective basis B[k] = basisExprs[basisStart + k], for k = 0..N-2.
+ // Stride s[i] = product of B[i..N-2] = product of
+ // basisExprs[basisStart+i .. end].
+ //
+ // result[0] = x floordiv s[0]
+ // result[i>0] = (x floordiv s[i]) mod B[i-1]
+ // For i=N-1, s[N-1]=1, so result[N-1] = x mod B[N-2].
+
+ AffineExpr stride = getAffineConstantExpr(1, ctx);
+ for (int64_t i = numResults - 1; i >= 0; --i) {
+ AffineExpr resultExpr;
+ if (i == 0) {
+ resultExpr = linearExpr.floorDiv(stride);
+ } else {
+ resultExpr =
+ (linearExpr.floorDiv(stride)) % basisExprs[basisStart + i - 1];
+ }
+
+ AffineMap resultMap = AffineMap::get(1, dynIdx, resultExpr, ctx);
+ SmallVector<ConstantIntDivisibility> divs =
+ getResultDivisibilities(resultMap, operandDivs);
+ setResultDivs(getResult(i), divs[0]);
+
+ if (i > 0)
+ stride = basisExprs[basisStart + i - 1] * stride;
+ }
+}
diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index 4beb99ccfdfba..3423e11a7d0f0 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
ArithOps.cpp
ArithDialect.cpp
+ InferIntDivisibilityOpInterfaceImpl.cpp
InferIntRangeInterfaceImpls.cpp
ValueBoundsOpInterfaceImpl.cpp
)
@@ -12,6 +13,7 @@ add_public_tablegen_target(MLIRArithCanonicalizationIncGen)
add_mlir_dialect_library(MLIRArithDialect
ArithOps.cpp
ArithDialect.cpp
+ InferIntDivisibilityOpInterfaceImpl.cpp
InferIntRangeInterfaceImpls.cpp
ADDITIONAL_HEADER_DIRS
@@ -24,6 +26,7 @@ add_mlir_dialect_library(MLIRArithDialect
LINK_LIBS PUBLIC
MLIRCastInterfaces
MLIRDialect
+ MLIRInferIntDivisibilityOpInterface
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..b23ee108ca4a3
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp
@@ -0,0 +1,122 @@
+//===- InferIntDivisibilityOpInterfaceImpl.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Direct implementations of `InferIntDivisibilityOpInterface` for arith ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
+
+#include <cstdlib>
+
+using namespace mlir;
+using namespace mlir::arith;
+
+static ConstantIntDivisibility
+getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility) {
+ if (!divisibility.isUninitialized())
+ return divisibility.getValue();
+ APInt intVal;
+ if (matchPattern(v, m_ConstantInt(&intVal))) {
+ uint64_t udiv = intVal.getZExtValue();
+ uint64_t sdiv = std::abs(intVal.getSExtValue());
+ return ConstantIntDivisibility(udiv, sdiv);
+ }
+ return ConstantIntDivisibility(1, 1);
+}
+
+// Result divisibility is the GCD (union) of the operand divisibilities.
+template <typename OpTy>
+static void
+inferBinaryGCDResultDivisibility(OpTy op, ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ auto lhsDiv = getDivisibilityOfOperand(op.getLhs(), argDivs[0]);
+ auto rhsDiv = getDivisibilityOfOperand(op.getRhs(), argDivs[1]);
+ setResultDivs(op.getResult(), lhsDiv.getUnion(rhsDiv));
+}
+
+void ConstantOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ auto constAttr = dyn_cast_if_present<IntegerAttr>(getValue());
+ if (!constAttr)
+ return;
+ const APInt &value = constAttr.getValue();
+ uint64_t udiv = value.getZExtValue();
+ uint64_t sdiv = std::abs(value.getSExtValue());
+ setResultDivs(getResult(), ConstantIntDivisibility(udiv, sdiv));
+}
+
+void AddIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void SubIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void MinUIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void MaxUIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void MinSIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void MaxSIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
+}
+
+void MulIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ auto lhsDivisibility = getDivisibilityOfOperand(getLhs(), argDivs[0]);
+ auto rhsDivisibility = getDivisibilityOfOperand(getRhs(), argDivs[1]);
+
+ uint64_t mulUDiv = lhsDivisibility.udiv() * rhsDivisibility.udiv();
+ uint64_t mulSDiv = lhsDivisibility.sdiv() * rhsDivisibility.sdiv();
+
+ setResultDivs(getResult(), ConstantIntDivisibility(mulUDiv, mulSDiv));
+}
+
+void DivUIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ APInt intVal;
+ if (!matchPattern(getRhs(), m_ConstantInt(&intVal)))
+ return;
+
+ auto lhsDivisibility = getDivisibilityOfOperand(getLhs(), argDivs[0]);
+
+ uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0
+ ? lhsDivisibility.udiv() / intVal.getZExtValue()
+ : 1;
+ uint64_t divSDiv =
+ lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0
+ ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue())
+ : 1;
+
+ setResultDivs(getResult(), ConstantIntDivisibility(divUDiv, divSDiv));
+}
+
+void SelectOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
+ SetIntDivisibilityFn setResultDivs) {
+ // argDivs[0] is the condition (i1), argDivs[1] is true, argDivs[2] is false.
+ auto trueDiv = getDivisibilityOfOperand(getTrueValue(), argDivs[1]);
+ auto falseDiv = getDivisibilityOfOperand(getFalseValue(), argDivs[2]);
+ setResultDivs(getResult(), trueDiv.getUnion(falseDiv));
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 41e890cb408ba..d20d290c45c01 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -9,6 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
FunctionImplementation.cpp
FunctionInterfaces.cpp
IndexingMapOpInterface.cpp
+ InferIntDivisibilityOpInterface.cpp
InferIntRangeInterface.cpp
InferStridedMetadataInterface.cpp
InferTypeOpInterface.cpp
@@ -66,6 +67,7 @@ add_mlir_library(MLIRFunctionInterfaces
)
add_mlir_interface_library(IndexingMapOpInterface)
+add_mlir_interface_library(InferIntDivisibilityOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_library(MLIRInferStridedMetadataInterface
diff --git a/mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp b/mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp
new file mode 100644
index 0000000000000..acd7cd9530b5c
--- /dev/null
+++ b/mlir/lib/Interfaces/InferIntDivisibilityOpInterface.cpp
@@ -0,0 +1,11 @@
+//===- InferIntDivisibilityOpInterface.cpp - Integer divisibility inference ==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.h"
+
+#include "mlir/Interfaces/InferIntDivisibilityOpInterface.cpp.inc"
diff --git a/mlir/test/Analysis/DataFlow/integer-divisibility.mlir b/mlir/test/Analysis/DataFlow/integer-divisibility.mlir
new file mode 100644
index 0000000000000..7f9466e949d4c
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/integer-divisibility.mlir
@@ -0,0 +1,152 @@
+// RUN: mlir-opt --split-input-file --test-int-divisibility-analysis --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL: @constant
+func.func @constant() -> index {
+ %0 = arith.constant 8 : index
+ // CHECK: divisibility = "udiv = 8, sdiv = 8"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @muli_constant
+func.func @muli_constant(%arg0 : index) -> index {
+ %c4 = arith.constant 4 : index
+ %0 = arith.muli %arg0, %c4 : index
+ // CHECK: divisibility = "udiv = 4, sdiv = 4"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @addi_gcd_of_muli_operands
+func.func @addi_gcd_of_muli_operands(%arg0 : index, %arg1 : index) -> index {
+ %c8 = arith.constant 8 : index
+ %c12 = arith.constant 12 : index
+ %a = arith.muli %arg0, %c8 : index
+ %b = arith.muli %arg1, %c12 : index
+ %0 = arith.addi %a, %b : index
+ // gcd(8, 12) = 4.
+ // CHECK: divisibility = "udiv = 4, sdiv = 4"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @addi_same_divisibility
+func.func @addi_same_divisibility(%arg0 : index, %arg1 : index) -> index {
+ %c16 = arith.constant 16 : index
+ %a = arith.muli %arg0, %c16 : index
+ %b = arith.muli %arg1, %c16 : index
+ %0 = arith.addi %a, %b : index
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_mul
+func.func @affine_apply_mul(%arg0 : index) -> index {
+ %c2 = arith.constant 2 : index
+ %seed = arith.muli %arg0, %c2 : index
+ %0 = affine.apply affine_map<(d0) -> (d0 * 16)>(%seed)
+ // 2 * 16 = 32.
+ // CHECK: divisibility = "udiv = 32, sdiv = 32"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_mul_then_floordiv
+func.func @affine_apply_mul_then_floordiv(%arg0 : index) -> index {
+ %0 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg0)
+ %1 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%0)
+ // 16 floordiv 4 = 4.
+ // CHECK: divisibility = "udiv = 4, sdiv = 4"
+ %2 = "test.int_divisibility"(%1) : (index) -> index
+ return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_mod_zero
+func.func @affine_apply_mod_zero(%arg0 : index) -> index {
+ %0 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg0)
+ %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%0)
+ // 16 % 16 == 0, so x mod 16 is always 0 -> divisibility 0 (lattice top).
+ // CHECK: divisibility = "udiv = 0, sdiv = 0"
+ %2 = "test.int_divisibility"(%1) : (index) -> index
+ return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_constant
+func.func @affine_apply_constant() -> index {
+ %0 = affine.apply affine_map<() -> (64)>()
+ // CHECK: divisibility = "udiv = 64, sdiv = 64"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_constant_step
+func.func @scf_for_constant_step() {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c8 = arith.constant 8 : index
+ scf.for %iv = %c0 to %c64 step %c8 {
+ // CHECK: divisibility = "udiv = 8, sdiv = 8"
+ %0 = "test.int_divisibility"(%iv) : (index) -> index
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_nontrivial_gcd
+func.func @scf_for_nontrivial_gcd() {
+ %c12 = arith.constant 12 : index
+ %c100 = arith.constant 100 : index
+ %c18 = arith.constant 18 : index
+ scf.for %iv = %c12 to %c100 step %c18 {
+ // gcd(12, 18) = 6.
+ // CHECK: divisibility = "udiv = 6, sdiv = 6"
+ %0 = "test.int_divisibility"(%iv) : (index) -> index
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_coprime
+func.func @scf_for_coprime() {
+ %c15 = arith.constant 15 : index
+ %c100 = arith.constant 100 : index
+ %c8 = arith.constant 8 : index
+ scf.for %iv = %c15 to %c100 step %c8 {
+ // gcd(15, 8) = 1.
+ // CHECK: divisibility = "udiv = 1, sdiv = 1"
+ %0 = "test.int_divisibility"(%iv) : (index) -> index
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_mul_plus_const
+func.func @affine_apply_mul_plus_const(%arg0 : index) -> index {
+ %c4 = arith.constant 4 : index
+ %seed = arith.muli %arg0, %c4 : index
+ %0 = affine.apply affine_map<(d0) -> (d0 * 8 + 16)>(%seed)
+ // seed has udiv = 4, multiplied by 8 -> 32, then +16. gcd(32,16) = 16.
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %1 = "test.int_divisibility"(%0) : (index) -> index
+ return %1 : index
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index c37671ade37b3..d86af5017f24b 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_library(MLIRTestAnalysis
DataFlow/TestDeadCodeAnalysis.cpp
DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+ DataFlow/TestIntegerDivisibilityAnalysis.cpp
DataFlow/TestLivenessAnalysis.cpp
DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
DataFlow/TestStridedMetadataRangeAnalysis.cpp
@@ -27,6 +28,7 @@ add_mlir_library(MLIRTestAnalysis
mlir_target_link_libraries(MLIRTestAnalysis PUBLIC
MLIRAffineDialect
MLIRAnalysis
+ MLIRArithDialect
MLIRFunctionInterfaces
MLIRMemRefDialect
MLIRPass
diff --git a/mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp
new file mode 100644
index 0000000000000..626cbc0fac7aa
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestIntegerDivisibilityAnalysis.cpp
@@ -0,0 +1,93 @@
+//===- TestIntegerDivisibilityAnalysis.cpp - Test int divisibility --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerDivisibilityAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+struct TestIntegerDivisibilityAnalysisPass
+ : public PassWrapper<TestIntegerDivisibilityAnalysisPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestIntegerDivisibilityAnalysisPass)
+
+ StringRef getArgument() const override {
+ return "test-int-divisibility-analysis";
+ }
+ StringRef getDescription() const override {
+ return "Test integer divisibility analysis by annotating "
+ "'test.int_divisibility' ops with the divisibility of their "
+ "operand.";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, affine::AffineDialect>();
+ }
+
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+ MLIRContext *context = &getContext();
+
+ // The pass is rooted on `test.int_divisibility` ops, which are expected
+ // to have a single operand for which to annotate divisibility information.
+ SmallVector<std::pair<Operation *, Value>> queryOps;
+ rootOp->walk([&](Operation *op) {
+ if (op->getName().getStringRef() == "test.int_divisibility" &&
+ op->getNumOperands() == 1)
+ queryOps.emplace_back(op, op->getOperand(0));
+ });
+
+ DataFlowSolver solver;
+ // DeadCodeAnalysis is the base analysis that allows the solver to traverse
+ // control flow. It is required by IntegerDivisibilityAnalysis.
+ solver.load<DeadCodeAnalysis>();
+ // SparseConstantPropagation allows the solver to call
+ // visitNonControlFlowArguments and analyze arguments like loop induction
+ // variables.
+ solver.load<SparseConstantPropagation>();
+ solver.load<IntegerDivisibilityAnalysis>();
+ if (failed(solver.initializeAndRun(rootOp)))
+ return signalPassFailure();
+
+ for (auto &[op, value] : queryOps) {
+ const auto *lattice =
+ solver.lookupState<IntegerDivisibilityLattice>(value);
+ if (!lattice || lattice->getValue().isUninitialized()) {
+ op->setAttr("divisibility", StringAttr::get(context, "uninitialized"));
+ continue;
+ }
+
+ // Format for the divisibility information is "udiv = X, sdiv = Y".
+ const auto &div = lattice->getValue().getValue();
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ os << "udiv = " << div.udiv() << ", sdiv = " << div.sdiv();
+ op->setAttr("divisibility", StringAttr::get(context, result));
+ }
+ }
+};
+} // end anonymous namespace
+
+namespace mlir::test {
+void registerTestIntegerDivisibilityAnalysisPass() {
+ PassRegistration<TestIntegerDivisibilityAnalysisPass>();
+}
+} // end namespace mlir::test
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index c4754b3a08551..13c0934f34656 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -104,6 +104,7 @@ void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
+void registerTestIntegerDivisibilityAnalysisPass();
void registerTestInterfaces();
void registerTestIRVisitorsPass();
void registerTestLastModifiedPass();
@@ -253,6 +254,7 @@ static void registerTestPasses() {
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
+ mlir::test::registerTestIntegerDivisibilityAnalysisPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestIrdlTestDialectConversionPass();
mlir::test::registerTestIRVisitorsPass();
>From 11b1f9bf408e1971317582edbdf480f5ea9ea446 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Sat, 16 May 2026 09:47:44 -0700
Subject: [PATCH 2/4] [MLIR][CMake] Link MLIRDialectUtils into MLIRAnalysis
`IntegerDivisibilityAnalysis.cpp` calls
`mlir::getConstantIntValue(OpFoldResult)`, which is defined in
`mlir/lib/Dialect/Utils/StaticValueUtils.cpp` (part of
`MLIRDialectUtils`). Without this link dependency, shared-library
builds of `libMLIRAnalysis.so` fail with undefined references.
This was the root cause of #198048 reverting #197728.
---
mlir/lib/Analysis/CMakeLists.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 596ffaff428b5..2712ab503663e 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -54,6 +54,7 @@ add_mlir_library(MLIRAnalysis
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
+ MLIRDialectUtils
MLIRFunctionInterfaces
MLIRInferIntDivisibilityOpInterface
MLIRInferIntRangeInterface
>From 1984e8bb3009deb9ca1e99c912c7ff90cce741aa Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Sun, 17 May 2026 18:55:01 -0700
Subject: [PATCH 3/4] [MLIR][Bazel] Add InferIntDivisibilityOpInterface and
divisibility analysis
Add Bazel build rules mirroring the CMake changes from the
IntegerDivisibilityAnalysis reland:
- New `InferIntDivisibilityOpInterfaceTdFiles` td_library,
`InferIntDivisibilityOpInterfaceIncGen` gentbl, and
`InferIntDivisibilityOpInterface` cc_library targets.
- Wire `:InferIntDivisibilityOpInterface` into the `Analysis`,
`ArithDialect`, and `AffineDialect` cc_library deps.
- Add `InferIntDivisibilityOpInterfaceImpl.cpp` to the explicit
`ArithDialect` srcs (AffineDialect uses a glob).
- Add `:InferIntDivisibilityOpInterfaceTdFiles` to
`ArithOpsTdFiles` and `AffineOpsTdFiles` deps so the .td
include resolves.
`TestAnalysis` uses a `lib/Analysis/DataFlow/*.cpp` glob and picks
up `TestIntegerDivisibilityAnalysis.cpp` automatically; no test
BUILD change needed.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply at anthropic.com>
---
.../llvm-project-overlay/mlir/BUILD.bazel | 35 +++++++++++++++++++
1 file changed, 35 insertions(+)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 456cf1e06f539..0d5a3642513d4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1456,6 +1456,13 @@ gentbl_cc_library(
deps = [":DialectFoldInterfaceTdFiles"],
)
+td_library(
+ name = "InferIntDivisibilityOpInterfaceTdFiles",
+ srcs = ["include/mlir/Interfaces/InferIntDivisibilityOpInterface.td"],
+ includes = ["include"],
+ deps = [":OpBaseTdFiles"],
+)
+
td_library(
name = "InferIntRangeInterfaceTdFiles",
srcs = ["include/mlir/Interfaces/InferIntRangeInterface.td"],
@@ -1583,6 +1590,7 @@ td_library(
deps = [
":ArithOpsTdFiles",
":FuncTdFiles",
+ ":InferIntDivisibilityOpInterfaceTdFiles",
":LoopLikeInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
@@ -4059,6 +4067,7 @@ cc_library(
":ControlFlowInterfaces",
":DialectUtils",
":IR",
+ ":InferIntDivisibilityOpInterface",
":InferIntRangeCommon",
":InferIntRangeInterface",
":InliningUtils",
@@ -7834,6 +7843,28 @@ cc_library(
],
)
+gentbl_cc_library(
+ name = "InferIntDivisibilityOpInterfaceIncGen",
+ tbl_outs = {
+ "include/mlir/Interfaces/InferIntDivisibilityOpInterface.h.inc": ["-gen-op-interface-decls"],
+ "include/mlir/Interfaces/InferIntDivisibilityOpInterface.cpp.inc": ["-gen-op-interface-defs"],
+ },
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Interfaces/InferIntDivisibilityOpInterface.td",
+ deps = [":InferIntDivisibilityOpInterfaceTdFiles"],
+)
+
+cc_library(
+ name = "InferIntDivisibilityOpInterface",
+ srcs = ["lib/Interfaces/InferIntDivisibilityOpInterface.cpp"],
+ hdrs = ["include/mlir/Interfaces/InferIntDivisibilityOpInterface.h"],
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":InferIntDivisibilityOpInterfaceIncGen",
+ ],
+)
+
gentbl_cc_library(
name = "InferIntRangeInterfaceIncGen",
tbl_outs = {
@@ -8882,6 +8913,7 @@ cc_library(
":DialectUtils",
":FunctionInterfaces",
":IR",
+ ":InferIntDivisibilityOpInterface",
":InferIntRangeInterface",
":InferStridedMetadataInterface",
":LoopLikeInterface",
@@ -12826,6 +12858,7 @@ td_library(
":BuiltinDialectTdFiles",
":CastInterfacesTdFiles",
":ControlFlowInterfacesTdFiles",
+ ":InferIntDivisibilityOpInterfaceTdFiles",
":InferIntRangeInterfaceTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
@@ -12913,6 +12946,7 @@ cc_library(
"include/mlir/Interfaces/ValueBoundsOpInterface.h",
"lib/Dialect/Arith/IR/ArithDialect.cpp",
"lib/Dialect/Arith/IR/ArithOps.cpp",
+ "lib/Dialect/Arith/IR/InferIntDivisibilityOpInterfaceImpl.cpp",
"lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp",
] + glob([
"include/mlir/Analysis/**/*.h",
@@ -12936,6 +12970,7 @@ cc_library(
":ConvertToLLVMInterface",
":DestinationStyleOpInterface",
":IR",
+ ":InferIntDivisibilityOpInterface",
":InferIntRangeCommon",
":InferIntRangeInterface",
":InferStridedMetadataInterface",
>From c767687d783f823136414561efe4e39544ae204a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 18 May 2026 08:30:53 -0700
Subject: [PATCH 4/4] [MLIR][Bazel] Add ArithDialect dep to TestAnalysis
`TestIntegerDivisibilityAnalysis.cpp` includes
`mlir/Dialect/Arith/IR/Arith.h`, but the `TestAnalysis` cc_library
didn't depend on `//mlir:ArithDialect`. With module-style header
checking this surfaces as:
error: module ...:TestAnalysis does not depend on a module
exporting 'mlir/Dialect/Arith/IR/Arith.h'
Verified with `bazel build --config=generic_clang
@llvm-project//mlir/...` (4834 targets, all green).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply at anthropic.com>
---
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel | 1 +
1 file changed, 1 insertion(+)
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index adec60503ede0..b93d63fde404a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -101,6 +101,7 @@ cc_library(
"//mlir:AffineAnalysis",
"//mlir:AffineDialect",
"//mlir:Analysis",
+ "//mlir:ArithDialect",
"//mlir:CallOpInterfaces",
"//mlir:ControlFlowDialect",
"//mlir:ControlFlowInterfaces",
More information about the Mlir-commits
mailing list