[Mlir-commits] [mlir] b6ae3f8 - [mlir][tensor][bufferize] Implement getBufferType for CastOp
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 1 05:24:27 PST 2023
Author: Matthias Springer
Date: 2023-02-01T14:24:10+01:00
New Revision: b6ae3f88731c8a82668bd5a992a5ae9b41e716a9
URL: https://github.com/llvm/llvm-project/commit/b6ae3f88731c8a82668bd5a992a5ae9b41e716a9
DIFF: https://github.com/llvm/llvm-project/commit/b6ae3f88731c8a82668bd5a992a5ae9b41e716a9.diff
LOG: [mlir][tensor][bufferize] Implement getBufferType for CastOp
This interface method is used to compute the buffer type of a value during bufferization. It was missing. This is interface method is used during loop bufferization.
Also fix a bug where a cast from an unranked tensor to a ranked tensor type did not always apply a fully dynamic layout map on the result memref.
Differential Revision: https://reviews.llvm.org/D143063
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9c5b7d520ba1d..14401a98c8e2b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -51,6 +51,39 @@ struct CastOpInterface
return BufferRelation::Equivalent;
}
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto castOp = cast<tensor::CastOp>(op);
+ auto maybeSrcBufferType =
+ bufferization::getBufferType(castOp.getSource(), options, fixedTypes);
+ if (failed(maybeSrcBufferType))
+ return failure();
+ Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
+
+ // Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
+ // type in case the input is an unranked tensor type.
+
+ // Case 1: Casting an unranked tensor
+ if (castOp.getSource().getType().isa<UnrankedTensorType>()) {
+ // When casting to a ranked tensor, we cannot infer any static offset or
+ // strides from the source. Assume fully dynamic.
+ return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+ }
+
+ // Case 2: Casting to an unranked tensor type
+ if (castOp.getType().isa<UnrankedTensorType>()) {
+ return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+ }
+
+ // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
+ // change.
+ auto rankedResultType = castOp.getType().cast<RankedTensorType>();
+ return MemRefType::get(
+ rankedResultType.getShape(), rankedResultType.getElementType(),
+ maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
+ }
+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto castOp = cast<tensor::CastOp>(op);
@@ -60,25 +93,19 @@ struct CastOpInterface
getBuffer(rewriter, castOp.getSource(), options);
if (failed(resultBuffer))
return failure();
- auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
- TensorType resultTensorType =
- castOp.getResult().getType().cast<TensorType>();
- MemRefLayoutAttrInterface layout;
- if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
- if (resultTensorType.isa<RankedTensorType>())
- layout = rankedMemRefType.getLayout();
-
- // Compute the new memref type.
- Type resultMemRefType = getMemRefType(castOp.getResult(), options, layout,
- sourceMemRefType.getMemorySpace());
+ // Compute the new type.
+ auto resultMemRefType =
+ bufferization::getBufferType(castOp.getResult(), options);
+ if (failed(resultMemRefType))
+ return failure();
// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
- resultMemRefType) &&
+ *resultMemRefType) &&
"CallOp::bufferize: cast incompatible");
- replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
- *resultBuffer);
+ replaceOpWithNewBufferizedOp<memref::CastOp>(
+ rewriter, op, *resultMemRefType, *resultBuffer);
return success();
}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 90c88d86a11c1..587eed843c71f 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -925,3 +925,26 @@ func.func @non_block_argument_yield() {
}
return
}
+
+// -----
+
+// This is a regression test. Make sure that bufferization succeeds.
+
+// CHECK-LABEL: func @regression_cast_in_loop(
+func.func @regression_cast_in_loop() -> tensor<2xindex> {
+ %false = arith.constant false
+ %c0 = arith.constant 0 : index
+ %0 = bufferization.alloc_tensor() : tensor<2xindex>
+ // CHECK: scf.while (%{{.*}} = %{{.*}}) : (memref<2xindex>) -> memref<2xindex>
+ %1 = scf.while (%arg0 = %0) : (tensor<2xindex>) -> tensor<2xindex> {
+ scf.condition(%false) %arg0 : tensor<2xindex>
+ } do {
+ // CHECK: ^bb0(%{{.*}}: memref<2xindex>):
+ ^bb0(%arg0: tensor<2xindex>):
+ %cast = tensor.cast %0 : tensor<2xindex> to tensor<?xindex>
+ %inserted = tensor.insert %c0 into %cast[%c0] : tensor<?xindex>
+ %cast_0 = tensor.cast %inserted : tensor<?xindex> to tensor<2xindex>
+ scf.yield %cast_0 : tensor<2xindex>
+ }
+ return %1 : tensor<2xindex>
+}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 4948b0dccf976..cbcc1e3d339b6 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -40,8 +40,8 @@ func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor.cast_from_unranked(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
-// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
-// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32>
+// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32, strided<[?], offset: ?>>
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32, strided<[?], offset: ?>>
// CHECK: return %[[RET]] : tensor<2xf32>
func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 59fde562bd876..25164a4ba870c 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -347,3 +347,26 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
%1 = tensor.dim %t, %c0 : tensor<?xf32>
return %0, %1 : tensor<?xf32>, index
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)>
+// CHECK-LABEL: func.func @cast_retains_buffer_layout(
+// CHECK-SAME: %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
+// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]>
+// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>>
+// CHECK: return %[[slice]]
+func.func @cast_retains_buffer_layout(
+ %t: tensor<?xf32>
+ {bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>},
+ %sz: index)
+ -> (tensor<10xf32>, tensor<?xf32>)
+{
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>
+
+ // Note: The %casted return type is folded away because both buffers are
+ // equivalent. Therefore, we currently loose some static type information
+ // in the caller.
+ return %casted, %slice : tensor<10xf32>, tensor<?xf32>
+}
More information about the Mlir-commits
mailing list