[Mlir-commits] [mlir] 504a751 - [mlir][linalg][transform] Add structured.replace op

Matthias Springer llvmlistbot at llvm.org
Thu Dec 1 00:04:44 PST 2022


Author: Matthias Springer
Date: 2022-12-01T09:04:35+01:00
New Revision: 504a7516a1d1b459375a551d0b4fa0201428680b

URL: https://github.com/llvm/llvm-project/commit/504a7516a1d1b459375a551d0b4fa0201428680b
DIFF: https://github.com/llvm/llvm-project/commit/504a7516a1d1b459375a551d0b4fa0201428680b.diff

LOG: [mlir][linalg][transform] Add structured.replace op

This op is useful for debugging/experiments and allows users to replace ops (without arguments + IsolatedFromAbove) with the given op in the region of transform op.

Differential Revision: https://reviews.llvm.org/D139026

Added: 
    mlir/test/Dialect/Linalg/transform-op-replace.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index ac1af6c93cc93..fd3a33cc0fb14 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
 
 namespace mlir {
 class TilingInterface;

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 64946bc5e81e6..3ea0a66625776 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
+include "mlir/IR/RegionKindInterface.td"
 
 def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
@@ -387,6 +388,28 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
   }];
 }
 
+def ReplaceOp : Op<Transform_Dialect, "structured.replace",
+    [IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
+  let description = [{
+    Replace all `target` payload ops with the single op that is contained in
+    this op's region. All targets must have zero arguments and must be isolated
+    from above.
+
+    This op is for debugging/experiments only.
+
+    #### Return modes
+
+    This operation consumes the `target` handle.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$replacement);
+  let regions = (region SizedRegion<1>:$bodyRegion);
+  let assemblyFormat = "$target attr-dict-with-keyword regions";
+  let hasVerifier = 1;
+}
+
 def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e6123a4f17749..96a386e484e5a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/StringSet.h"
@@ -883,6 +884,64 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
   return DiagnosedSilenceableFailure(success());
 }
 
+//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplaceOp::apply(TransformResults &transformResults,
+                            TransformState &state) {
+  ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+
+  // Check for invalid targets.
+  for (Operation *target : payload) {
+    if (target->getNumOperands() > 0)
+      return emitDefiniteFailure() << "expected target without operands";
+    if (!target->hasTrait<IsIsolatedFromAbove>() && target->getNumRegions() > 0)
+      return emitDefiniteFailure()
+             << "expected target that is isloated from above";
+  }
+
+  // Clone and replace.
+  IRRewriter rewriter(getContext());
+  Operation *pattern = &getBodyRegion().front().front();
+  SmallVector<Operation *> replacements;
+  for (Operation *target : payload) {
+    if (getOperation()->isAncestor(target))
+      continue;
+    rewriter.setInsertionPoint(target);
+    Operation *replacement = rewriter.clone(*pattern);
+    rewriter.replaceOp(target, replacement->getResults());
+    replacements.push_back(replacement);
+  }
+  transformResults.set(getReplacement().cast<OpResult>(), replacements);
+  return DiagnosedSilenceableFailure(success());
+}
+
+void transform::ReplaceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  producesHandle(getReplacement(), effects);
+  modifiesPayload(effects);
+}
+
+LogicalResult transform::ReplaceOp::verify() {
+  if (!getBodyRegion().hasOneBlock())
+    return emitOpError() << "expected one block";
+  if (std::distance(getBodyRegion().front().begin(),
+                    getBodyRegion().front().end()) != 1)
+    return emitOpError() << "expected one operation in block";
+  Operation *replacement = &getBodyRegion().front().front();
+  if (replacement->getNumOperands() > 0)
+    return replacement->emitOpError()
+           << "expected replacement without operands";
+  if (!replacement->hasTrait<IsIsolatedFromAbove>() &&
+      replacement->getNumRegions() > 0)
+    return replacement->emitOpError()
+           << "expect op that is isolated from above";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ScalarizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
new file mode 100644
index 0000000000000..f8ed7712b68e0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -test-transform-dialect-interpreter %s -allow-unregistered-dialect -verify-diagnostics --split-input-file | FileCheck %s
+
+// CHECK: func.func @foo() {
+// CHECK:   "dummy_op"() : () -> ()
+// CHECK: }
+// CHECK-NOT: func.func @bar
+func.func @bar() {
+  "another_op"() : () -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+  transform.structured.replace %0 {
+    func.func @foo() {
+      "dummy_op"() : () -> ()
+    }
+  }
+}
+
+// -----
+
+func.func @bar(%arg0: i1) {
+  "another_op"(%arg0) : (i1) -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["another_op"]} in %arg1
+  // expected-error @+1 {{expected target without operands}}
+  transform.structured.replace %0 {
+    "dummy_op"() : () -> ()
+  }
+}
+
+// -----
+
+func.func @bar() {
+  "another_op"() : () -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["another_op"]} in %arg1
+  transform.structured.replace %0 {
+  ^bb0(%a: i1):
+    // expected-error @+1 {{expected replacement without operands}}
+    "dummy_op"(%a) : (i1) -> ()
+  }
+}


        


More information about the Mlir-commits mailing list