[Mlir-commits] [mlir] 504a751 - [mlir][linalg][transform] Add structured.replace op
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 1 00:04:44 PST 2022
Author: Matthias Springer
Date: 2022-12-01T09:04:35+01:00
New Revision: 504a7516a1d1b459375a551d0b4fa0201428680b
URL: https://github.com/llvm/llvm-project/commit/504a7516a1d1b459375a551d0b4fa0201428680b
DIFF: https://github.com/llvm/llvm-project/commit/504a7516a1d1b459375a551d0b4fa0201428680b.diff
LOG: [mlir][linalg][transform] Add structured.replace op
This op is useful for debugging/experiments and allows users to replace ops (without arguments + IsolatedFromAbove) with the given op in the region of transform op.
Differential Revision: https://reviews.llvm.org/D139026
Added:
mlir/test/Dialect/Linalg/transform-op-replace.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index ac1af6c93cc93..fd3a33cc0fb14 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
namespace mlir {
class TilingInterface;
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 64946bc5e81e6..3ea0a66625776 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
+include "mlir/IR/RegionKindInterface.td"
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
@@ -387,6 +388,28 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
}];
}
+def ReplaceOp : Op<Transform_Dialect, "structured.replace",
+ [IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
+ let description = [{
+ Replace all `target` payload ops with the single op that is contained in
+ this op's region. All targets must have zero arguments and must be isolated
+ from above.
+
+ This op is for debugging/experiments only.
+
+ #### Return modes
+
+ This operation consumes the `target` handle.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$replacement);
+ let regions = (region SizedRegion<1>:$bodyRegion);
+ let assemblyFormat = "$target attr-dict-with-keyword regions";
+ let hasVerifier = 1;
+}
+
def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e6123a4f17749..96a386e484e5a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
@@ -883,6 +884,64 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
return DiagnosedSilenceableFailure(success());
}
+//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ReplaceOp::apply(TransformResults &transformResults,
+ TransformState &state) {
+ ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+
+ // Check for invalid targets.
+ for (Operation *target : payload) {
+ if (target->getNumOperands() > 0)
+ return emitDefiniteFailure() << "expected target without operands";
+ if (!target->hasTrait<IsIsolatedFromAbove>() && target->getNumRegions() > 0)
+ return emitDefiniteFailure()
+ << "expected target that is isloated from above";
+ }
+
+ // Clone and replace.
+ IRRewriter rewriter(getContext());
+ Operation *pattern = &getBodyRegion().front().front();
+ SmallVector<Operation *> replacements;
+ for (Operation *target : payload) {
+ if (getOperation()->isAncestor(target))
+ continue;
+ rewriter.setInsertionPoint(target);
+ Operation *replacement = rewriter.clone(*pattern);
+ rewriter.replaceOp(target, replacement->getResults());
+ replacements.push_back(replacement);
+ }
+ transformResults.set(getReplacement().cast<OpResult>(), replacements);
+ return DiagnosedSilenceableFailure(success());
+}
+
+void transform::ReplaceOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
+ producesHandle(getReplacement(), effects);
+ modifiesPayload(effects);
+}
+
+LogicalResult transform::ReplaceOp::verify() {
+ if (!getBodyRegion().hasOneBlock())
+ return emitOpError() << "expected one block";
+ if (std::distance(getBodyRegion().front().begin(),
+ getBodyRegion().front().end()) != 1)
+ return emitOpError() << "expected one operation in block";
+ Operation *replacement = &getBodyRegion().front().front();
+ if (replacement->getNumOperands() > 0)
+ return replacement->emitOpError()
+ << "expected replacement without operands";
+ if (!replacement->hasTrait<IsIsolatedFromAbove>() &&
+ replacement->getNumRegions() > 0)
+ return replacement->emitOpError()
+ << "expect op that is isolated from above";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ScalarizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
new file mode 100644
index 0000000000000..f8ed7712b68e0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -test-transform-dialect-interpreter %s -allow-unregistered-dialect -verify-diagnostics --split-input-file | FileCheck %s
+
+// CHECK: func.func @foo() {
+// CHECK: "dummy_op"() : () -> ()
+// CHECK: }
+// CHECK-NOT: func.func @bar
+func.func @bar() {
+ "another_op"() : () -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ transform.structured.replace %0 {
+ func.func @foo() {
+ "dummy_op"() : () -> ()
+ }
+ }
+}
+
+// -----
+
+func.func @bar(%arg0: i1) {
+ "another_op"(%arg0) : (i1) -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["another_op"]} in %arg1
+ // expected-error @+1 {{expected target without operands}}
+ transform.structured.replace %0 {
+ "dummy_op"() : () -> ()
+ }
+}
+
+// -----
+
+func.func @bar() {
+ "another_op"() : () -> ()
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["another_op"]} in %arg1
+ transform.structured.replace %0 {
+ ^bb0(%a: i1):
+ // expected-error @+1 {{expected replacement without operands}}
+ "dummy_op"(%a) : (i1) -> ()
+ }
+}
More information about the Mlir-commits
mailing list