[Mlir-commits] [mlir] fcd4778 - [mlir][affine][transform] Simplify affine.min/max ops with given constraints

Matthias Springer llvmlistbot at llvm.org
Fri Jan 13 01:29:00 PST 2023


Author: Matthias Springer
Date: 2023-01-13T10:28:51+01:00
New Revision: fcd4778bdf4f4767d7c508da59df94d320850c46

URL: https://github.com/llvm/llvm-project/commit/fcd4778bdf4f4767d7c508da59df94d320850c46
DIFF: https://github.com/llvm/llvm-project/commit/fcd4778bdf4f4767d7c508da59df94d320850c46.diff

LOG: [mlir][affine][transform] Simplify affine.min/max ops with given constraints

This transform op uses `mlir::simplifyConstrainedMinMaxOp` to simplify `affine.min` and `affine.max` ops based on a given constraints.

Differential Revision: https://reviews.llvm.org/D140997

Added: 
    mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
    mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
    mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
index c1d67c7869646..fbb3868ccbe92 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H
 #define MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H
 
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/OpImplementation.h"

diff  --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
index e2b7e50ef8cd7..1ab7bd3628e7c 100644
--- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
+++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td
@@ -9,6 +9,7 @@
 #ifndef AFFINE_TRANSFORM_OPS
 #define AFFINE_TRANSFORM_OPS
 
+include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
@@ -18,4 +19,49 @@ include "mlir/IR/OpBase.td"
 
 def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">;
 
+def SimplifyBoundedAffineOpsOp
+    : Op<Transform_Dialect, "affine.simplify_bounded_affine_ops",
+         [DeclareOpInterfaceMethods<TransformOpInterface>,
+          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+    Simplify the targeted affine.min / affine.max ops given the supplied
+    lower and upper bounds for values that may be used as target op operands.
+
+    Example:
+    ```
+    %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1
+    %1 = transform.structured.match ops{["gpu.lane_id"]} in %arg1
+    transform.affine.simplify_bounded_affine_ops %0 with [%1] within [0] and [32]
+
+    // Multiple bounds can be specified.
+    transform.affine.simplify_bounded_affine_ops %0 with [%1, %2] within [0, 5] and [32, 50]
+    ```
+
+    Bounded op handles (`%1` and `%2) must be mapped to ops that have a single
+    result of index type. The sets of target ops and bounded ops must not
+    overlap.
+
+    #### Return modes
+
+    Target ops must be affine.min or affine.max ops. This transform consumes the
+    target handle and does not produce any handle. It reads the bounded op
+    handles.
+
+    TODO: Support affine.apply targets.
+    TODO: Allow mixed PDL_Operation/int64_t for lower_bounds and upper_bounds.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       Variadic<PDL_Operation>:$bounded_values,
+                       DenseI64ArrayAttr:$lower_bounds,
+                       DenseI64ArrayAttr:$upper_bounds);
+  let results = (outs);
+  let hasVerifier = 1;
+
+  let assemblyFormat = [{
+      $target `with` `[` $bounded_values `]`
+          `within` $lower_bounds `and` $upper_bounds attr-dict
+  }];
+}
+
 #endif // Affine_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3ac25b7f5d33c..db76278b7322d 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3255,7 +3255,6 @@ struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
     AffineMap map = affineOp.getAffineMap();
     if (failed(canonicalizeMapExprAndTermOrder(map)))
       return failure();
-
     rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
     return success();
   }

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index eafc6d9eafabc..24ed10eb593d3 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -7,13 +7,141 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
+using namespace mlir::transform;
+
+//===----------------------------------------------------------------------===//
+// SimplifyBoundedAffineOpsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SimplifyBoundedAffineOpsOp::verify() {
+  if (getLowerBounds().size() != getBoundedValues().size())
+    return emitOpError() << "incorrect number of lower bounds, expected "
+                         << getBoundedValues().size() << " but found "
+                         << getLowerBounds().size();
+  if (getUpperBounds().size() != getBoundedValues().size())
+    return emitOpError() << "incorrect number of upper bounds, expected "
+                         << getBoundedValues().size() << " but found "
+                         << getUpperBounds().size();
+  return success();
+}
+
+namespace {
+/// Simplify affine.min / affine.max ops with the given constraints. They are
+/// either rewritten to affine.apply or left unchanged.
+template <typename OpTy>
+struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+  SimplifyAffineMinMaxOp(MLIRContext *ctx,
+                         const FlatAffineValueConstraints &constraints,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<AffineValueMap> simplified =
+        simplifyConstrainedMinMaxOp(op, constraints);
+    if (failed(simplified))
+      return failure();
+    rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
+                                               simplified->getOperands());
+    return success();
+  }
+
+  const FlatAffineValueConstraints &constraints;
+};
+} // namespace
+
+DiagnosedSilenceableFailure
+SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
+                                  TransformState &state) {
+  // Get constraints for bounded values.
+  SmallVector<int64_t> lbs;
+  SmallVector<int64_t> ubs;
+  SmallVector<Value> boundedValues;
+  DenseSet<Operation *> boundedOps;
+  for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
+                                        getUpperBounds())) {
+    Value handle = std::get<0>(it);
+    ArrayRef<Operation *> boundedValueOps = state.getPayloadOps(handle);
+    for (Operation *op : boundedValueOps) {
+      if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+        auto diag =
+            emitDefiniteFailure()
+            << "expected bounded value handle to point to one or multiple "
+               "single-result index-typed ops";
+        diag.attachNote(op->getLoc()) << "multiple/non-index result";
+        return diag;
+      }
+      boundedValues.push_back(op->getResult(0));
+      boundedOps.insert(op);
+      lbs.push_back(std::get<1>(it));
+      ubs.push_back(std::get<2>(it));
+    }
+  }
+
+  // Build constraint set.
+  FlatAffineValueConstraints cstr;
+  for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
+    unsigned pos;
+    if (!cstr.findVar(std::get<0>(it), &pos))
+      pos = cstr.appendSymbolVar(std::get<0>(it));
+    cstr.addBound(FlatAffineValueConstraints::BoundType::LB, pos,
+                  std::get<1>(it));
+    // Note: addBound bounds are inclusive, but specified UB is exclusive.
+    cstr.addBound(FlatAffineValueConstraints::BoundType::UB, pos,
+                  std::get<2>(it) - 1);
+  }
+
+  // Transform all targets.
+  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+  for (Operation *target : targets) {
+    if (!isa<AffineMinOp, AffineMaxOp>(target)) {
+      auto diag = emitDefiniteFailure()
+                  << "target must be affine.min or affine.max";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+    if (boundedOps.contains(target)) {
+      auto diag = emitDefiniteFailure()
+                  << "target op result must not be constrainted";
+      diag.attachNote(target->getLoc()) << "target/constrained op";
+      return diag;
+    }
+  }
+  SmallVector<Operation *> transformed;
+  RewritePatternSet patterns(getContext());
+  // Canonicalization patterns are needed so that affine.apply ops are composed
+  // with the remaining affine.min/max ops.
+  AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
+  AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
+  patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
+                  SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+  // Apply the simplification pattern to a fixpoint.
+  (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void SimplifyBoundedAffineOpsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  for (Value v : getBoundedValues())
+    onlyReadsHandle(v, effects);
+  modifiesPayload(effects);
+}
 
 //===----------------------------------------------------------------------===//
 // Transform op registration

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
index 740563af5afed..24e2c8378bb3a 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/TransformOps/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRAffineTransformOps
   MLIRAffineTransformOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineAnalysis
   MLIRAffineDialect
   MLIRFuncDialect
   MLIRIR

diff  --git a/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir
new file mode 100644
index 0000000000000..ec9559c081bcf
--- /dev/null
+++ b/mlir/test/Dialect/Affine/transform-op-simplify-bounded-affine-ops.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt  %s -allow-unregistered-dialect \
+// RUN:     --test-transform-dialect-interpreter -verify-diagnostics \
+// RUN:     --split-input-file | FileCheck %s
+
+//     CHECK: func @simplify_min_max()
+// CHECK-DAG:   %[[c50:.*]] = arith.constant 50 : index
+// CHECK-DAG:   %[[c100:.*]] = arith.constant 100 : index
+//     CHECK:   return %[[c50]], %[[c100]]
+func.func @simplify_min_max() -> (index, index) {
+  %0 = "test.some_op"() : () -> (index)
+  %1 = affine.min affine_map<()[s0] -> (50, 100 - s0)>()[%0]
+  %2 = affine.max affine_map<()[s0] -> (100, 80 + s0)>()[%0]
+  return %1, %2 : index, index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["affine.min", "affine.max"]} in %arg1
+  %1 = transform.structured.match ops{["test.some_op"]} in %arg1
+  transform.affine.simplify_bounded_affine_ops %0 with [%1] within [0] and [20]
+}
+
+// -----
+
+// CHECK: func @simplify_min_sequence()
+// CHECK:   %[[c1:.*]] = arith.constant 1 : index
+// CHECK:   return %[[c1]]
+func.func @simplify_min_sequence() -> index {
+  %1 = "test.workgroup_id"() : () -> (index)
+  %2 = affine.min affine_map<()[s0] -> (s0 * -32 + 1023, 32)>()[%1]
+  %3 = "test.thread_id"() : () -> (index)
+  %4 = affine.min affine_map<()[s0, s1] -> (s0 - s1 * (s0 ceildiv 32), s0 ceildiv 32)>()[%2, %3]
+  return %4 : index
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["affine.min"]} in %arg1
+  %1 = transform.structured.match ops{["test.workgroup_id"]} in %arg1
+  %2 = transform.structured.match ops{["test.thread_id"]} in %arg1
+  transform.affine.simplify_bounded_affine_ops %0 with [%1, %2] within [0, 0] and [31, 31]
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["affine.min"]} in %arg1
+  // expected-error at +1 {{incorrect number of lower bounds, expected 0 but found 1}}
+  transform.affine.simplify_bounded_affine_ops %0 with [] within [0] and []
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["affine.min"]} in %arg1
+  // expected-error at +1 {{incorrect number of upper bounds, expected 0 but found 1}}
+  transform.affine.simplify_bounded_affine_ops %0 with [] within [] and [5]
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f59f22e8a5963..3bf5079820aa7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1201,13 +1201,16 @@ cc_library(
     hdrs = glob(["include/mlir/Dialect/Affine/TransformOps/*.h"]),
     includes = ["include"],
     deps = [
+        ":AffineAnalysis",
         ":AffineDialect",
         ":AffineTransformOpsIncGen",
         ":AffineTransforms",
         ":AffineUtils",
         ":FuncDialect",
         ":IR",
+        ":PDLDialect",
         ":TransformDialect",
+        ":Transforms",
         ":VectorDialect",
     ],
 )


        


More information about the Mlir-commits mailing list