[Mlir-commits] [mlir] 6b1f653 - [mlir][linalg][bufferize] tensor.cast may require a copy
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 7 06:30:09 PDT 2021
Author: Matthias Springer
Date: 2021-10-07T22:24:05+09:00
New Revision: 6b1f653c94c0d5de8bb954286bf144f129fdb7ff
URL: https://github.com/llvm/llvm-project/commit/6b1f653c94c0d5de8bb954286bf144f129fdb7ff
DIFF: https://github.com/llvm/llvm-project/commit/6b1f653c94c0d5de8bb954286bf144f129fdb7ff.diff
LOG: [mlir][linalg][bufferize] tensor.cast may require a copy
Differential Revision: https://reviews.llvm.org/D110806
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 0e3ce47eca3e0..6f5ddd35032de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1615,7 +1615,27 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
- Type sourceType = lookup(bvm, castOp.source()).getType();
+ // If castOp is not inPlace, allocate a new buffer.
+ auto inPlace = getInPlace(castOp->getResult(0));
+ Value newBuffer;
+ if (inPlace != InPlaceSpec::True) {
+ Location loc = castOp.getLoc();
+ // Alloc a copy for `writeOp.source()`, it will become the result buffer.
+ newBuffer = createNewAllocDeallocPairForShapedValue(b, loc, castOp.source(),
+ aliasInfo);
+ if (!isInitTensorOp(castOp.source())) {
+ // Set insertion point now that potential alloc/dealloc are introduced.
+ b.setInsertionPoint(castOp);
+ b.create<CopyOp>(loc, lookup(bvm, castOp.source()), newBuffer);
+ }
+ } else {
+ // InPlace write will result in memref.tensor_load(x) which must
+ // canonicalize away with one of it uses.
+ newBuffer = lookup(bvm, castOp.source());
+ assert(newBuffer && "missing buffer");
+ }
+
+ Type sourceType = newBuffer.getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
@@ -1629,8 +1649,7 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
: ArrayRef<AffineMap>{};
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), affineMaps, memorySpace);
- Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType,
- lookup(bvm, castOp.source()));
+ Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType, newBuffer);
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
map(bvm, castOp.getResult(), res);
return success();
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index ac9735dedec3c..66e54e5971cca 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -737,3 +737,21 @@ func @matmul(
}
return %0 : tensor<128x192xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast_not_in_place(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
+// CHECK: %[[alloc:.*]] = memref.alloc
+// CHECK: linalg.copy(%[[A]], %[[alloc]])
+// CHECK: %[[cast:.*]] = memref.cast %[[alloc]]
+func @tensor_cast_not_in_place(
+ %A : tensor<?xf32> {linalg.inplaceable = true},
+ %B : tensor<?xf32>, %idx: index)
+ -> (tensor<?xf32>)
+{
+ %r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
+ %r1 = tensor.insert_slice %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r1 : tensor<?xf32>
+}
+
More information about the Mlir-commits
mailing list