[Mlir-commits] [mlir] [mlir][bufferization] Let bufferization.tensor_layout be any layout attr (PR #138567)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon May 5 12:31:26 PDT 2025


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/138567

The bufferization.tensor_layout is unnecessarily restricted to affine map attributes when it could reasonably be any implementor of MemRefLayoutAttrInterface.

>From a492cf9a002bbd3116d513930c3917f23f42e3d7 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 5 May 2025 19:25:17 +0000
Subject: [PATCH] [mlir][bufferization] Let bufferization.tensor_layout be any
 layout attr

The bufferization.tensor_layout is unnecessarily restricted to
affine map attributes when it could reasonably be any implementor
of MemRefLayoutAttrInterface.
---
 .../Bufferization/IR/BufferizationDialect.cpp |  4 ++--
 .../FuncBufferizableOpInterfaceImpl.cpp       |  8 +++----
 .../Dialect/Tensor/one-shot-bufferize.mlir    | 22 +++++++++++++++++++
 3 files changed, 28 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6b9253a5d71da..d8eac01c2dea0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -122,9 +122,9 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
     return success();
   }
   if (attr.getName() == kBufferLayoutAttrName) {
-    if (!llvm::isa<AffineMapAttr>(attr.getValue())) {
+    if (!llvm::isa<MemRefLayoutAttrInterface>(attr.getValue())) {
       return op->emitError() << "'" << kBufferLayoutAttrName
-                             << "' is expected to be a affine map attribute";
+                             << "' is expected to be a memref layout attribute";
     }
     if (!isa<FunctionOpInterface>(op))
       return op->emitError() << "expected '" << kBufferLayoutAttrName
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..0b0dcc9162a9a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -63,16 +63,16 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
   BaseMemRefType memrefType = options.functionArgTypeConverterFn(
       tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
 
-  auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+  auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
       index, BufferizationDialect::kBufferLayoutAttrName);
   if (!layoutAttr)
     return memrefType;
 
   auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
-  return MemRefType::get(
-      rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
-      layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
+  return MemRefType::get(rankedMemrefType.getShape(),
+                         rankedMemrefType.getElementType(), layoutAttr,
+                         rankedMemrefType.getMemorySpace());
 }
 
 /// Return the FuncOp called by `callOp`.
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 2983cd30258a5..5f95da25cbc74 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -353,6 +353,28 @@ func.func @cast_retains_buffer_layout(
 
 // -----
 
+// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided(
+//  CHECK-SAME:     %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
+//       CHECK:   %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>>
+//       CHECK:   %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>>
+//       CHECK:   return %[[slice]]
+func.func @cast_retains_buffer_layout_strided(
+    %t: tensor<?xf32>
+        {bufferization.buffer_layout = strided<[1], offset: 5>},
+    %sz: index)
+  -> (tensor<10xf32>, tensor<?xf32>)
+{
+  %casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
+  %slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>
+
+  // Note: The %casted return type is folded away because both buffers are
+  // equivalent. Therefore, we currently loose some static type information
+  // in the caller.
+  return %casted, %slice : tensor<10xf32>, tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
 func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
   %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list