[Mlir-commits] [mlir] 6fc092f - [mlir][bufferization] Let bufferization.tensor_layout be any layout attr (#138567)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 6 09:21:56 PDT 2025
Author: Krzysztof Drewniak
Date: 2025-05-06T11:21:52-05:00
New Revision: 6fc092fc417e5c4e9fd78c3cc5892bacae405c72
URL: https://github.com/llvm/llvm-project/commit/6fc092fc417e5c4e9fd78c3cc5892bacae405c72
DIFF: https://github.com/llvm/llvm-project/commit/6fc092fc417e5c4e9fd78c3cc5892bacae405c72.diff
LOG: [mlir][bufferization] Let bufferization.tensor_layout be any layout attr (#138567)
The bufferization.tensor_layout is unnecessarily restricted to affine
map attributes when it could reasonably be any implementor of
MemRefLayoutAttrInterface.
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6b9253a5d71da..d8eac01c2dea0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -122,9 +122,9 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
return success();
}
if (attr.getName() == kBufferLayoutAttrName) {
- if (!llvm::isa<AffineMapAttr>(attr.getValue())) {
+ if (!llvm::isa<MemRefLayoutAttrInterface>(attr.getValue())) {
return op->emitError() << "'" << kBufferLayoutAttrName
- << "' is expected to be a affine map attribute";
+ << "' is expected to be a memref layout attribute";
}
if (!isa<FunctionOpInterface>(op))
return op->emitError() << "expected '" << kBufferLayoutAttrName
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..0b0dcc9162a9a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -63,16 +63,16 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
- auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+ auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
index, BufferizationDialect::kBufferLayoutAttrName);
if (!layoutAttr)
return memrefType;
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
- return MemRefType::get(
- rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
- layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
+ return MemRefType::get(rankedMemrefType.getShape(),
+ rankedMemrefType.getElementType(), layoutAttr,
+ rankedMemrefType.getMemorySpace());
}
/// Return the FuncOp called by `callOp`.
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 2983cd30258a5..5f95da25cbc74 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -353,6 +353,28 @@ func.func @cast_retains_buffer_layout(
// -----
+// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided(
+// CHECK-SAME: %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
+// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>>
+// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>>
+// CHECK: return %[[slice]]
+func.func @cast_retains_buffer_layout_strided(
+ %t: tensor<?xf32>
+ {bufferization.buffer_layout = strided<[1], offset: 5>},
+ %sz: index)
+ -> (tensor<10xf32>, tensor<?xf32>)
+{
+ %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+ %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>
+
+ // Note: The %casted return type is folded away because both buffers are
+ // equivalent. Therefore, we currently loose some static type information
+ // in the caller.
+ return %casted, %slice : tensor<10xf32>, tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list