[Mlir-commits] [mlir] [mlir] [linalg] Fix bufferize error in tensor.parallel_insert_slice op (PR #98312)

donald chen llvmlistbot at llvm.org
Thu Jul 11 02:59:01 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/98312

>From 871c7411cecb8e8efddf21b87e1c00f861dd54aa Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Wed, 10 Jul 2024 12:55:56 +0000
Subject: [PATCH] [mlir] [linalg] Fix bufferize error in
 tensor.parallel_insert_slice op

tensor.parallel_insert_slice op has implicit inplace behavior. In the
"copy-before-write" bufferize mode, the resolveConflict function will generate
bufferize.copy making the result incorrect. This patch fixes this issue.
---
 .../BufferizableOpInterfaceImpl.cpp           | 16 +++++++++++----
 mlir/test/Dialect/Tensor/bufferize.mlir       | 20 +++++++++++++++++++
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..87464ccb71720 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
-        mixedSizes, mixedStrides);
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+        mixedOffsets, mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
-        mixedOffsets, mixedSizes, mixedStrides));
+        extractSliceOp.getType().getShape(),
+        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+        mixedStrides));
   }
 };
 
@@ -997,6 +998,13 @@ struct ParallelInsertSliceOpInterface
     rewriter.eraseOp(op);
     return success();
   }
+
+  /// tensor.parallel_insert_slice op has implicit inplace behavior. We
+  /// shouldn't create copy to resolve conflict.
+  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
+                                 const AnalysisState &state) const {
+    return success();
+  }
 };
 
 /// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index e85d9e740adf4..3a3c8af15e6e4 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -626,3 +626,23 @@ func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf
   return %0 : tensor<?x3x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_copy_before_write
+func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: tensor<4xf32>) {
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 4 : index
+
+  // CHECK: scf.forall {{.*}} {
+  %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> {
+      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32>
+      scf.forall.in_parallel {
+        // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
+        // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
+          tensor<1xf32> into tensor<4xf32>
+      }
+  }
+  // CHECK: }
+  return
+}



More information about the Mlir-commits mailing list