[Mlir-commits] [mlir] fcd4778 - [mlir][affine][transform] Simplify affine.min/max ops with given constraints
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 13 01:29:00 PST 2023
Author: Matthias Springer
Date: 2023-01-13T10:28:51+01:00
New Revision: fcd4778bdf4f4767d7c508da59df94d320850c46
URL: https://github.com/llvm/llvm-project/commit/fcd4778bdf4f4767d7c508da59df94d320850c46
DIFF: https://github.com/llvm/llvm-project/commit/fcd4778bdf4f4767d7c508da59df94d320850c46.diff
LOG: [mlir][affine][transform] Simplify affine.min/max ops with given constraints
This transform op uses `mlir::simplifyConstrainedMinMaxOp` to simplify `affine.min` and `affine.max` ops based on a given constraints.
Differential Revision: https://reviews.llvm.org/D140997
Added:
mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir
Modified:
mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
index c1d67c7869646..fbb3868ccbe92 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H
#define MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index e2b7e50ef8cd7..1ab7bd3628e7c 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -9,6 +9,7 @@
#ifndef AFFINE_TRANSFORM_OPS
#define AFFINE_TRANSFORM_OPS
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
@@ -18,4 +19,49 @@ include "mlir/IR/OpBase.td"
def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">;
+def SimplifyBoundedAffineOpsOp
+ : Op<Transform_Dialect, "affine.simplify_bounded_affine_ops",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let description = [{
+ Simplify the targeted affine.min / affine.max ops given the supplied
+ lower and upper bounds for values that may be used as target op operands.
+
+ Example:
+ ```
+ %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1
+ %1 = transform.structured.match ops{["gpu.lane_id"]} in %arg1
+ transform.affine.simplify_bounded_affine_ops %0 with [%1] within [0] and [32]
+
+ // Multiple bounds can be specified.
+ transform.affine.simplify_bounded_affine_ops %0 with [%1, %2] within [0, 5] and [32, 50]
+ ```
+
+ Bounded op handles (`%1` and `%2) must be mapped to ops that have a single
+ result of index type. The sets of target ops and bounded ops must not
+ overlap.
+
+ #### Return modes
+
+ Target ops must be affine.min or affine.max ops. This transform consumes the
+ target handle and does not produce any handle. It reads the bounded op
+ handles.
+
+ TODO: Support affine.apply targets.
+ TODO: Allow mixed PDL_Operation/int64_t for lower_bounds and upper_bounds.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ Variadic<PDL_Operation>:$bounded_values,
+ DenseI64ArrayAttr:$lower_bounds,
+ DenseI64ArrayAttr:$upper_bounds);
+ let results = (outs);
+ let hasVerifier = 1;
+
+ let assemblyFormat = [{
+ $target `with` `[` $bounded_values `]`
+ `within` $lower_bounds `and` $upper_bounds attr-dict
+ }];
+}
+
#endif // Affine_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3ac25b7f5d33c..db76278b7322d 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3255,7 +3255,6 @@ struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
AffineMap map = affineOp.getAffineMap();
if (failed(canonicalizeMapExprAndTermOrder(map)))
return failure();
-
rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
return success();
}
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index eafc6d9eafabc..24ed10eb593d3 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -7,13 +7,141 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
+using namespace mlir::transform;
+
+//===----------------------------------------------------------------------===//
+// SimplifyBoundedAffineOpsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SimplifyBoundedAffineOpsOp::verify() {
+ if (getLowerBounds().size() != getBoundedValues().size())
+ return emitOpError() << "incorrect number of lower bounds, expected "
+ << getBoundedValues().size() << " but found "
+ << getLowerBounds().size();
+ if (getUpperBounds().size() != getBoundedValues().size())
+ return emitOpError() << "incorrect number of upper bounds, expected "
+ << getBoundedValues().size() << " but found "
+ << getUpperBounds().size();
+ return success();
+}
+
+namespace {
+/// Simplify affine.min / affine.max ops with the given constraints. They are
+/// either rewritten to affine.apply or left unchanged.
+template <typename OpTy>
+struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+ SimplifyAffineMinMaxOp(MLIRContext *ctx,
+ const FlatAffineValueConstraints &constraints,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<AffineValueMap> simplified =
+ simplifyConstrainedMinMaxOp(op, constraints);
+ if (failed(simplified))
+ return failure();
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
+ simplified->getOperands());
+ return success();
+ }
+
+ const FlatAffineValueConstraints &constraints;
+};
+} // namespace
+
+DiagnosedSilenceableFailure
+SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
+ TransformState &state) {
+ // Get constraints for bounded values.
+ SmallVector<int64_t> lbs;
+ SmallVector<int64_t> ubs;
+ SmallVector<Value> boundedValues;
+ DenseSet<Operation *> boundedOps;
+ for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
+ getUpperBounds())) {
+ Value handle = std::get<0>(it);
+ ArrayRef<Operation *> boundedValueOps = state.getPayloadOps(handle);
+ for (Operation *op : boundedValueOps) {
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ auto diag =
+ emitDefiniteFailure()
+ << "expected bounded value handle to point to one or multiple "
+ "single-result index-typed ops";
+ diag.attachNote(op->getLoc()) << "multiple/non-index result";
+ return diag;
+ }
+ boundedValues.push_back(op->getResult(0));
+ boundedOps.insert(op);
+ lbs.push_back(std::get<1>(it));
+ ubs.push_back(std::get<2>(it));
+ }
+ }
+
+ // Build constraint set.
+ FlatAffineValueConstraints cstr;
+ for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
+ unsigned pos;
+ if (!cstr.findVar(std::get<0>(it), &pos))
+ pos = cstr.appendSymbolVar(std::get<0>(it));
+ cstr.addBound(FlatAffineValueConstraints::BoundType::LB, pos,
+ std::get<1>(it));
+ // Note: addBound bounds are inclusive, but specified UB is exclusive.
+ cstr.addBound(FlatAffineValueConstraints::BoundType::UB, pos,
+ std::get<2>(it) - 1);
+ }
+
+ // Transform all targets.
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+ for (Operation *target : targets) {
+ if (!isa<AffineMinOp, AffineMaxOp>(target)) {
+ auto diag = emitDefiniteFailure()
+ << "target must be affine.min or affine.max";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ if (boundedOps.contains(target)) {
+ auto diag = emitDefiniteFailure()
+ << "target op result must not be constrainted";
+ diag.attachNote(target->getLoc()) << "target/constrained op";
+ return diag;
+ }
+ }
+ SmallVector<Operation *> transformed;
+ RewritePatternSet patterns(getContext());
+ // Canonicalization patterns are needed so that affine.apply ops are composed
+ // with the remaining affine.min/max ops.
+ AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
+ AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
+ patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
+ SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ // Apply the simplification pattern to a fixpoint.
+ (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyBoundedAffineOpsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ for (Value v : getBoundedValues())
+ onlyReadsHandle(v, effects);
+ modifiesPayload(effects);
+}
//===----------------------------------------------------------------------===//
// Transform op registration
diff --git a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
index 740563af5afed..24e2c8378bb3a 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRAffineTransformOps
MLIRAffineTransformOpsIncGen
LINK_LIBS PUBLIC
+ MLIRAffineAnalysis
MLIRAffineDialect
MLIRFuncDialect
MLIRIR
diff --git a/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir
new file mode 100644
index 0000000000000..ec9559c081bcf
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect \
+// RUN: --test-transform-dialect-interpreter -verify-diagnostics \
+// RUN: --split-input-file | FileCheck %s
+
+// CHECK: func @simplify_min_max()
+// CHECK-DAG: %[[c50:.*]] = arith.constant 50 : index
+// CHECK-DAG: %[[c100:.*]] = arith.constant 100 : index
+// CHECK: return %[[c50]], %[[c100]]
+func.func @simplify_min_max() -> (index, index) {
+ %0 = "test.some_op"() : () -> (index)
+ %1 = affine.min affine_map<()[s0] -> (50, 100 - s0)>()[%0]
+ %2 = affine.max affine_map<()[s0] -> (100, 80 + s0)>()[%0]
+ return %1, %2 : index, index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1
+ %1 = transform.structured.match ops{["test.some_op"]} in %arg1
+ transform.affine.simplify_bounded_affine_ops %0 with [%1] within [0] and [20]
+}
+
+// -----
+
+// CHECK: func @simplify_min_sequence()
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: return %[[c1]]
+func.func @simplify_min_sequence() -> index {
+ %1 = "test.workgroup_id"() : () -> (index)
+ %2 = affine.min affine_map<()[s0] -> (s0 * -32 + 1023, 32)>()[%1]
+ %3 = "test.thread_id"() : () -> (index)
+ %4 = affine.min affine_map<()[s0, s1] -> (s0 - s1 * (s0 ceildiv 32), s0 ceildiv 32)>()[%2, %3]
+ return %4 : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["affine.min"]} in %arg1
+ %1 = transform.structured.match ops{["test.workgroup_id"]} in %arg1
+ %2 = transform.structured.match ops{["test.thread_id"]} in %arg1
+ transform.affine.simplify_bounded_affine_ops %0 with [%1, %2] within [0, 0] and [31, 31]
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["affine.min"]} in %arg1
+ // expected-error at +1 {{incorrect number of lower bounds, expected 0 but found 1}}
+ transform.affine.simplify_bounded_affine_ops %0 with [] within [0] and []
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["affine.min"]} in %arg1
+ // expected-error at +1 {{incorrect number of upper bounds, expected 0 but found 1}}
+ transform.affine.simplify_bounded_affine_ops %0 with [] within [] and [5]
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f59f22e8a5963..3bf5079820aa7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1201,13 +1201,16 @@ cc_library(
hdrs = glob(["include/mlir/Dialect/Affine/TransformOps/*.h"]),
includes = ["include"],
deps = [
+ ":AffineAnalysis",
":AffineDialect",
":AffineTransformOpsIncGen",
":AffineTransforms",
":AffineUtils",
":FuncDialect",
":IR",
+ ":PDLDialect",
":TransformDialect",
+ ":Transforms",
":VectorDialect",
],
)
More information about the Mlir-commits
mailing list