[Mlir-commits] [mlir] c716558 - [mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (#145068)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 22 21:05:24 PDT 2025
Author: Fabian Mora
Date: 2025-06-23T06:05:20+02:00
New Revision: c7165587e49605452f96249412f123b47b78bb81
URL: https://github.com/llvm/llvm-project/commit/c7165587e49605452f96249412f123b47b78bb81
DIFF: https://github.com/llvm/llvm-project/commit/c7165587e49605452f96249412f123b47b78bb81.diff
LOG: [mlir][affine|ValueBounds] Add transform to simplify affine min max ops with ValueBoundsOpInterface (#145068)
This commit makes the following changes:
- Expose `map` and `mapOperands` in
`ValueBoundsConstraintSet::Variable`, so that the class can be used by
subclasses of `ValueBoundsConstraintSet`. Otherwise subclasses cannot
access those members.
- Add `ValueBoundsConstraintSet::strongCompare`. This method is similar
to `ValueBoundsConstraintSet::compare` except that it returns false when
the inverse comparison holds, and `llvm::failure()` if neither the
relation nor its inverse relation could be proven.
- Add `simplifyAffineMinOp`, `simplifyAffineMaxOp`, and
`simplifyAffineMinMaxOps` to simplify those operations using
`ValueBoundsConstraintSet`.
- Adds the `SimplifyMinMaxAffineOpsOp` transform op that uses
`simplifyAffineMinMaxOps`.
- Add the `test.value_with_bounds` op to test unknown values with a min
max range using `ValueBoundsOpInterface`.
- Adds tests verifying the transform.
Example:
```mlir
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {min = 0 : index, max = 192 : index}
%1 = test.value_with_bounds {min = 128 : index, max = 384 : index}
%2 = test.value_with_bounds {min = 256 : index, max = 512 : index}
%r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
%r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
return %r0, %r1 : index, index
}
// Result of applying `simplifyAffineMinMaxOps` to `func.func`
#map1 = affine_map<()[s0, s1] -> (s1, s0)>
func.func @overlapping_constraints() -> (index, index) {
%0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
%1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
%2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
%3 = affine.min #map1()[%0, %1]
%4 = affine.max #map1()[%1, %2]
return %3, %4 : index, index
}
```
---------
Co-authored-by: Nicolas Vasilache <Nico.Vasilache at amd.com>
Added:
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
Modified:
mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Linalg/transform-op-pad.mlir
mlir/test/lib/Dialect/Test/TestOpDefs.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index 70b127fd063ca..2969b4238dd67 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -63,4 +63,35 @@ def SimplifyBoundedAffineOpsOp
}];
}
+def SimplifyMinMaxAffineOpsOp :
+ Op<Transform_Dialect, "affine.simplify_min_max_affine_ops", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+ ]> {
+ let description = [{
+ Simplify the targeted `affine.min` / `affine.max` ops using the
+ `mlir::affine::simplifyAffineMinMaxOps` transform.
+
+ Example:
+ ```
+ %0 = transform.structured.match ops{["affine.max"]} in %arg1
+ transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+ ```
+
+ #### Return modes
+
+ This transform consumes the target handle and does not produce any results.
+ This transforms definitely fails if any of the targeted operations is not an
+ `affine.min` or `affine.max` operation, or if the canonicalization patterns
+ failed to converge.
+ This transform silently fails if none of the operations were simplified.
+ Otherwise, it succeeds.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+ let assemblyFormat = [{
+ $target attr-dict `:` type($target)
+ }];
+}
+
#endif // Affine_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 5c538d28c1835..272054448374e 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -34,6 +34,8 @@ namespace affine {
class AffineApplyOp;
class AffineDelinearizeIndexOp;
class AffineLinearizeIndexOp;
+class AffineMaxOp;
+class AffineMinOp;
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
@@ -127,6 +129,37 @@ OpFoldResult materializeComputedBound(
OpBuilder &b, Location loc, AffineMap boundMap,
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands);
+/// This transform tries to simplify the affine min operation `op`, by finding a
+/// common lower bound for a set of expressions in the affine map results. It
+/// returns whether the transform updated `op`'s affine map.
+///
+/// In concrete terms, given an operation like:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
+/// If `d0 < 128` and `128 < s1 < s0`, the transform will update `op` to:
+/// `affine.min affine_map<(d0)[s0, s1] -> (d0, 128)>(%i)[%s0, %s1]`.
+bool simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op);
+
+/// This transform tries to simplify the affine max operation `op`, by finding a
+/// common upper bound for a set of expressions in the affine map results. It
+/// returns whether the transform updated `op`'s affine map.
+///
+/// In concrete terms, given an operation like:
+/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s1, s0, 128)>(%i)[%s0, %s1]`
+/// If `d0 > 128` and `s0 > s1 > 128`, the transform will update `op` to:
+/// `affine.max affine_map<(d0)[s0, s1] -> (d0, s0)>(%i)[%s0, %s1]`.
+bool simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op);
+
+/// This transform applies `simplifyAffineMinOp` and `simplifyAffineMaxOp` to
+/// all the `affine.min` or `affine.max` operations in `ops`. After
+/// simplification, it invokes the `affine.min/max` canonicalization patterns on
+/// `ops`.
+///
+/// This transform returns failure if the greedy pattern rewriter failed to
+/// converge during canonicalization, otherwise it returns success. If provided,
+/// `modified` is set to `true` if the IR was modified in any way.
+LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter,
+ ArrayRef<Operation *> ops,
+ bool *modified = nullptr);
} // namespace affine
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 337314143c80c..d168735f50598 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -135,10 +135,17 @@ class ValueBoundsConstraintSet
/// Construct a variable for a map and its operands.
Variable(AffineMap map, ArrayRef<Variable> mapOperands);
- Variable(AffineMap map, ArrayRef<Value> mapOperands);
+ Variable(AffineMap map, ValueRange mapOperands);
MLIRContext *getContext() const { return map.getContext(); }
+ /// Returns the affine map.
+ AffineMap getMap() const { return map; }
+
+ /// Returns the map operands.
+ ValueDimList &getOperands() { return mapOperands; }
+ const ValueDimList &getOperands() const { return mapOperands; }
+
private:
friend class ValueBoundsConstraintSet;
AffineMap map;
@@ -254,6 +261,12 @@ class ValueBoundsConstraintSet
/// prove the relation or until it ran out of IR.
static bool compare(const Variable &lhs, ComparisonOperator cmp,
const Variable &rhs);
+ /// This function is similar to `ValueBoundsConstraintSet::compare`, except
+ /// that it returns false if `!(lhs cmp rhs)`, and `failure` if neither the
+ /// relation nor its inverse relation could be proven.
+ static llvm::FailureOr<bool> strongCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs);
/// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
@@ -327,6 +340,16 @@ class ValueBoundsConstraintSet
/// constraints.
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+ /// Return "true" if, based on the current state of the constraint system,
+ /// "lhs cmp rhs" was proven to hold. It returns "false" if "!(lhs cmp rhs)"
+ /// can be proven. Otherwise, it returns `failure` if neither the relation nor
+ /// its inverse relation could be proven.
+ ///
+ /// This function does not analyze any IR and does not populate any additional
+ /// constraints.
+ llvm::FailureOr<bool> strongComparePos(int64_t lhsPos, ComparisonOperator cmp,
+ int64_t rhsPos);
+
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index c9fe4474a68fa..b1e40d9b289ec 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -112,7 +113,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
}
if (boundedOps.contains(target)) {
auto diag = emitDefiniteFailure()
- << "target op result must not be constrainted";
+ << "target op result must not be constrained";
diag.attachNote(target->getLoc()) << "target/constrained op";
return diag;
}
@@ -148,6 +149,42 @@ void SimplifyBoundedAffineOpsOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// SimplifyMinMaxAffineOpsOp
+//===----------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &results,
+ TransformState &state) {
+ SmallVector<Operation *> targets;
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ if (!isa<AffineMinOp, AffineMaxOp>(target)) {
+ auto diag = emitDefiniteFailure()
+ << "target must be affine.min or affine.max";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ targets.push_back(target);
+ }
+ bool modified = false;
+ if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
+ &modified))) {
+ return emitDefiniteFailure()
+ << "affine.min/max simplification did not converge";
+ }
+ if (!modified) {
+ return emitSilenceableError()
+ << "the transform failed to simplify any of the target operations";
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyMinMaxAffineOpsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 1c82822b2bd7f..c792200f4a49a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
+ SimplifyAffineMinMax.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
new file mode 100644
index 0000000000000..c992badcfa493
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -0,0 +1,174 @@
+//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a transform to simplify mix/max affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/IntEqClasses.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "affine-min-max"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Simplifies an affine min/max operation by proving there's a lower or upper
+/// bound.
+template <typename AffineOp>
+static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
+ using Variable = ValueBoundsConstraintSet::Variable;
+ using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
+
+ AffineMap affineMap = affineOp.getMap();
+ ValueRange operands = affineOp.getOperands();
+ static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
+
+ LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+
+ // Create a `Variable` list with values corresponding to each of the results
+ // in the affine affineMap.
+ SmallVector<Variable> variables = llvm::map_to_vector(
+ llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
+ [&](unsigned i) {
+ return Variable(affineMap.getSliceMap(i, 1), operands);
+ });
+
+ // Get the comparison operation.
+ ComparisonOperator cmpOp =
+ isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
+
+ // Find disjoint sets bounded by a common value.
+ llvm::IntEqClasses boundedClasses(variables.size());
+ DenseMap<unsigned, Variable *> bounds;
+ for (auto &&[i, v] : llvm::enumerate(variables)) {
+ unsigned eqClass = boundedClasses.findLeader(i);
+
+ // If the class already has a bound continue.
+ if (bounds.contains(eqClass))
+ continue;
+
+ // Initialize the bound.
+ Variable *bound = &v;
+
+ LLVM_DEBUG({
+ DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
+ });
+
+ // Check against the other variables.
+ for (size_t j = i + 1; j < variables.size(); ++j) {
+ unsigned jEqClass = boundedClasses.findLeader(j);
+ // Skip if the class is the same.
+ if (jEqClass == eqClass)
+ continue;
+
+ // Get the bound of the equivalence class or itself.
+ Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
+
+ LLVM_DEBUG({
+ DBGS() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap() << "\n";
+ });
+
+ // Compare the variables.
+ FailureOr<bool> cmpResult =
+ ValueBoundsConstraintSet::strongCompare(*bound, cmpOp, *nv);
+
+ // The variables cannot be compared.
+ if (failed(cmpResult)) {
+ LLVM_DEBUG({
+ DBGS() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged\n";
+ });
+ continue;
+ }
+
+ // Join the equivalent classes and update the bound if necessary.
+ LLVM_DEBUG({
+ DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
+ });
+ if (*cmpResult) {
+ boundedClasses.join(eqClass, jEqClass);
+ } else {
+ // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
+ // isMin == false.
+ bound = nv;
+ boundedClasses.join(eqClass, jEqClass);
+ }
+ }
+ bounds[boundedClasses.findLeader(i)] = bound;
+ }
+
+ // Return if there's no simplification.
+ if (bounds.size() >= affineMap.getNumResults()) {
+ LLVM_DEBUG(
+ { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ return false;
+ }
+
+ // Construct the new affine affineMap.
+ SmallVector<AffineExpr> results;
+ results.reserve(bounds.size());
+ for (auto [k, bound] : bounds)
+ results.push_back(bound->getMap().getResult(0));
+
+ affineMap = AffineMap::get(affineMap.getNumDims(), affineMap.getNumSymbols(),
+ results, rewriter.getContext());
+
+ // Update the affine op.
+ rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
+ LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ return true;
+}
+
+bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
+ return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
+ return simplifyAffineMinMaxOp(rewriter, op);
+}
+
+LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
+ ArrayRef<Operation *> ops,
+ bool *modified) {
+ bool changed = false;
+ for (Operation *op : ops) {
+ if (auto minOp = dyn_cast<AffineMinOp>(op))
+ changed = simplifyAffineMinOp(rewriter, minOp) || changed;
+ else if (auto maxOp = cast<AffineMaxOp>(op))
+ changed = simplifyAffineMaxOp(rewriter, maxOp) || changed;
+ }
+ RewritePatternSet patterns(rewriter.getContext());
+ AffineMaxOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
+ AffineMinOp::getCanonicalizationPatterns(patterns, rewriter.getContext());
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ if (modified)
+ *modified = changed;
+ // Canonicalize to a fixpoint.
+ if (failed(applyOpPatternsGreedily(
+ ops, frozenPatterns,
+ GreedyRewriteConfig()
+ .setListener(
+ static_cast<RewriterBase::Listener *>(rewriter.getListener()))
+ .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
+ &changed))) {
+ return failure();
+ }
+ if (modified)
+ *modified = changed;
+ return success();
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 87f883c2e6485..c9481fb5d9406 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -146,7 +146,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
}
ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
- ArrayRef<Value> mapOperands)
+ ValueRange mapOperands)
: Variable(map, llvm::map_to_vector(mapOperands,
[](Value v) { return Variable(v); })) {}
@@ -736,6 +736,44 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
return isEmpty;
}
+FailureOr<bool> ValueBoundsConstraintSet::strongComparePos(
+ int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos) {
+ auto strongCmp = [&](ComparisonOperator cmp,
+ ComparisonOperator negCmp) -> FailureOr<bool> {
+ if (comparePos(lhsPos, cmp, rhsPos))
+ return true;
+ if (comparePos(lhsPos, negCmp, rhsPos))
+ return false;
+ return failure();
+ };
+ switch (cmp) {
+ case ComparisonOperator::LT:
+ return strongCmp(ComparisonOperator::LT, ComparisonOperator::GE);
+ case ComparisonOperator::LE:
+ return strongCmp(ComparisonOperator::LE, ComparisonOperator::GT);
+ case ComparisonOperator::GT:
+ return strongCmp(ComparisonOperator::GT, ComparisonOperator::LE);
+ case ComparisonOperator::GE:
+ return strongCmp(ComparisonOperator::GE, ComparisonOperator::LT);
+ case ComparisonOperator::EQ: {
+ std::optional<bool> le =
+ strongComparePos(lhsPos, ComparisonOperator::LE, rhsPos);
+ if (!le)
+ return failure();
+ if (!*le)
+ return false;
+ std::optional<bool> ge =
+ strongComparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+ if (!ge)
+ return failure();
+ if (!*ge)
+ return false;
+ return true;
+ }
+ }
+ llvm_unreachable("invalid comparison operator");
+}
+
bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
ComparisonOperator cmp,
const Variable &rhs) {
@@ -763,14 +801,29 @@ bool ValueBoundsConstraintSet::compare(const Variable &lhs,
return cstr.comparePos(lhsPos, cmp, rhsPos);
}
+FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs,
+ ComparisonOperator cmp,
+ const Variable &rhs) {
+ int64_t lhsPos = -1, rhsPos = -1;
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs were not processed.
+ if (size_t(lhsPos) >= cstr.positionToValueDim.size() ||
+ size_t(rhsPos) >= cstr.positionToValueDim.size())
+ return false;
+ // Keep processing as long as the strong relation cannot be proven.
+ FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos);
+ return failed(ordered) ? true : false;
+ };
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+ rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
+ return cstr.strongComparePos(lhsPos, cmp, rhsPos);
+}
+
FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
const Variable &var2) {
- if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
- return true;
- if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
- ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
- return false;
- return failure();
+ return strongCompare(var1, ComparisonOperator::EQ, var2);
}
FailureOr<bool>
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
new file mode 100644
index 0000000000000..948f434f3fa5e
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-min-max-ops.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-DAG: #[[MAP_0:.*]] = affine_map<()[s0] -> (32, s0)>
+// CHECK-DAG: #[[MAP_1:.*]] = affine_map<()[s0, s1] -> (s1, s0)>
+// CHECK-DAG: #[[MAP_2:.*]] = affine_map<()[s0] -> (256, s0)>
+
+// CHECK: @min_max_full_simplify
+func.func @min_max_full_simplify() -> (index, index) {
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK-NOT: affine.min
+ // CHECK-NOT: affine.max
+ // CHECK: return %[[V0]], %[[V1]]
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 192, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @min_only_simplify
+func.func @min_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_0]]()[%[[V0]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ %0 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 32, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @max_only_simplify
+func.func @max_only_simplify() -> (index, index) {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_2]]()[%[[V1]]]
+ %0 = test.value_with_bounds {max = 128 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 512 : index, min = 0 : index}
+ %r0 = affine.min affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ %r1 = affine.max affine_map<()[s0, s1] -> (s0, 256, s1)>()[%0, %1]
+ return %r0, %r1 : index, index
+}
+
+// CHECK: @overlapping_constraints
+func.func @overlapping_constraints() -> (index, index) {
+ %0 = test.value_with_bounds {max = 192 : index, min = 0 : index}
+ %1 = test.value_with_bounds {max = 384 : index, min = 128 : index}
+ %2 = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 192 : index, min = 0 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 384 : index, min = 128 : index}
+ // CHECK: %[[V2:.*]] = test.value_with_bounds {max = 512 : index, min = 256 : index}
+ // CHECK: affine.min #[[MAP_1]]()[%[[V0]], %[[V1]]]
+ // CHECK: affine.max #[[MAP_1]]()[%[[V1]], %[[V2]]]
+ %r0 = affine.min affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+ %r1 = affine.max affine_map<()[s0, s1, s2] -> (s0, s1, s2)>()[%0, %1, %2]
+ return %r0, %r1 : index, index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.affine.simplify_min_max_affine_ops %0 : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index bc684b53c9b61..f91eb9c30a51a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -454,3 +454,38 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// This test checks that by using `simplify_min_max_affine_ops` after padding
+// and tiling, it's possible to recover static tiled slices.
+
+// CHECK-LABEL: @dyn_pad_tiling
+// CHECK: %[[LHS:.*]] = tensor.pad
+// CHECK: %[[RHS:.*]] = tensor.pad
+// CHECK: scf.for
+// CHECK-DAG: tensor.extract_slice %[[LHS]][0, %{{.*}}] [%{{.*}}, 32]
+// CHECK-DAG: tensor.extract_slice %[[RHS]][0, %{{.*}}] [%{{.*}}, 32]
+func.func @dyn_pad_tiling(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad, %copy = transform.structured.pad %0 pad_to_multiple_of [32] use_prescribed_tensor_shapes {padding_dimensions = [2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %tiled_linalg_op, %loops = transform.structured.tile_using_for %padded tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.apply_registered_pass "resolve-shaped-type-result-dims" to %1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %2 {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+ %3 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.affine.simplify_min_max_affine_ops %3 : !transform.any_op
+ transform.apply_patterns to %2 {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+ transform.yield
+ }
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 78e44c6ec7a9b..6c1a5d3441530 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -836,6 +836,16 @@ void ConversionFuncOp::print(OpAsmPrinter &p) {
getArgAttrsAttrName(), getResAttrsAttrName());
}
+//===----------------------------------------------------------------------===//
+// TestValueWithBoundsOp
+//===----------------------------------------------------------------------===//
+
+void TestValueWithBoundsOp::populateBoundsForIndexValue(
+ Value v, ValueBoundsConstraintSet &cstr) {
+ cstr.bound(v) >= getMin().getSExtValue();
+ cstr.bound(v) <= getMax().getSExtValue();
+}
+
//===----------------------------------------------------------------------===//
// ReifyBoundOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 30234698bc8dd..8a4981a90831f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -31,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
// Include the attribute definitions.
@@ -2375,6 +2376,24 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
// Test ValueBoundsOpInterface
//===----------------------------------------------------------------------===//
+def TestValueWithBoundsOp : TEST_Op<"value_with_bounds", [
+ DeclareOpInterfaceMethods<ValueBoundsOpInterface, ["populateBoundsForIndexValue"]>
+ ]> {
+ let description = [{
+ Creates a value with specified [min, max] range for value bounds analysis.
+
+ Example:
+
+ ```mlir
+ %0 = test.value_with_bounds { min = 4 : index, max = 5 : index}
+ ```
+ }];
+ let arguments = (ins IndexAttr:$min, IndexAttr:$max);
+ let results = (outs Index:$result);
+ let assemblyFormat = "attr-dict";
+}
+
+
def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let description = [{
Reify a bound for the given index-typed value or dimension size of a shaped
More information about the Mlir-commits
mailing list