[Mlir-commits] [mlir] [mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (PR #145068)
Fabian Mora
llvmlistbot at llvm.org
Fri Jun 20 12:05:37 PDT 2025
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/145068
>From b3ba28e399dec36d3cab6acf0738afae169846df Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 20 Jun 2025 16:21:52 +0000
Subject: [PATCH 1/3] [mlir][affine|ValueBounds] Add transform to simplify
affine min max ops with ValueBoundsOpInterface
---
.../Affine/TransformOps/AffineTransformOps.td | 37 +++++
.../Dialect/Affine/Transforms/Transforms.h | 19 +++
.../mlir/Interfaces/ValueBoundsOpInterface.h | 25 ++-
.../TransformOps/AffineTransformOps.cpp | 19 +++
.../Dialect/Affine/Transforms/CMakeLists.txt | 1 +
.../Transforms/SimplifyAffineMinMax.cpp | 153 ++++++++++++++++++
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 59 ++++++-
.../transform-op-simplify-min-max-ops.mlir | 72 +++++++++
.../test/Dialect/Linalg/transform-op-pad.mlir | 35 ++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 10 ++
mlir/test/lib/Dialect/Test/TestOps.td | 19 +++
11 files changed, 447 insertions(+), 2 deletions(-)
create mode 100644 mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
create mode 100644 mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 70b127fd063ca..4659ae28ee093 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -63,4 +63,41 @@ def SimplifyBoundedAffineOpsOp
}];
}
+def SimplifyMinMaxAffineOpsOp :
+ Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
+ TransformOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformEachOpTrait
+ ]> {
+ let description = [{
+ Simplify all the affine.min / affine.max ops being targeted or nested in the
+ target operation, using the `mlir::affine::simplifyAffineMinMaxOps`
+ transform.
+
+ Example:
+ ```
+ %0 = transform.structured.match ops{["gpu.launch", "affine.max"]} in %arg1
+ transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+ ```
+
+ #### Return modes
+
+ This transform consumes the target handle and does not produce any results.
+ This transforms never produces errors.
+
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // Affine_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 5c538d28c1835..b0578eb159c11 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -34,6 +34,8 @@ namespace affine {
class AffineApplyOp;
class AffineDelinearizeIndexOp;
class AffineLinearizeIndexOp;
+class AffineMaxOp;
+class AffineMinOp;
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
@@ -127,6 +129,23 @@ OpFoldResult materializeComputedBound(
OpBuilder &b, Location loc, AffineMap boundMap,
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
+/// Tries to simplify all affine min or max operations under `topOp`. The
+/// transform works by finding disjoint sets of affine result expressions
+/// bounded by a common affine expression on the min/max operation. It populates
+/// `modifiedOps` with all the operations modified by the transform/
+///
+/// In concrete terms, given an operation like:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
+/// If `d0 < 128` and `128 < s1 < s0`, the transform will update the op to:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
+void simplifyAffineMinMaxOps(RewriterBase &rewriter, Operation *topOp,
+ SmallVectorImpl<Operation *> &modifiedOps);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
} // namespace affine
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 337314143c80c..39206b89ef8c6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -135,10 +135,17 @@ class ValueBoundsConstraintSet
/// Construct a variable for a map and its operands.
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
- Variable(AffineMap map, ArrayRef<Value> mapOperands);
+ Variable(AffineMap map, ValueRange mapOperands);
MLIRContext *getContext() const { return map.getContext(); }
+ /// Returns the affine map.
+ AffineMap getMap() const { return map; }
+
+ /// Returns the map operands.
+ ValueDimList &getOperands() { return mapOperands; }
+ const ValueDimList &getOperands() const { return mapOperands; }
+
private:
friend class ValueBoundsConstraintSet;
AffineMap map;
@@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
/// prove the relation or until it ran out of IR.
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);
+ /// This function is similar to `ValueBoundsConstraintSet::compare`, except
+ /// that it returns false if `!(lhs cmp rhs)`, and `std::nullopt` if the
+ /// values couldn't be compared.
+ static std::optional<bool> strongCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs);
/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
@@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
/// constraints.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+ /// Return "true" if, based on the current state of the constraint system,
+ /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
+ /// can be proven. Otherwise it returns `std::nullopt` meaning the values are
+ /// unordered with respect to the constraints.
+ ///
+ /// This function does not analyze any IR and does not populate any additional
+ /// constraints.
+ std::optional<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+ int64_t rhsPos);
+
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index c9fe4474a68fa..6bd4dd23c7af5 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -148,6 +149,24 @@ void SimplifyBoundedAffineOpsOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// SimplifyMinMaxAffineOpsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure SimplifyMinMaxAffineOpsOp::applyToOne(
+ TransformRewriter &rewriter, Operation *target,
+ ApplyToEachResultList &results, TransformState &state) {
+ SmallVector<Operation *> modifiedOps;
+ simplifyAffineMinMaxOps(rewriter, target, modifiedOps);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyMinMaxAffineOpsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 1c82822b2bd7f..c792200f4a49a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
+ SimplifyAffineMinMax.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
new file mode 100644
index 0000000000000..0ddf2ec192a3e
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -0,0 +1,153 @@
+//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a transform to simplify mix/max affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/IntEqClasses.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-min-max"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Simplifies an affine min/max operation by proving there's a lower or upper
+/// bound.
+template <typename AffineOp>
+static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
+ using Variable = ValueBoundsConstraintSet::Variable;
+ using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
+
+ AffineMap affineMap = affineOp.getMap();
+ ValueRange operands = affineOp.getOperands();
+ static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
+
+ LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+
+ // Create a `Variable` list with values corresponding to each of the results
+ // in the affine affineMap.
+ SmallVector<Variable> variables = llvm::map_to_vector(
+ llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
+ [&](unsigned i) {
+ return Variable(affineMap.getSliceMap(i, 1), operands);
+ });
+
+ // Get the comparison operation.
+ ComparisonOperator cmpOp =
+ isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
+
+ // Find disjoint sets bounded by a common value.
+ llvm::IntEqClasses boundedClasses(variables.size());
+ DenseMap<unsigned, Variable *> bounds;
+ for (auto &&[i, v] : llvm::enumerate(variables)) {
+ unsigned eqClass = boundedClasses.findLeader(i);
+
+ // If the class already has a bound continue.
+ if (bounds.contains(eqClass))
+ continue;
+
+ // Initialize the bound.
+ Variable *bound = &v;
+
+ LLVM_DEBUG({
+ DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
+ });
+
+ // Check against the other variables.
+ for (size_t j = i + 1; j < variables.size(); ++j) {
+ unsigned jEqClass = boundedClasses.findLeader(j);
+ // Get the bound of the equivalence class or itself.
+ Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
+
+ LLVM_DEBUG({
+ DBGS() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap() << "\n";
+ });
+
+ // Compare the variables.
+ std::optional<bool> cmpResult =
+ ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
+
+ // The variables cannot be compared.
+ if (!cmpResult) {
+ LLVM_DEBUG({
+ DBGS() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged\n";
+ });
+ continue;
+ }
+
+ // Join the equivalent classes and update the bound if necessary.
+ LLVM_DEBUG({
+ DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is lhs <= rhs: " << *cmpResult << "`\n";
+ });
+ if (*cmpResult) {
+ boundedClasses.join(eqClass, jEqClass);
+ } else {
+ // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
+ // isMin == false.
+ bound = nv;
+ boundedClasses.join(eqClass, jEqClass);
+ }
+ }
+ bounds[boundedClasses.findLeader(i)] = bound;
+ }
+
+ // Return if there's no simplification.
+ if (bounds.size() >= affineMap.getNumResults()) {
+ LLVM_DEBUG(
+ { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ return false;
+ }
+
+ // Construct the new affine affineMap.
+ SmallVector<AffineExpr> results;
+ results.reserve(bounds.size());
+ for (auto [k, bound] : bounds)
+ results.push_back(bound->getMap().getResult(0));
+
+ affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
+ results, rewriter.getContext());
+
+ // Update the affine op.
+ rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
+ LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ return true;
+}
+
+bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
+ return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
+ return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+void mlir::affine::simplifyAffineMinMaxOps(
+ RewriterBase &rewriter, Operation *topOp,
+ SmallVectorImpl<Operation *> &modifiedOps) {
+ assert(topOp && "null-op");
+ topOp->walk([&](Operation *op) {
+ if (auto affineOp = dyn_cast<AffineMinOp>(op)) {
+ if (simplifyAffineMinMaxOp(rewriter, affineOp))
+ modifiedOps.push_back(op);
+ } else if (auto affineOp = dyn_cast<AffineMaxOp>(op)) {
+ if (simplifyAffineMinMaxOp(rewriter, affineOp))
+ modifiedOps.push_back(op);
+ }
+ });
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 87f883c2e6485..d7a6187cafb1e 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
}
ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
- ArrayRef<Value> mapOperands)
+ ValueRange mapOperands)
: Variable(map, llvm::map_to_vector(mapOperands,
[](Value v) { return Variable(v); })) {}
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
+std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
+ int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
+ auto strongCmp = [&](ComparisonOperator cmp,
+ ComparisonOperator negCmp) -> std::optional<bool> {
+ if (comparePos(lhsPos, cmp, rhsPos))
+ return true;
+ if (comparePos(lhsPos, negCmp, rhsPos))
+ return false;
+ return std::nullopt;
+ };
+ switch (cmp) {
+ case ComparisonOperator::LT:
+ return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
+ case ComparisonOperator::LE:
+ return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
+ case ComparisonOperator::GT:
+ return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
+ case ComparisonOperator::GE:
+ return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
+ case ComparisonOperator::EQ: {
+ std::optional<bool> le =
+ strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
+ if (!le)
+ return std::nullopt;
+ if (!*le)
+ return false;
+ std::optional<bool> ge =
+ strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+ if (!ge)
+ return std::nullopt;
+ if (!*ge)
+ return false;
+ return true;
+ }
+ }
+ llvm_unreachable("invalid comparison operator");
+}
+
bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs) {
@@ -763,6 +801,25 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
+std::optional<bool> ValueBoundsConstraintSet::strongCompare(
+ const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) {
+ int64_t lhsPos = -1, rhsPos = -1;
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs were not processed.
+ if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
+ size_t(rhsPos) >= cstr.positionToValueDim.size())
+ return false;
+ // Keep processing as long as the strong relation cannot be proven.
+ std::optional<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+ return ordered ? true : false;
+ };
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+ rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
+ return cstr.strongComparePos(lhsPos, cmp, rhsPos);
+}
+
FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
const Variable &var2) {
if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
new file mode 100644
index 0000000000000..2b6de62073e99
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK-NOT: affine.min
+ // CHECK-NOT: affine.max
+ // CHECK: return %[[V0]], %[[V1]]
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @overlapping_constraints() -> (index, index) {
+ %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
+ %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
+ // CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
+ %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+ %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+ return %r0, %r1 : index, index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+ %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %1 {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index bc684b53c9b61..53e67321000fd 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -454,3 +454,38 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// This test checks that by using `simplify_min_max_affine_ops` after padding
+// and tiling, it's possible to recover static tiled slices.
+
+// CHECK-LABEL: @dyn_pad_tiling
+// CHECK: %[[LHS:.*]] = tensor.pad
+// CHECK: %[[RHS:.*]] = tensor.pad
+// CHECK: scf.for
+// CHECK: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
+// CHECK: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
+func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %2 {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+ transform.affine.simplify_min_max_affine_ops %2 : !transform.any_op
+ %3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %3 {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+ transform.yield
+ }
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 78e44c6ec7a9b..6c1a5d3441530 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -836,6 +836,16 @@ void ConversionFuncOp::print(OpAsmPrinter &p) {
getArgAttrsAttrName(), getResAttrsAttrName());
}
+//===----------------------------------------------------------------------===//
+// TestValueWithBoundsOp
+//===----------------------------------------------------------------------===//
+
+void TestValueWithBoundsOp::populateBoundsForIndexValue(
+ Value v, ValueBoundsConstraintSet &cstr) {
+ cstr.bound(v) >= getMin().getSExtValue();
+ cstr.bound(v) <= getMax().getSExtValue();
+}
+
//===----------------------------------------------------------------------===//
// ReifyBoundOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 30234698bc8dd..8a4981a90831f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -31,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
// Include the attribute definitions.
@@ -2375,6 +2376,24 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
// Test ValueBoundsOpInterface
//===----------------------------------------------------------------------===//
+def TestValueWithBoundsOp : TEST_Op<"value_with_bounds", [
+ DeclareOpInterfaceMethods<ValueBoundsOpInterface, ["populateBoundsForIndexValue"]>
+ ]> {
+ let description = [{
+ Creates a value with specified [min, max] range for value bounds analysis.
+
+ Example:
+
+ ```mlir
+ %0 = test.value_with_bounds { min = 4 : index, max = 5 : index}
+ ```
+ }];
+ let arguments = (ins IndexAttr:$min, IndexAttr:$max);
+ let results = (outs Index:$result);
+ let assemblyFormat = "attr-dict";
+}
+
+
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let description = [{
Reify a bound for the given index-typed value or dimension size of a shaped
>From 9f4341d28cc98619e9b47a6636737ff7ed62a072 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Fri, 20 Jun 2025 14:17:14 -0400
Subject: [PATCH 2/3] fix nit
---
.../mlir/Dialect/Affine/TransformOps/AffineTransformOps.td | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 4659ae28ee093..a69dd665ec125 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -84,7 +84,6 @@ def SimplifyMinMaxAffineOpsOp :
This transform consumes the target handle and does not produce any results.
This transforms never produces errors.
-
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);
>From 4f48eb0bda938f6f5d2815d5a521e26131460bdc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 20 Jun 2025 19:01:20 +0000
Subject: [PATCH 3/3] address reviewer comments
---
.../mlir/Interfaces/ValueBoundsOpInterface.h | 18 ++++++-------
.../Transforms/SimplifyAffineMinMax.cpp | 4 +--
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 26 ++++++++-----------
3 files changed, 22 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 39206b89ef8c6..d168735f50598 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -262,11 +262,11 @@ class ValueBoundsConstraintSet
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);
/// This function is similar to `ValueBoundsConstraintSet::compare`, except
- /// that it returns false if `!(lhs cmp rhs)`, and `std::nullopt` if the
- /// values couldn't be compared.
- static std::optional<bool> strongCompare(const Variable &lhs,
- ComparisonOperator cmp,
- const Variable &rhs);
+ /// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
+ /// relation nor its inverse relation could be proven.
+ static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs);
/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
@@ -342,13 +342,13 @@ class ValueBoundsConstraintSet
/// Return "true" if, based on the current state of the constraint system,
/// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
- /// can be proven. Otherwise it returns `std::nullopt` meaning the values are
- /// unordered with respect to the constraints.
+ /// can be proven. Otherwise, it returns `failure` if neither the relation nor
+ /// its inverse relation could be proven.
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
- std::optional<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
- int64_t rhsPos);
+ llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+ int64_t rhsPos);
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index 0ddf2ec192a3e..1586af561e174 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -78,11 +78,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
});
// Compare the variables.
- std::optional<bool> cmpResult =
+ FailureOr<bool> cmpResult =
ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
// The variables cannot be compared.
- if (!cmpResult) {
+ if (failed(cmpResult)) {
LLVM_DEBUG({
DBGS() << "-- classes: #" << i << ", #" << jEqClass
<< " cannot be merged\n";
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index d7a6187cafb1e..c9481fb5d9406 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -736,15 +736,15 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
-std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
+FailureOr<bool> ValueBoundsConstraintSet::strongComparePos(
int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
auto strongCmp = [&](ComparisonOperator cmp,
- ComparisonOperator negCmp) -> std::optional<bool> {
+ ComparisonOperator negCmp) -> FailureOr<bool> {
if (comparePos(lhsPos, cmp, rhsPos))
return true;
if (comparePos(lhsPos, negCmp, rhsPos))
return false;
- return std::nullopt;
+ return failure();
};
switch (cmp) {
case ComparisonOperator::LT:
@@ -759,13 +759,13 @@ std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
std::optional<bool> le =
strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
if (!le)
- return std::nullopt;
+ return failure();
if (!*le)
return false;
std::optional<bool> ge =
strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
if (!ge)
- return std::nullopt;
+ return failure();
if (!*ge)
return false;
return true;
@@ -801,8 +801,9 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
-std::optional<bool> ValueBoundsConstraintSet::strongCompare(
- const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) {
+FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs) {
int64_t lhsPos = -1, rhsPos = -1;
auto stopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
@@ -811,8 +812,8 @@ std::optional<bool> ValueBoundsConstraintSet::strongCompare(
size_t(rhsPos) >= cstr.positionToValueDim.size())
return false;
// Keep processing as long as the strong relation cannot be proven.
- std::optional<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
- return ordered ? true : false;
+ FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+ return failed(ordered) ? true : false;
};
ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
@@ -822,12 +823,7 @@ std::optional<bool> ValueBoundsConstraintSet::strongCompare(
FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
const Variable &var2) {
- if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
- return true;
- if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
- ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
- return false;
- return failure();
+ return strongCompare(var1, ComparisonOperator::EQ, var2);
}
FailureOr<bool>
More information about the Mlir-commits
mailing list