[Mlir-commits] [mlir] 5c60a08 - [mlir][Linalg] Support tensor.parallel_insert_slice in transform.insert_slice_to_copy

Nicolas Vasilache llvmlistbot at llvm.org
Fri Apr 14 06:11:36 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-14T06:11:29-07:00
New Revision: 5c60a08c696c0420ddc5fdad5b8e50a7528cb3bf

URL: https://github.com/llvm/llvm-project/commit/5c60a08c696c0420ddc5fdad5b8e50a7528cb3bf
DIFF: https://github.com/llvm/llvm-project/commit/5c60a08c696c0420ddc5fdad5b8e50a7528cb3bf.diff

LOG: [mlir][Linalg] Support tensor.parallel_insert_slice in transform.insert_slice_to_copy

Differential Revision: https://reviews.llvm.org/D148333

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index af76660d988f2..9366ce7a241b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2030,7 +2030,7 @@ def InsertSliceToCopyOp :
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::tensor::InsertSliceOp target,
+        ::mlir::Operation *target,
         ::mlir::transform::ApplyToEachResultList &results,
         ::mlir::transform::TransformState &state);
   }];

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 39f7802a688a8..8e667022966d0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -38,6 +38,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -3214,18 +3215,27 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
+template <typename OpTy>
+DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
+                                 transform::ApplyToEachResultList &results,
+                                 transform::TransformState &state) {
+  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
+                                tensor::ParallelInsertSliceOp>() &&
+                "wrong op type");
 
-DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
-    tensor::InsertSliceOp target, transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  if (auto copySource = target.getSource().getDefiningOp<linalg::CopyOp>()) {
+  if (auto copySource =
+          target.getSource().template getDefiningOp<linalg::CopyOp>()) {
     results.push_back(copySource);
     return DiagnosedSilenceableFailure::success();
   }
 
-  TrackingListener listener(state, *this);
-  IRRewriter rewriter(target->getContext(), &listener);
-  rewriter.setInsertionPoint(target);
+  // If we are inside an InParallel region, temporarily set the insertion point
+  // outside: only tensor.parallel_insert_slice ops are allowed in there.
+  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
+    rewriter.setInsertionPoint(
+        target->template getParentOfType<scf::InParallelOp>());
+  }
+
   Value extracted = rewriter.create<tensor::ExtractSliceOp>(
       target.getLoc(), target.getDest(), target.getMixedOffsets(),
       target.getMixedSizes(), target.getMixedStrides());
@@ -3233,7 +3243,9 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
                      .create<linalg::CopyOp>(target.getLoc(),
                                              target.getSource(), extracted)
                      .getResult(0);
-  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+  // Reset the insertion point.
+  rewriter.setInsertionPoint(target);
+  rewriter.replaceOpWithNewOp<OpTy>(
       target, copied, target.getDest(), target.getMixedOffsets(),
       target.getMixedSizes(), target.getMixedStrides());
 
@@ -3241,6 +3253,25 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
+    Operation *targetOp, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(targetOp->getContext(), &listener);
+  rewriter.setInsertionPoint(targetOp);
+  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
+    return doit(rewriter, target, results, state);
+  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
+    return doit(rewriter, target, results, state);
+
+  DiagnosedSilenceableFailure diag =
+      emitSilenceableError()
+      << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
+  diag.attachNote(targetOp->getLoc()) << "target op";
+  return diag;
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
index 7c2461c52b2a2..e6b2d2b0c4c3e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s
+// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL: func @insert_slice_to_copy
     // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
@@ -108,3 +108,30 @@ transform.sequence failures(propagate) {
   transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
 }
 
+// -----
+
+// CHECK-LABEL: func @parallel_insert_slice_to_copy
+func.func @parallel_insert_slice_to_copy(%out : tensor<?x?xf32>, %sz0: index, %sz1: index) {
+  %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %out) -> (tensor<?x?xf32>) {
+    %t = "make_me_a_tensor"() : () -> (tensor<?x?xf32> )
+
+    //      CHECK: tensor.extract_slice
+    //      CHECK: linalg.copy
+    //      CHECK: scf.forall.in_parallel
+    //      CHECK:   tensor.parallel_insert_slice
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %t into %arg2[0, 0] [%sz0, %sz1] [1, 1] 
+        : tensor<?x?xf32> into tensor<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.insert_slice_to_copy %0
+    : (!transform.any_op) -> !transform.any_op
+  transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
+}


        


More information about the Mlir-commits mailing list