[Mlir-commits] [mlir] 8f4da2c - [mlir][affine] Fix min simplification in makeComposedAffineApply (#145376)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 24 04:55:16 PDT 2025
Author: Fabian Mora
Date: 2025-06-24T07:55:12-04:00
New Revision: 8f4da2cbf055ec7b9b66d757afcba1b942385874
URL: https://github.com/llvm/llvm-project/commit/8f4da2cbf055ec7b9b66d757afcba1b942385874
DIFF: https://github.com/llvm/llvm-project/commit/8f4da2cbf055ec7b9b66d757afcba1b942385874.diff
LOG: [mlir][affine] Fix min simplification in makeComposedAffineApply (#145376)
This patch fixes a bug discovered in the
`affine::makeComposedFoldedAffineApply` function when `composeAffineMin
== true`. The bug happened because the simplification assumed the
symbols appearing in the `affine.apply` op corresponded to symbols in
the `affine.min` op, and that's not always the case. For example:
```mlir
#map = affine_map<()[s0, s1] -> (s1)>
#map1 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
module {
func.func @min_max_full_simplify() -> index {
%0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
%1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
%2 = affine.min #map()[%0, %1]
%3 = affine.apply #map1()[%2, %0]
return %3 : index
}
}
```
This patch also introduces the test `make_composed_folded_affine_apply`
transform operation to test this simplification. It also adds tests
ensuring we get correct behavior.
---------
Co-authored-by: Nicolas Vasilache <nico.vasilache at amd.com>
Added:
mlir/test/Transforms/make-composed-folded-affine-apply.mlir
Modified:
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/test/lib/Transforms/TestTransformsOps.cpp
mlir/test/lib/Transforms/TestTransformsOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3b4d51d914d86..f577883085608 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1046,59 +1046,81 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
map.getContext());
}
-/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`.
-/// Assuming that the quantity is of the form:
-/// `affine_min(f(x, y), symbolic_cst)`.
-/// This function checks that `0 < affine_min(f(x, y), symbolic_cst)` and
-/// proceeds with replacing the patterns:
+/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
+/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
+/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
/// ```
-/// dimOrSym.ceildiv(symbolic_cst)
-/// (dimOrSym + symbolic_cst - 1).floordiv(symbolic_cst)
+/// dimOrSym.ceildiv(x_k)
+/// (dimOrSym + x_k - 1).floordiv(x_k)
/// ```
-/// by `1`.
+/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
///
-/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
-/// inject static information that may not be statically discoverable.
///
/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
-/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
-static LogicalResult replaceAffineMinBoundingBoxExpression(
- AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map,
- bool affineMinKnownToBeNonNegative = false) {
- auto affineMinMap = minOp.getAffineMap();
- if (!affineMinKnownToBeNonNegative) {
- ValueRange values = minOp->getOperands();
- for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
- AffineMap row = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
- FailureOr<int64_t> lowerBound =
- ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::LB, {row, values},
- /*stopCondition=*/nullptr,
- /*closedUB=*/true);
- if (failed(lowerBound) || lowerBound.value() <= 0)
- return failure();
+/// `minOp` is positive.
+static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
+ AffineExpr dimOrSym,
+ AffineMap *map,
+ ValueRange dims,
+ ValueRange syms) {
+ AffineMap affineMinMap = minOp.getAffineMap();
+
+ // Check the value is positive.
+ for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
+ // Compare each expression in the minimum against 0.
+ if (!ValueBoundsConstraintSet::compare(
+ getAsIndexOpFoldResult(minOp.getContext(), 0),
+ ValueBoundsConstraintSet::ComparisonOperator::LT,
+ ValueBoundsConstraintSet::Variable(affineMinMap.getSliceMap(i, 1),
+ minOp.getOperands())))
+ return failure();
+ }
+
+ /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
+ /// the apply op affine map.
+ DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
+ SmallVector<unsigned> unmappedDims, unmappedSyms;
+ for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
+ auto it = llvm::find(dims, dim);
+ if (it == dims.end()) {
+ unmappedDims.push_back(i);
+ continue;
}
+ dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
+ getAffineDimExpr(it.getIndex(), minOp.getContext());
+ }
+ for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
+ auto it = llvm::find(syms, sym);
+ if (it == syms.end()) {
+ unmappedSyms.push_back(i);
+ continue;
+ }
+ dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
+ getAffineSymbolExpr(it.getIndex(), minOp.getContext());
}
- AffineMap initialMap = *map;
- for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
- auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
- AffineExpr expr = m.getResult(0);
- if (!expr.isSymbolicOrConstant())
+ // Create the replacement map.
+ DenseMap<AffineExpr, AffineExpr> repl;
+ AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
+ for (AffineExpr expr : affineMinMap.getResults()) {
+ // If we cannot express the result in terms of the apply map symbols and
+ // sims then continue.
+ if (llvm::any_of(unmappedDims,
+ [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
+ llvm::any_of(unmappedSyms,
+ [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
continue;
- DenseMap<AffineExpr, AffineExpr> repl;
+ AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
+
// dimOrSym.ceilDiv(expr) -> 1
- repl[dimOrSym.ceilDiv(expr)] = getAffineConstantExpr(1, minOp.getContext());
+ repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
// (dimOrSym + expr - 1).floorDiv(expr) -> 1
- repl[(dimOrSym + expr - 1).floorDiv(expr)] =
- getAffineConstantExpr(1, minOp.getContext());
- auto newMap = map->replace(repl);
- if (newMap == *map)
- continue;
- *map = newMap;
+ repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
}
-
+ AffineMap initialMap = *map;
+ *map = initialMap.replace(repl, initialMap.getNumDims(),
+ initialMap.getNumSymbols());
return success(*map != initialMap);
}
@@ -1127,11 +1149,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
if (!v)
return failure();
- auto minOp = v.getDefiningOp<AffineMinOp>();
- if (minOp && replaceAffineMin) {
+ if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
: getAffineSymbolExpr(pos, ctx);
- return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map);
+ return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
+ syms);
}
auto affineApply = v.getDefiningOp<AffineApplyOp>();
diff --git a/mlir/test/Transforms/make-composed-folded-affine-apply.mlir b/mlir/test/Transforms/make-composed-folded-affine-apply.mlir
new file mode 100644
index 0000000000000..138426827fa15
--- /dev/null
+++ b/mlir/test/Transforms/make-composed-folded-affine-apply.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt --transform-interpreter %s | FileCheck %s
+
+#map = affine_map<()[s0, s1] -> (s0, s1, 128)>
+#map1 = affine_map<()[s0, s1] -> (s0 ceildiv 128 + s0 ceildiv s1)>
+#map2 = affine_map<()[s0, s1, s2] -> (s0, s1 + s2)>
+#map3 = affine_map<()[s0, s1, s2, s3] -> (3 * (s0 ceildiv s3) + s0 ceildiv (s1 + s2))>
+#map4 = affine_map<()[s0, s1] -> (s1)>
+#map5 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map6 = affine_map<()[s0, s1] -> (s0, s1, -128)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 ceildiv 128 + s0 ceildiv s1)>
+// CHECK-DAG: #[[MAP5:.*]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+
+// These test checks the `affine::makeComposedFoldedAffineApply` function when
+// `composeAffineMin == true`.
+
+// Check the apply gets simplified.
+// CHECK: @apply_simplification
+func.func @apply_simplification_1() -> index {
+ %0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %2 = affine.min #map()[%0, %1]
+ // CHECK-NOT: affine.apply
+ // CHECK: arith.constant 2 : index
+ %3 = affine.apply #map1()[%2, %1]
+ return %3 : index
+}
+
+// Check the simplification can match non-trivial affine expressions like s1 + s2.
+func.func @apply_simplification_2() -> index {
+ %0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %2 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %3 = affine.min #map2()[%0, %1, %2]
+ // CHECK-NOT: affine.apply
+ // CHECK: arith.constant 4 : index
+ %4 = affine.apply #map3()[%3, %1, %2, %0]
+ return %4 : index
+}
+
+// Check there's no simplification.
+// The apply cannot be simplified because `s1 = %0` doesn't appear in the input min.
+// CHECK: @no_simplification_0
+func.func @no_simplification_0() -> index {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 64 : index, min = 16 : index}
+ // CHECK: %[[V2:.*]] = affine.min #{{.*}}()[%[[V0]], %[[V1]]]
+ // CHECK: %[[V3:.*]] = affine.apply #[[MAP5]]()[%[[V2]], %[[V0]]]
+ // CHECK: return %[[V3]] : index
+ %0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %1 = test.value_with_bounds {max = 64 : index, min = 16 : index}
+ %2 = affine.min #map4()[%0, %1]
+ %3 = affine.apply #map5()[%2, %0]
+ return %3 : index
+}
+
+// The apply cannot be simplified because the min cannot be proven to be greater than 0.
+// CHECK: @no_simplification_1
+func.func @no_simplification_1() -> index {
+ // CHECK: %[[V0:.*]] = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ // CHECK: %[[V1:.*]] = test.value_with_bounds {max = 64 : index, min = 16 : index}
+ // CHECK: %[[V2:.*]] = affine.min #{{.*}}()[%[[V0]], %[[V1]]]
+ // CHECK: %[[V3:.*]] = affine.apply #[[MAP1]]()[%[[V2]], %[[V1]]]
+ // CHECK: return %[[V3]] : index
+ %0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
+ %1 = test.value_with_bounds {max = 64 : index, min = 16 : index}
+ %2 = affine.min #map6()[%0, %1]
+ %3 = affine.apply #map1()[%2, %1]
+ return %3 : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["affine.apply"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.test.make_composed_folded_affine_apply %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index c05b32bed9b94..9a5632bb99c06 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -11,9 +11,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/RegionUtils.h"
#define GET_OP_CLASSES
@@ -56,6 +58,33 @@ transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// Test affine functionality.
+//===----------------------------------------------------------------------===//
+DiagnosedSilenceableFailure
+transform::TestMakeComposedFoldedAffineApply::applyToOne(
+ TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
+ ApplyToEachResultList &results, TransformState &state) {
+ Location loc = affineApplyOp.getLoc();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, affineApplyOp.getAffineMap(),
+ getAsOpFoldResult(affineApplyOp.getOperands()),
+ /*composeAffineMin=*/true);
+ Value result;
+ if (auto v = dyn_cast<Value>(ofr)) {
+ result = v;
+ } else {
+ result = rewriter.create<arith::ConstantIndexOp>(
+ loc, getConstantIntValue(ofr).value());
+ }
+ results.push_back(result.getDefiningOp());
+ rewriter.replaceOp(affineApplyOp, result);
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Extension
+//===----------------------------------------------------------------------===//
namespace {
class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index 495579b452dfc..9b0a26082490c 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -59,5 +59,32 @@ def TestMoveValueDefns :
}];
}
+//===----------------------------------------------------------------------===//
+// Test affine functionality.
+//===----------------------------------------------------------------------===//
+
+def TestMakeComposedFoldedAffineApply :
+ Op<Transform_Dialect, "test.make_composed_folded_affine_apply",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Rewrite an affine_apply by using the makeComposedFoldedAffineApply API.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$op);
+ let results = (outs TransformHandleTypeInterface:$composed);
+ let assemblyFormat = [{
+ $op attr-dict `:` functional-type(operands, results)
+ }];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::affine::AffineApplyOp affineApplyOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
#endif // TEST_TRANSFORM_OPS
More information about the Mlir-commits
mailing list