[Mlir-commits] [mlir] 6b1f653 - [mlir][linalg][bufferize] tensor.cast may require a copy

Matthias Springer llvmlistbot at llvm.org
Thu Oct 7 06:30:09 PDT 2021


Author: Matthias Springer
Date: 2021-10-07T22:24:05+09:00
New Revision: 6b1f653c94c0d5de8bb954286bf144f129fdb7ff

URL: https://github.com/llvm/llvm-project/commit/6b1f653c94c0d5de8bb954286bf144f129fdb7ff
DIFF: https://github.com/llvm/llvm-project/commit/6b1f653c94c0d5de8bb954286bf144f129fdb7ff.diff

LOG: [mlir][linalg][bufferize] tensor.cast may require a copy

Differential Revision: https://reviews.llvm.org/D110806

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 0e3ce47eca3e0..6f5ddd35032de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1615,7 +1615,27 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(castOp);
 
-  Type sourceType = lookup(bvm, castOp.source()).getType();
+  // If castOp is not inPlace, allocate a new buffer.
+  auto inPlace = getInPlace(castOp->getResult(0));
+  Value newBuffer;
+  if (inPlace != InPlaceSpec::True) {
+    Location loc = castOp.getLoc();
+    // Alloc a copy for `writeOp.source()`, it will become the result buffer.
+    newBuffer = createNewAllocDeallocPairForShapedValue(b, loc, castOp.source(),
+                                                        aliasInfo);
+    if (!isInitTensorOp(castOp.source())) {
+      // Set insertion point now that potential alloc/dealloc are introduced.
+      b.setInsertionPoint(castOp);
+      b.create<CopyOp>(loc, lookup(bvm, castOp.source()), newBuffer);
+    }
+  } else {
+    // InPlace write will result in memref.tensor_load(x) which must
+    // canonicalize away with one of it uses.
+    newBuffer = lookup(bvm, castOp.source());
+    assert(newBuffer && "missing buffer");
+  }
+
+  Type sourceType = newBuffer.getType();
   auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
   auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
   assert(rankedMemRefType || unrankedMemRefType);
@@ -1629,8 +1649,7 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
           : ArrayRef<AffineMap>{};
   Type memRefType = getContiguousOrUnrankedMemRefType(
       castOp.getResult().getType(), affineMaps, memorySpace);
-  Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType,
-                                       lookup(bvm, castOp.source()));
+  Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType, newBuffer);
   aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
   map(bvm, castOp.getResult(), res);
   return success();

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index ac9735dedec3c..66e54e5971cca 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -737,3 +737,21 @@ func @matmul(
   }
   return %0 : tensor<128x192xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast_not_in_place(
+//  CHECK-SAME:     %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
+//       CHECK:   %[[alloc:.*]] = memref.alloc
+//       CHECK:   linalg.copy(%[[A]], %[[alloc]])
+//       CHECK:   %[[cast:.*]] = memref.cast %[[alloc]]
+func @tensor_cast_not_in_place(
+    %A : tensor<?xf32> {linalg.inplaceable = true},
+    %B : tensor<?xf32>, %idx: index)
+  -> (tensor<?xf32>)
+{
+  %r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
+  %r1 = tensor.insert_slice %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+  return %r1 : tensor<?xf32>
+}
+


        


More information about the Mlir-commits mailing list