[Mlir-commits] [mlir] ad2f9f6 - [mlir] Fix subtensor_insert bufferization.

Sean Silva llvmlistbot at llvm.org
Thu Nov 12 14:57:05 PST 2020


Author: Sean Silva
Date: 2020-11-12T14:56:09-08:00
New Revision: ad2f9f67451cbb5e3af9760222f802da82f8024e

URL: https://github.com/llvm/llvm-project/commit/ad2f9f67451cbb5e3af9760222f802da82f8024e
DIFF: https://github.com/llvm/llvm-project/commit/ad2f9f67451cbb5e3af9760222f802da82f8024e.diff

LOG: [mlir] Fix subtensor_insert bufferization.

It was incorrect in the presence of a tensor argument with multiple
uses.

The bufferization of subtensor_insert was writing into a converted
memref operand, but there is no guarantee that the converted memref for
that operand is safe to write into. In this case, the same converted
memref is written to in-place by the subtensor_insert bufferization,
violating the tensor-level semantics.

I left some comments in a TODO about ways forward on this. I will be
working actively on this problem in the coming days.

Differential Revision: https://reviews.llvm.org/D91371

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
new file mode 100644
index 000000000000..fbb026c12138
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
+// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+  %const = constant dense<10.0> : tensor<2xf32>
+  %insert_val = constant dense<20.0> : tensor<1xf32>
+
+  // Both of these subtensor_insert ops insert into the same original tensor
+  // value `%const`. This can easily cause bugs if at the memref level
+  // we attempt to write in-place into the memref that %const has been
+  // converted into.
+  %inserted_at_position_0 = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
+  %inserted_at_position_1 = subtensor_insert %insert_val into %const[1][1][1] : tensor<1xf32> into tensor<2xf32>
+
+  %unranked_at_position_0 = tensor_cast %inserted_at_position_0 : tensor<2xf32> to tensor<*xf32>
+  call @print_memref_f32(%unranked_at_position_0) : (tensor<*xf32>) -> ()
+
+  //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+  // CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
+  // CHECK-NEXT: [20, 10]
+
+  %unranked_at_position_1 = tensor_cast %inserted_at_position_1 : tensor<2xf32> to tensor<*xf32>
+  call @print_memref_f32(%unranked_at_position_1) : (tensor<*xf32>) -> ()
+
+  //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+  // CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
+  // CHECK-NEXT: [10, 20]
+
+  return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 3672b80730b8..b78a26281c66 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -40,6 +40,19 @@ static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
   return b.create<IndexCastOp>(loc, val, b.getIndexType());
 }
 
+static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
+  auto memrefType = memref.getType().cast<MemRefType>();
+  SmallVector<Value, 4> dynOperands;
+  for (auto dim : llvm::enumerate(memrefType.getShape())) {
+    if (dim.value() == TensorType::kDynamicSize) {
+      dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index()));
+    }
+  }
+  auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
+  b.create<linalg::CopyOp>(loc, memref, alloc);
+  return alloc;
+}
+
 static LogicalResult
 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
                           linalg::GenericOpAdaptor &adaptor,
@@ -65,19 +78,10 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
     // results.
     // TODO: update this assumption because the reality is more complex
     // under linalg on tensor based transformations.
-    bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors();
-    if (foldedInitTensor) {
-      Value initTensor = linalgOp.getInitTensor(resultIndex);
-      Value initBuffer = adaptor.init_tensors()[resultIndex];
-      SmallVector<Value, 4> dynOperands;
-      for (auto dim : llvm::enumerate(tensorShape)) {
-        if (dim.value() == TensorType::kDynamicSize) {
-          dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index()));
-        }
-      }
-      auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
-      b.create<linalg::CopyOp>(loc, initBuffer, alloc);
-      resultBuffers.push_back(alloc);
+    bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors();
+    if (hasInitTensor) {
+      resultBuffers.push_back(
+          cloneMemref(loc, adaptor.init_tensors()[resultIndex], b));
       continue;
     }
 
@@ -303,7 +307,10 @@ class SubTensorInsertOpConverter
     Value sourceMemRef = adaptor.source();
     assert(sourceMemRef.getType().isa<MemRefType>());
 
-    Value destMemRef = adaptor.dest();
+    // For now, be conservative and copy the converted input memref.
+    // In general, the converted input memref here could be aliased or could
+    // point into constant memory, so mutating it would lead to miscompilations.
+    Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
     assert(destMemRef.getType().isa<MemRefType>());
 
     // Take a subview to copy the small memref.

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index c85951577674..ef79f911d3e8 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -199,20 +199,33 @@ func @bufferize_subtensor_insert(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %
   //      CHECK: %[[IDX:.*]] = call @make_index() : () -> index
   %i0 = call @make_index() : () -> index
 
+
   //  CHECK-DAG: %[[M0:.*]] = tensor_to_memref %[[T]] : memref<?x?xf32>
   //  CHECK-DAG: %[[SM0:.*]] = tensor_to_memref %[[ST0]] : memref<2x3xf32>
-  // CHECK-NEXT: %[[SUBVIEW0:.*]] = subview %[[M0]][0, 0] [2, 3] [1, 1]
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  // CHECK-NEXT: %[[DIM0:.*]] = dim %[[M0]], %[[C0]] : memref<?x?xf32>
+  // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+  // CHECK-NEXT: %[[DIM1:.*]] = dim %[[M0]], %[[C1]] : memref<?x?xf32>
+  // CHECK-NEXT: %[[M0_COPY:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+  // CHECK-NEXT: linalg.copy(%[[M0]], %[[M0_COPY]]) : memref<?x?xf32>, memref<?x?xf32>
+  // CHECK-NEXT: %[[SUBVIEW0:.*]] = subview %[[M0_COPY]][0, 0] [2, 3] [1, 1]
   // CHECK-SAME:   memref<?x?xf32> to memref<2x3xf32, #[[$MAP0]]>
   // CHECK-NEXT: linalg.copy(%[[SM0]], %[[SUBVIEW0]]) : memref<2x3xf32>, memref<2x3xf32, #[[$MAP0]]>
-  // CHECK-NEXT: %[[RT0:.*]] = tensor_load %[[M0]] : memref<?x?xf32>
+  // CHECK-NEXT: %[[RT0:.*]] = tensor_load %[[M0_COPY]] : memref<?x?xf32>
   %t0 = subtensor_insert %st0 into %t[0, 0][2, 3][1, 1] : tensor<2x3xf32> into tensor<?x?xf32>
 
   //  CHECK-DAG: %[[M1:.*]] = tensor_to_memref %[[T]] : memref<?x?xf32>
   //  CHECK-DAG: %[[SM1:.*]] = tensor_to_memref %[[ST1]] : memref<2x?xf32>
-  // CHECK-NEXT: %[[SUBVIEW1:.*]] = subview %[[M1]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2]
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  // CHECK-NEXT: %[[DIM0:.*]] = dim %[[M1]], %[[C0]] : memref<?x?xf32>
+  // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+  // CHECK-NEXT: %[[DIM1:.*]] = dim %[[M1]], %[[C1]] : memref<?x?xf32>
+  // CHECK-NEXT: %[[M1_COPY:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+  // CHECK-NEXT: linalg.copy(%[[M1]], %[[M1_COPY]]) : memref<?x?xf32>, memref<?x?xf32>
+  // CHECK-NEXT: %[[SUBVIEW1:.*]] = subview %[[M1_COPY]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2]
   // CHECK-SAME:   memref<?x?xf32> to memref<2x?xf32, #[[$MAP1]]>
   // CHECK-NEXT: linalg.copy(%[[SM1]], %[[SUBVIEW1]]) : memref<2x?xf32>, memref<2x?xf32, #[[$MAP1]]>
-  // CHECK-NEXT: %[[RT1:.*]] = tensor_load %[[M1]] : memref<?x?xf32>
+  // CHECK-NEXT: %[[RT1:.*]] = tensor_load %[[M1_COPY]] : memref<?x?xf32>
   %t1 = subtensor_insert %st1 into %t[0, %i0][2, %i0][1, 2] : tensor<2x?xf32> into tensor<?x?xf32>
 
   //     CHECK: return %[[RT0]], %[[RT1]]


        


More information about the Mlir-commits mailing list