[Mlir-commits] [mlir] 5c60a08 - [mlir][Linalg] Support tensor.parallel_insert_slice in transform.insert_slice_to_copy
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Apr 14 06:11:36 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-14T06:11:29-07:00
New Revision: 5c60a08c696c0420ddc5fdad5b8e50a7528cb3bf
URL: https://github.com/llvm/llvm-project/commit/5c60a08c696c0420ddc5fdad5b8e50a7528cb3bf
DIFF: https://github.com/llvm/llvm-project/commit/5c60a08c696c0420ddc5fdad5b8e50a7528cb3bf.diff
LOG: [mlir][Linalg] Support tensor.parallel_insert_slice in transform.insert_slice_to_copy
Differential Revision: https://reviews.llvm.org/D148333
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index af76660d988f2..9366ce7a241b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2030,7 +2030,7 @@ def InsertSliceToCopyOp :
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::tensor::InsertSliceOp target,
+ ::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 39f7802a688a8..8e667022966d0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -38,6 +38,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include <type_traits>
using namespace mlir;
using namespace mlir::linalg;
@@ -3214,18 +3215,27 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
+template <typename OpTy>
+DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
+ tensor::ParallelInsertSliceOp>() &&
+ "wrong op type");
-DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
- tensor::InsertSliceOp target, transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- if (auto copySource = target.getSource().getDefiningOp<linalg::CopyOp>()) {
+ if (auto copySource =
+ target.getSource().template getDefiningOp<linalg::CopyOp>()) {
results.push_back(copySource);
return DiagnosedSilenceableFailure::success();
}
- TrackingListener listener(state, *this);
- IRRewriter rewriter(target->getContext(), &listener);
- rewriter.setInsertionPoint(target);
+ // If we are inside an InParallel region, temporarily set the insertion point
+ // outside: only tensor.parallel_insert_slice ops are allowed in there.
+ if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
+ rewriter.setInsertionPoint(
+ target->template getParentOfType<scf::InParallelOp>());
+ }
+
Value extracted = rewriter.create<tensor::ExtractSliceOp>(
target.getLoc(), target.getDest(), target.getMixedOffsets(),
target.getMixedSizes(), target.getMixedStrides());
@@ -3233,7 +3243,9 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
.create<linalg::CopyOp>(target.getLoc(),
target.getSource(), extracted)
.getResult(0);
- rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+ // Reset the insertion point.
+ rewriter.setInsertionPoint(target);
+ rewriter.replaceOpWithNewOp<OpTy>(
target, copied, target.getDest(), target.getMixedOffsets(),
target.getMixedSizes(), target.getMixedStrides());
@@ -3241,6 +3253,25 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
+ Operation *targetOp, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+
+ TrackingListener listener(state, *this);
+ IRRewriter rewriter(targetOp->getContext(), &listener);
+ rewriter.setInsertionPoint(targetOp);
+ if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
+ return doit(rewriter, target, results, state);
+ if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
+ return doit(rewriter, target, results, state);
+
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
index 7c2461c52b2a2..e6b2d2b0c4c3e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @insert_slice_to_copy
// CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
@@ -108,3 +108,30 @@ transform.sequence failures(propagate) {
transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
}
+// -----
+
+// CHECK-LABEL: func @parallel_insert_slice_to_copy
+func.func @parallel_insert_slice_to_copy(%out : tensor<?x?xf32>, %sz0: index, %sz1: index) {
+ %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %out) -> (tensor<?x?xf32>) {
+ %t = "make_me_a_tensor"() : () -> (tensor<?x?xf32> )
+
+ // CHECK: tensor.extract_slice
+ // CHECK: linalg.copy
+ // CHECK: scf.forall.in_parallel
+ // CHECK: tensor.parallel_insert_slice
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %t into %arg2[0, 0] [%sz0, %sz1] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ }
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.insert_slice_to_copy %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
+}
More information about the Mlir-commits
mailing list