[Mlir-commits] [mlir] 13593dc - [mlir][tensor][bufferize] Fix tensor.insert_slice regression

Matthias Springer llvmlistbot at llvm.org
Sat Nov 26 10:20:52 PST 2022


Author: Matthias Springer
Date: 2022-11-26T19:14:33+01:00
New Revision: 13593dc9dc5a8f587402bf1e5f180f2c0fc750ee

URL: https://github.com/llvm/llvm-project/commit/13593dc9dc5a8f587402bf1e5f180f2c0fc750ee
DIFF: https://github.com/llvm/llvm-project/commit/13593dc9dc5a8f587402bf1e5f180f2c0fc750ee.diff

LOG: [mlir][tensor][bufferize] Fix tensor.insert_slice regression

This reverts D132662 (apart from overall cleanups), which introduced a too aggressive optimization for tensor.insert_slice bufferization. Instead, bufferizesToMemoryRead is improved to handle some of these cases. The remaining cases can still bufferize efficiently when running the canonicalizer before the bufferization.

Differential Revision: https://reviews.llvm.org/D138745

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6ce05c257d94f..528d83f76e050 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -615,8 +615,8 @@ static bool areEquivalentSlices(const AnalysisState &state,
   return true;
 }
 
-/// Return true if `value` is originating from the InsertSliceOp's destination
-/// or an ExtractSliceOp that matches the given InsertSliceOp.
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
 template <typename OpTy>
 static bool matchesInsertDestination(const AnalysisState &state, Value value,
                                      OpTy insertSliceOp) {
@@ -630,15 +630,6 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
   if (llvm::all_of(state.findValueInReverseUseDefChain(value, matchesSlice),
                    matchesSlice))
     return true;
-
-  // Look for equivalent values.
-  auto isEquivalent = [&](Value val) {
-    return state.areEquivalentBufferizedValues(val, insertSliceOp.getDest());
-  };
-  if (llvm::all_of(state.findValueInReverseUseDefChain(
-                       value, isEquivalent, /*followEquivalentOnly=*/true),
-                   isEquivalent))
-    return true;
   return false;
 }
 
@@ -727,6 +718,36 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
 struct InsertSliceOpInterface
     : public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
                                                      tensor::InsertSliceOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+    RankedTensorType destType = insertSliceOp.getDestType();
+
+    // The source is always read.
+    if (&opOperand == &op->getOpOperand(0) /*src*/)
+      return true;
+
+    // For the destination, it depends...
+    assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
+
+    // Dest is not read if it is entirely overwritten. E.g.:
+    // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+    bool allOffsetsZero =
+        llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
+          return isConstantIntValue(ofr, 0);
+        });
+    bool sizesMatchDestSizes = llvm::all_of(
+        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](auto &it) {
+          return getConstantIntValue(it.value()) ==
+                 destType.getDimSize(it.index());
+        });
+    bool allStridesOne =
+        llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+          return isConstantIntValue(ofr, 1);
+        });
+    return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+  }
+
   bool isNotConflicting(Operation *op, OpOperand *uRead,
                         OpOperand *uConflictingWrite,
                         const AnalysisState &state) const {

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index a1dffec50d763..405f112782a74 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -126,12 +126,18 @@ func.func @insert_slice_fun_not_inplace(
 
 // -----
 
-// CHECK-LABEL: func @tensor_cast_in_place(
-//  CHECK-SAME:     %[[A:.*]]: memref<?xf32{{.*}}>
+// This test case could bufferize in-place with a better analysis. However, it
+// is simpler to let the canonicalizer fold away the tensor.insert_slice.
+
+// CHECK-LABEL: func @tensor_cast_not_in_place(
+//  CHECK-SAME:     %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
+//       CHECK:   %[[alloc:.*]] = memref.alloc
+//       CHECK:   memref.copy %[[A]], %[[alloc]]
 //       CHECK:   %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32
-//       CHECK:   memref.copy %[[A]], %[[subview]]
-func.func @tensor_cast_in_place(
-    %A : tensor<?xf32> {bufferization.writable = true}, %idx: index)
+//       CHECK:   memref.copy %[[alloc]], %[[subview]]
+func.func @tensor_cast_not_in_place(
+    %A : tensor<?xf32> {bufferization.writable = true},
+    %B : tensor<?xf32> {bufferization.writable = false}, %idx: index)
   -> (tensor<?xf32>)
 {
   %r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
@@ -241,13 +247,16 @@ func.func @dealloc_pad_buffer(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
 
 // -----
 
+// This test case could bufferize in-place with a better analysis. However, it
+// is simpler to let the canonicalizer fold away the tensor.insert_slice.
+
 // CHECK-LABEL: func @insert_equivalent_tensor
 func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> {
-  // CHECK-NOT: memref.alloc
+  // CHECK: memref.alloc
   %cst = arith.constant 4.200000e+01 : f32
   // CHECK: linalg.fill
   %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
-  // CHECK-NOT: memref.copy
+  // CHECK: memref.copy
   %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
   return %1 : tensor<10xf32>
 }
@@ -279,3 +288,45 @@ func.func @pad_memory_space(%t: tensor<?xf32>, %h1: index, %f: f32, %pos: index)
   // CHECK-DAG: memref.dealloc %[[padded_alloc]]
   return %2 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_regression(
+//  CHECK-SAME:   %[[t:.*]]: memref<10xf32,{{.*}}>, %[[b:.*]]: memref<5xf32
+func.func @insert_slice_regression(%t: tensor<10xf32>, %b: tensor<5xf32>) -> tensor<10xf32> {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32>
+  // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<10xf32>)
+  %1 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
+
+  // Read %1 so that it does not DCE away.
+  %vec = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<10xf32>
+  vector.print %vec : vector<10xf32>
+
+  // Write back a 
diff erent value (not %1).
+  // CHECK: %[[subview:.*]] = memref.subview %[[t]][0] [5] [1]
+  // CHECK: memref.copy %[[b]], %[[subview]]
+  %2 = tensor.insert_slice %b into %t[0][5][1] : tensor<5xf32> into tensor<10xf32>
+  return %2 : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_full_overwrite(
+//  CHECK-SAME:   %[[t:.*]]: memref<10xf32,{{.*}}>, %[[b:.*]]: memref<10xf32,{{.*}}>
+func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -> tensor<10xf32> {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: linalg.fill {{.*}} outs(%[[t]] : memref<10xf32,{{.*}}>)
+  %1 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
+
+  // Read %1 so that it does not DCE away.
+  %vec = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<10xf32>
+  vector.print %vec : vector<10xf32>
+
+  // Write back a 
diff erent value (not %1).
+  // CHECK: memref.copy %[[b]], %[[t]]
+  %2 = tensor.insert_slice %b into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
+  return %2 : tensor<10xf32>
+}


        


More information about the Mlir-commits mailing list