[Mlir-commits] [mlir] [mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (PR #145068)

Fabian Mora llvmlistbot at llvm.org
Fri Jun 20 12:05:37 PDT 2025


https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/145068

>From b3ba28e399dec36d3cab6acf0738afae169846df Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 20 Jun 2025 16:21:52 +0000
Subject: [PATCH 1/3] [mlir][affine|ValueBounds] Add transform to simplify
 affine min max ops with ValueBoundsOpInterface

---
 .../Affine/TransformOps/AffineTransformOps.td |  37 +++++
 .../Dialect/Affine/Transforms/Transforms.h    |  19 +++
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  25 ++-
 .../TransformOps/AffineTransformOps.cpp       |  19 +++
 .../Dialect/Affine/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/SimplifyAffineMinMax.cpp       | 153 ++++++++++++++++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  59 ++++++-
 .../transform-op-simplify-min-max-ops.mlir    |  72 +++++++++
 .../test/Dialect/Linalg/transform-op-pad.mlir |  35 ++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  10 ++
 mlir/test/lib/Dialect/Test/TestOps.td         |  19 +++
 11 files changed, 447 insertions(+), 2 deletions(-)
 create mode 100644 mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
 create mode 100644 mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 70b127fd063ca..4659ae28ee093 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -63,4 +63,41 @@ def SimplifyBoundedAffineOpsOp
   }];
 }
 
+def SimplifyMinMaxAffineOpsOp :
+  Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
+    TransformOpInterface,
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TransformEachOpTrait
+  ]> {
+  let description = [{
+    Simplify all the affine.min / affine.max ops being targeted or nested in the
+    target operation, using the `mlir::affine::simplifyAffineMinMaxOps`
+    transform.
+
+    Example:
+    ```
+    %0 = transform.structured.match ops{["gpu.launch", "affine.max"]} in %arg1
+    transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+    ```
+
+    #### Return modes
+
+    This transform consumes the target handle and does not produce any results.
+    This transforms never produces errors.
+
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+  let assemblyFormat = [{
+      $target attr-dict `:` type($target)
+  }];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // Affine_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 5c538d28c1835..b0578eb159c11 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -34,6 +34,8 @@ namespace affine {
 class AffineApplyOp;
 class AffineDelinearizeIndexOp;
 class AffineLinearizeIndexOp;
+class AffineMaxOp;
+class AffineMinOp;
 
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
 /// operations.
@@ -127,6 +129,23 @@ OpFoldResult materializeComputedBound(
     OpBuilder &b, Location loc, AffineMap boundMap,
     ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
 
+/// Tries to simplify all affine min or max operations under `topOp`. The
+/// transform works by finding disjoint sets of affine result expressions
+/// bounded by a common affine expression on the min/max operation. It populates
+/// `modifiedOps` with all the operations modified by the transform/
+///
+/// In concrete terms, given an operation like:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
+/// If `d0 < 128` and `128 < s1 < s0`, the transform will update the op to:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
+void simplifyAffineMinMaxOps(RewriterBase &rewriter, Operation *topOp,
+                             SmallVectorImpl<Operation *> &modifiedOps);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
+/// Applies `simplifyAffineMinMaxOps` to a single operation and returns whether
+/// the operation was modified.
+bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 337314143c80c..39206b89ef8c6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -135,10 +135,17 @@ class ValueBoundsConstraintSet
 
     /// Construct a variable for a map and its operands.
     Variable(AffineMap map, ArrayRef<Variable> mapOperands);
-    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+    Variable(AffineMap map, ValueRange mapOperands);
 
     MLIRContext *getContext() const { return map.getContext(); }
 
+    /// Returns the affine map.
+    AffineMap getMap() const { return map; }
+
+    /// Returns the map operands.
+    ValueDimList &getOperands() { return mapOperands; }
+    const ValueDimList &getOperands() const { return mapOperands; }
+
   private:
     friend class ValueBoundsConstraintSet;
     AffineMap map;
@@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
   /// prove the relation or until it ran out of IR.
   static bool compare(const Variable &lhs, ComparisonOperator cmp,
                       const Variable &rhs);
+  /// This function is similar to `ValueBoundsConstraintSet::compare`, except
+  /// that it returns false if `!(lhs cmp rhs)`, and `std::nullopt` if the
+  /// values couldn't be compared.
+  static std::optional<bool> strongCompare(const Variable &lhs,
+                                           ComparisonOperator cmp,
+                                           const Variable &rhs);
 
   /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
@@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
   /// constraints.
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
+  /// Return "true" if, based on the current state of the constraint system,
+  /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
+  /// can be proven. Otherwise it returns `std::nullopt` meaning the values are
+  /// unordered with respect to the constraints.
+  ///
+  /// This function does not analyze any IR and does not populate any additional
+  /// constraints.
+  std::optional<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+                                       int64_t rhsPos);
+
   /// Given an affine map with a single result (and map operands), add a new
   /// column to the constraint set that represents the result of the map.
   /// Traverse additional IR starting from the map operands as needed (as long
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index c9fe4474a68fa..6bd4dd23c7af5 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -148,6 +149,24 @@ void SimplifyBoundedAffineOpsOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// SimplifyMinMaxAffineOpsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure SimplifyMinMaxAffineOpsOp::applyToOne(
+    TransformRewriter &rewriter, Operation *target,
+    ApplyToEachResultList &results, TransformState &state) {
+  SmallVector<Operation *> modifiedOps;
+  simplifyAffineMinMaxOps(rewriter, target, modifiedOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyMinMaxAffineOpsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 1c82822b2bd7f..c792200f4a49a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   ReifyValueBounds.cpp
   SuperVectorize.cpp
   SimplifyAffineStructures.cpp
+  SimplifyAffineMinMax.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
new file mode 100644
index 0000000000000..0ddf2ec192a3e
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -0,0 +1,153 @@
+//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a transform to simplify mix/max affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/IntEqClasses.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-min-max"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Simplifies an affine min/max operation by proving there's a lower or upper
+/// bound.
+template <typename AffineOp>
+static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
+  using Variable = ValueBoundsConstraintSet::Variable;
+  using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
+
+  AffineMap affineMap = affineOp.getMap();
+  ValueRange operands = affineOp.getOperands();
+  static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
+
+  LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+
+  // Create a `Variable` list with values corresponding to each of the results
+  // in the affine affineMap.
+  SmallVector<Variable> variables = llvm::map_to_vector(
+      llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
+      [&](unsigned i) {
+        return Variable(affineMap.getSliceMap(i, 1), operands);
+      });
+
+  // Get the comparison operation.
+  ComparisonOperator cmpOp =
+      isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
+
+  // Find disjoint sets bounded by a common value.
+  llvm::IntEqClasses boundedClasses(variables.size());
+  DenseMap<unsigned, Variable *> bounds;
+  for (auto &&[i, v] : llvm::enumerate(variables)) {
+    unsigned eqClass = boundedClasses.findLeader(i);
+
+    // If the class already has a bound continue.
+    if (bounds.contains(eqClass))
+      continue;
+
+    // Initialize the bound.
+    Variable *bound = &v;
+
+    LLVM_DEBUG({
+      DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+             << "`\n";
+    });
+
+    // Check against the other variables.
+    for (size_t j = i + 1; j < variables.size(); ++j) {
+      unsigned jEqClass = boundedClasses.findLeader(j);
+      // Get the bound of the equivalence class or itself.
+      Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
+
+      LLVM_DEBUG({
+        DBGS() << "- comparing with variable: #" << jEqClass
+               << ", with map: " << nv->getMap() << "\n";
+      });
+
+      // Compare the variables.
+      std::optional<bool> cmpResult =
+          ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
+
+      // The variables cannot be compared.
+      if (!cmpResult) {
+        LLVM_DEBUG({
+          DBGS() << "-- classes: #" << i << ", #" << jEqClass
+                 << " cannot be merged\n";
+        });
+        continue;
+      }
+
+      // Join the equivalent classes and update the bound if necessary.
+      LLVM_DEBUG({
+        DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
+               << ", is lhs <= rhs: " << *cmpResult << "`\n";
+      });
+      if (*cmpResult) {
+        boundedClasses.join(eqClass, jEqClass);
+      } else {
+        // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
+        // isMin == false.
+        bound = nv;
+        boundedClasses.join(eqClass, jEqClass);
+      }
+    }
+    bounds[boundedClasses.findLeader(i)] = bound;
+  }
+
+  // Return if there's no simplification.
+  if (bounds.size() >= affineMap.getNumResults()) {
+    LLVM_DEBUG(
+        { DBGS() << "- the affine operation couldn't get simplified\n"; });
+    return false;
+  }
+
+  // Construct the new affine affineMap.
+  SmallVector<AffineExpr> results;
+  results.reserve(bounds.size());
+  for (auto [k, bound] : bounds)
+    results.push_back(bound->getMap().getResult(0));
+
+  affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
+                             results, rewriter.getContext());
+
+  // Update the affine op.
+  rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
+  LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+  return true;
+}
+
+bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
+  return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
+  return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+void mlir::affine::simplifyAffineMinMaxOps(
+    RewriterBase &rewriter, Operation *topOp,
+    SmallVectorImpl<Operation *> &modifiedOps) {
+  assert(topOp && "null-op");
+  topOp->walk([&](Operation *op) {
+    if (auto affineOp = dyn_cast<AffineMinOp>(op)) {
+      if (simplifyAffineMinMaxOp(rewriter, affineOp))
+        modifiedOps.push_back(op);
+    } else if (auto affineOp = dyn_cast<AffineMaxOp>(op)) {
+      if (simplifyAffineMinMaxOp(rewriter, affineOp))
+        modifiedOps.push_back(op);
+    }
+  });
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 87f883c2e6485..d7a6187cafb1e 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
 }
 
 ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
-                                             ArrayRef<Value> mapOperands)
+                                             ValueRange mapOperands)
     : Variable(map, llvm::map_to_vector(mapOperands,
                                         [](Value v) { return Variable(v); })) {}
 
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
   return isEmpty;
 }
 
+std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
+    int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
+  auto strongCmp = [&](ComparisonOperator cmp,
+                       ComparisonOperator negCmp) -> std::optional<bool> {
+    if (comparePos(lhsPos, cmp, rhsPos))
+      return true;
+    if (comparePos(lhsPos, negCmp, rhsPos))
+      return false;
+    return std::nullopt;
+  };
+  switch (cmp) {
+  case ComparisonOperator::LT:
+    return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
+  case ComparisonOperator::LE:
+    return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
+  case ComparisonOperator::GT:
+    return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
+  case ComparisonOperator::GE:
+    return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
+  case ComparisonOperator::EQ: {
+    std::optional<bool> le =
+        strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
+    if (!le)
+      return std::nullopt;
+    if (!*le)
+      return false;
+    std::optional<bool> ge =
+        strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+    if (!ge)
+      return std::nullopt;
+    if (!*ge)
+      return false;
+    return true;
+  }
+  }
+  llvm_unreachable("invalid comparison operator");
+}
+
 bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
                                                   ComparisonOperator cmp,
                                                   const Variable &rhs) {
@@ -763,6 +801,25 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
   return cstr.comparePos(lhsPos, cmp, rhsPos);
 }
 
+std::optional<bool> ValueBoundsConstraintSet::strongCompare(
+    const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) {
+  int64_t lhsPos = -1, rhsPos = -1;
+  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+                           ValueBoundsConstraintSet &cstr) {
+    // Keep processing as long as lhs/rhs were not processed.
+    if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
+        size_t(rhsPos) >= cstr.positionToValueDim.size())
+      return false;
+    // Keep processing as long as the strong relation cannot be proven.
+    std::optional<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+    return ordered ? true : false;
+  };
+  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+  lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+  rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
+  return cstr.strongComparePos(lhsPos, cmp, rhsPos);
+}
+
 FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
                                                    const Variable &var2) {
   if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
new file mode 100644
index 0000000000000..2b6de62073e99
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt  %s  --transform-interpreter | FileCheck %s
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+  %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK-NOT: affine.min
+  // CHECK-NOT: affine.max
+  // CHECK: return %[[V0]], %[[V1]]
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+  // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+  %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+  %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+  %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+  return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @overlapping_constraints() -> (index, index) {
+  %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
+  %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
+  %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
+  // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
+  // CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+  // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+  // CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
+  %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+  %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+  return %r0, %r1 : index, index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+    %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %1 {
+      transform.apply_patterns.canonicalization
+    } {apply_cse} : !transform.any_op
+    transform.yield 
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index bc684b53c9b61..53e67321000fd 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -454,3 +454,38 @@ module attributes {transform.with_named_sequence} {
     transform.yield 
   }
 }
+
+// -----
+
+// This test checks that by using `simplify_min_max_affine_ops` after padding
+// and tiling, it's possible to recover static tiled slices.
+
+// CHECK-LABEL: @dyn_pad_tiling
+// CHECK: %[[LHS:.*]] = tensor.pad
+// CHECK: %[[RHS:.*]] = tensor.pad
+// CHECK: scf.for
+// CHECK: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
+// CHECK: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
+func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %2 {
+      transform.apply_patterns.canonicalization
+    } {apply_cse} : !transform.any_op
+    transform.affine.simplify_min_max_affine_ops %2 : !transform.any_op
+    %3 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %3 {
+      transform.apply_patterns.canonicalization
+    } {apply_cse} : !transform.any_op
+    transform.yield 
+  }
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 78e44c6ec7a9b..6c1a5d3441530 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -836,6 +836,16 @@ void ConversionFuncOp::print(OpAsmPrinter &p) {
       getArgAttrsAttrName(), getResAttrsAttrName());
 }
 
+//===----------------------------------------------------------------------===//
+// TestValueWithBoundsOp
+//===----------------------------------------------------------------------===//
+
+void TestValueWithBoundsOp::populateBoundsForIndexValue(
+    Value v, ValueBoundsConstraintSet &cstr) {
+  cstr.bound(v) >= getMin().getSExtValue();
+  cstr.bound(v) <= getMax().getSExtValue();
+}
+
 //===----------------------------------------------------------------------===//
 // ReifyBoundOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 30234698bc8dd..8a4981a90831f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -31,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ValueBoundsOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 
 // Include the attribute definitions.
@@ -2375,6 +2376,24 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
 // Test ValueBoundsOpInterface
 //===----------------------------------------------------------------------===//
 
+def TestValueWithBoundsOp : TEST_Op<"value_with_bounds", [
+    DeclareOpInterfaceMethods<ValueBoundsOpInterface, ["populateBoundsForIndexValue"]>
+  ]> {
+  let description = [{
+    Creates a value with specified [min, max] range for value bounds analysis.
+
+    Example:
+
+    ```mlir
+    %0 = test.value_with_bounds { min = 4 : index, max = 5 : index}
+    ```
+  }];
+  let arguments = (ins IndexAttr:$min, IndexAttr:$max);
+  let results = (outs Index:$result);
+  let assemblyFormat = "attr-dict";
+}
+
+
 def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
   let description = [{
     Reify a bound for the given index-typed value or dimension size of a shaped

>From 9f4341d28cc98619e9b47a6636737ff7ed62a072 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Fri, 20 Jun 2025 14:17:14 -0400
Subject: [PATCH 2/3] fix nit

---
 .../mlir/Dialect/Affine/TransformOps/AffineTransformOps.td       | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 4659ae28ee093..a69dd665ec125 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -84,7 +84,6 @@ def SimplifyMinMaxAffineOpsOp :
 
     This transform consumes the target handle and does not produce any results.
     This transforms never produces errors.
-
   }];
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs);

>From 4f48eb0bda938f6f5d2815d5a521e26131460bdc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 20 Jun 2025 19:01:20 +0000
Subject: [PATCH 3/3] address reviewer comments

---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  | 18 ++++++-------
 .../Transforms/SimplifyAffineMinMax.cpp       |  4 +--
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 26 ++++++++-----------
 3 files changed, 22 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 39206b89ef8c6..d168735f50598 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -262,11 +262,11 @@ class ValueBoundsConstraintSet
   static bool compare(const Variable &lhs, ComparisonOperator cmp,
                       const Variable &rhs);
   /// This function is similar to `ValueBoundsConstraintSet::compare`, except
-  /// that it returns false if `!(lhs cmp rhs)`, and `std::nullopt` if the
-  /// values couldn't be compared.
-  static std::optional<bool> strongCompare(const Variable &lhs,
-                                           ComparisonOperator cmp,
-                                           const Variable &rhs);
+  /// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
+  /// relation nor its inverse relation could be proven.
+  static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
+                                             ComparisonOperator cmp,
+                                             const Variable &rhs);
 
   /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
@@ -342,13 +342,13 @@ class ValueBoundsConstraintSet
 
   /// Return "true" if, based on the current state of the constraint system,
   /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
-  /// can be proven. Otherwise it returns `std::nullopt` meaning the values are
-  /// unordered with respect to the constraints.
+  /// can be proven. Otherwise, it returns `failure` if neither the relation nor
+  /// its inverse relation could be proven.
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  std::optional<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
-                                       int64_t rhsPos);
+  llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+                                         int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
   /// column to the constraint set that represents the result of the map.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index 0ddf2ec192a3e..1586af561e174 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -78,11 +78,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
       });
 
       // Compare the variables.
-      std::optional<bool> cmpResult =
+      FailureOr<bool> cmpResult =
           ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
 
       // The variables cannot be compared.
-      if (!cmpResult) {
+      if (failed(cmpResult)) {
         LLVM_DEBUG({
           DBGS() << "-- classes: #" << i << ", #" << jEqClass
                  << " cannot be merged\n";
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index d7a6187cafb1e..c9481fb5d9406 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -736,15 +736,15 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
   return isEmpty;
 }
 
-std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
+FailureOr<bool> ValueBoundsConstraintSet::strongComparePos(
     int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
   auto strongCmp = [&](ComparisonOperator cmp,
-                       ComparisonOperator negCmp) -> std::optional<bool> {
+                       ComparisonOperator negCmp) -> FailureOr<bool> {
     if (comparePos(lhsPos, cmp, rhsPos))
       return true;
     if (comparePos(lhsPos, negCmp, rhsPos))
       return false;
-    return std::nullopt;
+    return failure();
   };
   switch (cmp) {
   case ComparisonOperator::LT:
@@ -759,13 +759,13 @@ std::optional<bool> ValueBoundsConstraintSet::strongComparePos(
     std::optional<bool> le =
         strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
     if (!le)
-      return std::nullopt;
+      return failure();
     if (!*le)
       return false;
     std::optional<bool> ge =
         strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
     if (!ge)
-      return std::nullopt;
+      return failure();
     if (!*ge)
       return false;
     return true;
@@ -801,8 +801,9 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
   return cstr.comparePos(lhsPos, cmp, rhsPos);
 }
 
-std::optional<bool> ValueBoundsConstraintSet::strongCompare(
-    const Variable &lhs, ComparisonOperator cmp, const Variable &rhs) {
+FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
+                                                        ComparisonOperator cmp,
+                                                        const Variable &rhs) {
   int64_t lhsPos = -1, rhsPos = -1;
   auto stopCondition = [&](Value v, std::optional<int64_t> dim,
                            ValueBoundsConstraintSet &cstr) {
@@ -811,8 +812,8 @@ std::optional<bool> ValueBoundsConstraintSet::strongCompare(
         size_t(rhsPos) >= cstr.positionToValueDim.size())
       return false;
     // Keep processing as long as the strong relation cannot be proven.
-    std::optional<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
-    return ordered ? true : false;
+    FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+    return failed(ordered) ? true : false;
   };
   ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
   lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
@@ -822,12 +823,7 @@ std::optional<bool> ValueBoundsConstraintSet::strongCompare(
 
 FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
                                                    const Variable &var2) {
-  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
-    return true;
-  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
-      ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
-    return false;
-  return failure();
+  return strongCompare(var1, ComparisonOperator::EQ, var2);
 }
 
 FailureOr<bool>



More information about the Mlir-commits mailing list