[llvm-branch-commits] [mlir] [mlir][Transform] Mapping update rules for `apply_conversion_patterns` (PR #84140)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Mar 9 19:15:31 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84140

>From e2b6b753bad33cbd03b79d3b9b4c2f0cabfbab8d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 8 Mar 2024 02:09:29 +0000
Subject: [PATCH] [mlir][Transform] Specify mapping update rules for
 `apply_conversion_patterns`

---
 .../Transform/IR/TransformInterfaces.h        |  51 +++++--
 .../mlir/Dialect/Transform/IR/TransformOps.td |  11 ++
 .../Transform/IR/TransformInterfaces.cpp      |  46 +-----
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 142 +++++++++++++++++-
 mlir/test/Dialect/Transform/ops-invalid.mlir  |  22 +++
 .../Transform/test-pattern-application.mlir   |  39 +++++
 6 files changed, 256 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 32724ff4b98e8e..5db1a2c28fd414 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1026,7 +1026,7 @@ class TrackingListener : public RewriterBase::Listener,
   /// Return the transform op in which this TrackingListener is used.
   TransformOpInterface getTransformOp() const { return transformOp; }
 
-private:
+protected:
   friend class TransformRewriter;
 
   void notifyOperationErased(Operation *op) override;
@@ -1034,6 +1034,7 @@ class TrackingListener : public RewriterBase::Listener,
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
   using Listener::notifyOperationReplaced;
 
+private:
   /// The transform op in which this TrackingListener is used.
   TransformOpInterface transformOp;
 
@@ -1047,23 +1048,48 @@ class TrackingListener : public RewriterBase::Listener,
 /// A specialized listener that keeps track of cases in which no replacement
 /// payload could be found. The error state of this listener must be checked
 /// before the end of its lifetime.
-class ErrorCheckingTrackingListener : public TrackingListener {
+template <typename TrackingListenerTy>
+class ErrorCheckingTrackingListener : public TrackingListenerTy {
 public:
-  using transform::TrackingListener::TrackingListener;
+  using TrackingListenerTy::TrackingListenerTy;
 
-  ~ErrorCheckingTrackingListener() override;
+  ~ErrorCheckingTrackingListener() override {
+    // The state of the ErrorCheckingTrackingListener must be checked and reset
+    // if there was an error. This is to prevent errors from accidentally being
+    // missed.
+    assert(status.succeeded() && "listener state was not checked");
+  }
 
   /// Check and return the current error state of this listener. Afterwards,
   /// resets the error state to "success".
-  DiagnosedSilenceableFailure checkAndResetError();
+  DiagnosedSilenceableFailure checkAndResetError() {
+    DiagnosedSilenceableFailure s = std::move(status);
+    status = DiagnosedSilenceableFailure::success();
+    errorCounter = 0;
+    return s;
+  }
 
   /// Return "true" if this tracking listener had a failure.
-  bool failed() const;
+  bool failed() const { return !status.succeeded(); }
 
 protected:
-  void
-  notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
-                                   DiagnosedSilenceableFailure &&diag) override;
+  void notifyPayloadReplacementNotFound(
+      Operation *op, ValueRange values,
+      DiagnosedSilenceableFailure &&diag) override {
+    // Merge potentially existing diags and store the result in the listener.
+    SmallVector<Diagnostic> diags;
+    diag.takeDiagnostics(diags);
+    if (!status.succeeded())
+      status.takeDiagnostics(diags);
+    status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
+
+    // Report more details.
+    status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
+    for (auto &&[index, value] : llvm::enumerate(values))
+      status.attachNote(value.getLoc())
+          << "[" << errorCounter << "] replacement value " << index;
+    ++errorCounter;
+  }
 
 private:
   /// The error state of this listener. "Success" indicates that no error
@@ -1082,8 +1108,9 @@ class TransformRewriter : public RewriterBase {
   friend class TransformState;
 
   /// Create a new TransformRewriter.
-  explicit TransformRewriter(MLIRContext *ctx,
-                             ErrorCheckingTrackingListener *listener);
+  explicit TransformRewriter(
+      MLIRContext *ctx,
+      ErrorCheckingTrackingListener<TrackingListener> *listener);
 
 public:
   /// Return "true" if the tracking listener had failures.
@@ -1106,7 +1133,7 @@ class TransformRewriter : public RewriterBase {
                                                Operation *replacement);
 
 private:
-  ErrorCheckingTrackingListener *const listener;
+  ErrorCheckingTrackingListener<TrackingListener> *const listener;
 };
 
 /// This trait is supposed to be attached to Transform dialect operations that
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 1766e4bb875f32..686a51bf7f9d36 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -203,6 +203,16 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
     lower ops to different ops (from a different dialect). More details can be
     found at the documentation site of `TrackingListener`.
 
+    The way op handles are updated can be customized with `find_replacements`.
+    If `find_replacements` is set, replacement ops are *not* deduced from the
+    replacement SSA values. The `find_replacements` dictionary attribute
+    specifies the kind of op that should be considered as a replacement for a
+    replaced tracked op. E.g., "arith.mulf => llvm.fmul" specifies that the
+    replacement op for a tracked "arith.mulf" must be an "llvm.fmul" op that was
+    created in the same pattern that replaced the "arith.mulf" op. If there is
+    no such op or if there are multiple such ops, a tracking listener failure
+    is produced.
+
     This transform produces a silenceable failure if the dialect conversion was
     unsuccessful or the tracking listener failed to find a replacement op.
   }];
@@ -212,6 +222,7 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
                        OptionalAttr<StrArrayAttr>:$illegal_ops,
                        OptionalAttr<StrArrayAttr>:$legal_dialects,
                        OptionalAttr<StrArrayAttr>:$illegal_dialects,
+                       OptionalAttr<DictionaryAttr>:$find_replacements,
                        UnitAttr:$partial_conversion,
                        UnitAttr:$preserve_handles);
   let results = (outs);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index fe2eea535ffdcf..92f59c47018f64 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -935,8 +935,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     }
     return true;
   };
-  transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
-                                                            config);
+  transform::ErrorCheckingTrackingListener<transform::TrackingListener>
+      trackingListener(*this, transform, config);
   transform::TransformRewriter rewriter(transform->getContext(),
                                         &trackingListener);
 
@@ -1214,11 +1214,10 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
     Operation *&result, Operation *op, ValueRange newValues) const {
   assert(op->getNumResults() == newValues.size() &&
          "invalid number of replacement values");
-  SmallVector<Value> values(newValues.begin(), newValues.end());
-
   DiagnosedSilenceableFailure diag = emitSilenceableFailure(
       getTransformOp(), "tracking listener failed to find replacement op "
                         "during application of this transform op");
+  SmallVector<Value> values(newValues.begin(), newValues.end());
 
   do {
     // If the replacement values belong to different ops, drop the mapping.
@@ -1349,49 +1348,12 @@ void transform::TrackingListener::notifyOperationReplaced(
   (void)replacePayloadOp(op, replacement);
 }
 
-transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
-  // The state of the ErrorCheckingTrackingListener must be checked and reset
-  // if there was an error. This is to prevent errors from accidentally being
-  // missed.
-  assert(status.succeeded() && "listener state was not checked");
-}
-
-DiagnosedSilenceableFailure
-transform::ErrorCheckingTrackingListener::checkAndResetError() {
-  DiagnosedSilenceableFailure s = std::move(status);
-  status = DiagnosedSilenceableFailure::success();
-  errorCounter = 0;
-  return s;
-}
-
-bool transform::ErrorCheckingTrackingListener::failed() const {
-  return !status.succeeded();
-}
-
-void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
-    Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
-
-  // Merge potentially existing diags and store the result in the listener.
-  SmallVector<Diagnostic> diags;
-  diag.takeDiagnostics(diags);
-  if (!status.succeeded())
-    status.takeDiagnostics(diags);
-  status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
-
-  // Report more details.
-  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
-  for (auto &&[index, value] : llvm::enumerate(values))
-    status.attachNote(value.getLoc())
-        << "[" << errorCounter << "] replacement value " << index;
-  ++errorCounter;
-}
-
 //===----------------------------------------------------------------------===//
 // TransformRewriter
 //===----------------------------------------------------------------------===//
 
 transform::TransformRewriter::TransformRewriter(
-    MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
+    MLIRContext *ctx, ErrorCheckingTrackingListener<TrackingListener> *listener)
     : RewriterBase(ctx), listener(listener) {
   setListener(listener);
 }
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ca80899ab07341..b73fceee7aba59 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -493,6 +493,125 @@ void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
 // ApplyConversionPatternsOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// A specialized tracking listener for dialect conversions. It can be
+/// configured with a "replacement mapping" that specifies how replacement ops
+/// for replaced tracked operations should be determined.
+class ConversionTrackingListener : public transform::TrackingListener {
+public:
+  ConversionTrackingListener(
+      transform::TransformState &state, transform::TransformOpInterface op,
+      transform::TrackingListenerConfig config,
+      const DenseMap<StringRef, StringRef> *replacementMapping)
+      : transform::TrackingListener(state, op, config),
+        replacementMapping(replacementMapping) {}
+
+  /// Instead of deducing the replacement op from the replacement values, the
+  /// replacement op is chosen among all ops that were created during the
+  /// current pattern application. E.g., a mapping of "arith.mulsi_extended ->
+  /// llvm.mul" indicates that tracked arith.mulsi_extended ops should be
+  /// updated to llvm.mul ops, assuming that an llvm.mul op was created in the
+  /// same pattern that replaced the arith.mulsi_extended op. If no such op or
+  /// multiple such ops were created, "nullptr" replacement op is returned.
+  ///
+  /// If no replacement mapping is set, fall back to the original mechanism of
+  /// `TrackingListener`.
+  DiagnosedSilenceableFailure
+  findReplacementOp(Operation *&result, Operation *op,
+                    ValueRange newValues) const override;
+
+protected:
+  void notifyOperationErased(Operation *op) override;
+
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
+
+  void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
+
+  void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override;
+
+  /// The root op of the pattern that is currently being applied or "nullptr" if
+  /// no pattern application is running.
+  Operation *rootOp = nullptr;
+
+  /// All ops that have been created during the current pattern application.
+  /// This set is maintained only if "config.replacementMapping" is set.
+  SmallVector<Operation *> createdOps;
+
+  /// A mapping that specifies how replacement ops should be
+  /// determined when a mapped op is replaced. If set to "nullptr", the default
+  /// lookup mechanism (i.e., op deduced from the replacement values) is used.
+  const DenseMap<StringRef, StringRef> *replacementMapping = nullptr;
+};
+} // namespace
+
+void ConversionTrackingListener::notifyOperationErased(Operation *op) {
+  TrackingListener::notifyOperationErased(op);
+
+  // Remove from created ops.
+  auto it = llvm::find(createdOps, op);
+  if (it != createdOps.end())
+    createdOps.erase(it);
+}
+
+void ConversionTrackingListener::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
+  if (replacementMapping)
+    createdOps.push_back(op);
+}
+
+void ConversionTrackingListener::notifyPatternBegin(const Pattern &pattern,
+                                                    Operation *op) {
+  assert(!rootOp && "expected that no other pattern is in progress");
+  rootOp = op;
+}
+
+void ConversionTrackingListener::notifyPatternEnd(const Pattern &pattern,
+                                                  LogicalResult status) {
+  rootOp = nullptr;
+  createdOps.clear();
+}
+
+DiagnosedSilenceableFailure
+ConversionTrackingListener::findReplacementOp(Operation *&result, Operation *op,
+                                              ValueRange newValues) const {
+  if (!replacementMapping)
+    return TrackingListener::findReplacementOp(result, op, newValues);
+
+  DiagnosedSilenceableFailure diag = emitSilenceableFailure(
+      getTransformOp(),
+      "conversion tracking listener failed to find replacement op during "
+      "application of this transform op");
+
+  auto it = replacementMapping->find(op->getName().getStringRef());
+  if (it == replacementMapping->end()) {
+    diag.attachNote(op->getLoc())
+        << "no mapping specified for '" << op->getName().getStringRef() << "'";
+    return diag;
+  }
+  StringRef replacementOpName = it->second;
+  Operation *replacementOp = nullptr;
+  for (Operation *op : createdOps) {
+    if (op->getName().getStringRef() == replacementOpName) {
+      if (replacementOp) {
+        diag.attachNote(op->getLoc()) << "multiple '" << replacementOpName
+                                      << "' replacement candidates found for '"
+                                      << op->getName().getStringRef() << "'";
+        return diag;
+      }
+      replacementOp = op;
+    }
+  }
+  if (!replacementOp) {
+    diag.attachNote(op->getLoc())
+        << "no replacement found for '" << op->getName().getStringRef()
+        << "', expected '" << replacementOpName << "'";
+    return diag;
+  }
+  result = replacementOp;
+  return DiagnosedSilenceableFailure::success();
+}
+
 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
@@ -523,6 +642,15 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
     for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
       conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
 
+  // Extract op replacement rules from attribute.
+  DenseMap<StringRef, StringRef> replacementMapping;
+  if (getFindReplacements()) {
+    DictionaryAttr mappingAttr = cast<DictionaryAttr>(*getFindReplacements());
+    for (auto it : mappingAttr)
+      replacementMapping[it.getName()] =
+          cast<StringAttr>(it.getValue()).getValue();
+  }
+
   // Gather all specified patterns.
   RewritePatternSet patterns(ctx);
   // Need to keep the converters alive until after pattern application because
@@ -569,7 +697,9 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
   // name.
   TrackingListenerConfig trackingConfig;
   trackingConfig.requireMatchingReplacementOpName = false;
-  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
+  ErrorCheckingTrackingListener<ConversionTrackingListener> trackingListener(
+      state, *this, trackingConfig,
+      replacementMapping.empty() ? nullptr : &replacementMapping);
   ConversionConfig conversionConfig;
   if (getPreserveHandles())
     conversionConfig.listener = &trackingListener;
@@ -658,6 +788,16 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
       }
     }
   }
+  if (getFindReplacements()) {
+    if (!getPreserveHandles())
+      return emitOpError() << "find_replacements requires preserve_handles";
+    auto mapping = cast<DictionaryAttr>(*getFindReplacements());
+    for (auto it : mapping) {
+      if (!isa<StringAttr>(it.getValue()))
+        return emitOpError() << "expected find_replacements to contain only "
+                                "StringAttr values";
+    }
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 73a5f36af92952..729645aca2f918 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -771,3 +771,25 @@ module attributes { transform.with_named_sequence } {
     transform.yield %arg0 : !transform.any_op
   }
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expected find_replacements to contain only StringAttr values}}
+  transform.apply_conversion_patterns to %arg0 {
+  } {legal_dialects = ["func", "llvm"], preserve_handles,
+     find_replacements = {"arith.muli" = 3}} : !transform.any_op
+  transform.yield
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{find_replacements requires preserve_handles}}
+  transform.apply_conversion_patterns to %arg0 {
+  } {legal_dialects = ["func", "llvm"],
+     find_replacements = {"arith.muli" = 3}} : !transform.any_op
+  transform.yield
+}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index fa8a555af92188..7ac2838bb95ccd 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -447,3 +447,42 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// "arith.mulsi_extended" is tracked and replaced with "llvm.mul" (and other
+// ops) during a dialect conversion. Make sure that the handle is updated
+// accordingly.
+
+// CHECK-LABEL: func @dialect_conversion_find_replacements(
+//  CHECK-SAME:     %[[arg0:.*]]: vector<4xi32>, %[[arg1:.*]]: vector<4xi32>)
+//       CHECK:   %[[VAL0:.*]] = llvm.sext %[[arg0]] : vector<4xi32> to vector<4xi64>
+//       CHECK:   %[[VAL1:.*]] = llvm.sext %[[arg1]] : vector<4xi32> to vector<4xi64>
+//       CHECK:   %[[VAL2:.*]] = llvm.mul %[[VAL0]], %[[VAL1]]  {annotated} : vector<4xi64>
+//       CHECK:   %[[VAL3:.*]] = llvm.trunc %[[VAL2]] : vector<4xi64> to vector<4xi32>
+//       CHECK:   %[[VAL4:.*]] = llvm.mlir.constant(dense<32> : vector<4xi64>) : vector<4xi64>
+//       CHECK:   %[[VAL5:.*]] = llvm.lshr %[[VAL2]], %[[VAL4]]  : vector<4xi64>
+//       CHECK:   %[[VAL6:.*]] = llvm.trunc %[[VAL5]] : vector<4xi64> to vector<4xi32>
+//       CHECK:   return %[[VAL3]], %[[VAL6]] : vector<4xi32>, vector<4xi32>
+func.func @dialect_conversion_find_replacements(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  %c:2 = arith.mulsi_extended %arg0, %arg1 : vector<4xi32>
+  return %c#0, %c#1 : vector<4xi32>, vector<4xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["arith.mulsi_extended"]} in %0 : (!transform.any_op) -> !transform.any_op
+    // arith.mulsi_extended handles are updated to llvm.mul.
+    transform.apply_conversion_patterns to %0 {
+      transform.apply_conversion_patterns.dialect_to_llvm "arith"
+    } with type_converter {
+      transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+    } {legal_dialects = ["func", "llvm"], preserve_handles,
+       find_replacements = {"arith.mulsi_extended" = "llvm.mul"}}
+        : !transform.any_op
+    // Add an attribute to %1, which is now mapped to a new op.
+    transform.annotate %1 "annotated" : !transform.any_op
+    transform.yield
+  }
+}



More information about the llvm-branch-commits mailing list