[Mlir-commits] [mlir] 128caa1 - [mlir][bufferization] Refine tensor-buffer compatibility checks (#167705)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 18 02:18:57 PST 2025
Author: Andrei Golubev
Date: 2025-11-18T11:18:53+01:00
New Revision: 128caa1ba37fe7f216226d24e8d616bab2d68ee9
URL: https://github.com/llvm/llvm-project/commit/128caa1ba37fe7f216226d24e8d616bab2d68ee9
DIFF: https://github.com/llvm/llvm-project/commit/128caa1ba37fe7f216226d24e8d616bab2d68ee9.diff
LOG: [mlir][bufferization] Refine tensor-buffer compatibility checks (#167705)
Generally, to_tensor and to_buffer already perform sufficient
verification. However, there are some unnecessarily strict constraints:
* builtin tensor requires its buffer counterpart to always be memref
* to_buffer on ranked tensor requires to always return memref
These checks are assertions (i.e. preconditions), however, they actually
prevent an apparently useful bufferization where builtin tensors could
become custom buffers. Lift these assertions, maintaining the
verification procedure unchanged, to allow builtin -> custom
bufferizations at operation boundary level.
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
mlir/test/Dialect/Bufferization/invalid.mlir
mlir/test/Dialect/Bufferization/ops.mlir
mlir/test/lib/Dialect/Test/TestTypes.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index e0cf353da207f..9b11270e7bbe2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -680,16 +680,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}
-// bufferization.to_buffer is not allowed to change the rank.
-static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
-#ifndef NDEBUG
- auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
- assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
- rankedTensorType.getRank()) &&
- "to_buffer would be invalid: mismatching ranks");
-#endif
-}
-
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
@@ -708,7 +698,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *bufferType);
+
return bufferization::ToBufferOp::create(rewriter, value.getLoc(),
*bufferType, value)
.getResult();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index d6c3cd62ee742..bd177ba1afccd 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -54,9 +54,6 @@ struct BuiltinTensorExternalModel
mlir::LogicalResult verifyCompatibleBufferType(
mlir::Type tensor, BufferLikeType bufferType,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
- assert(isa<TensorType>(tensor) && "expected tensor type");
- assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
-
auto tensorType = cast<ShapedType>(tensor);
auto memrefType = cast<ShapedType>(bufferType);
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 2c8807b66de74..9884b040119d0 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
arith.constant {bufferization.manual_deallocation} 0 : index
}
+
+// -----
+
+func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
+ // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %b = bufferization.to_buffer %t
+ : tensor<1x2x3x4xf32> to memref<1x2x3xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
+ // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %t = bufferization.to_tensor %b
+ : memref<1x2x3xf32> to tensor<1x2x3x4xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
+ // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %b = bufferization.to_buffer %t
+ : tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
+ // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{shapes do not match}}
+ %t = bufferization.to_tensor %b
+ : memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
+ // expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{element types do not match}}
+ %b = bufferization.to_buffer %t
+ : tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
+ return
+}
+
+// -----
+
+func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
+ // expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
+ // expected-error @below{{element types do not match}}
+ %t2 = bufferization.to_tensor %b
+ : memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index fc6df4a09f706..b0db1bb2d0389 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
bufferization.dealloc
return %0#0, %0#1 : i1, i1
}
+
+// CHECK: func.func @test_builtin_custom_builtin_type_conversion
+// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
+func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
+ -> tensor<42xf32> {
+ // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
+ // CHECK-SAME: to !test.test_memref<[42], f32>
+ %buffer = bufferization.to_buffer %t
+ : tensor<42xf32> to !test.test_memref<[42], f32>
+
+ // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
+ // CHECK-SAME: to tensor<42xf32>
+ %tensor = bufferization.to_tensor %buffer
+ : !test.test_memref<[42], f32> to tensor<42xf32>
+
+ // CHECK: return %[[tensor]]
+ return %tensor : tensor<42xf32>
+}
+
+// CHECK: func.func @test_custom_builtin_custom_type_conversion
+// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
+// CHECK-SAME: -> !test.test_tensor<[42], f32>
+func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
+ -> !test.test_tensor<[42], f32> {
+ // CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
+ // CHECK-SAME: to memref<42xf32>
+ %buffer = bufferization.to_buffer %t
+ : !test.test_tensor<[42], f32> to memref<42xf32>
+
+ // CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
+ // CHECK-SAME: to !test.test_tensor<[42], f32>
+ %tensor = bufferization.to_tensor %buffer
+ : memref<42xf32> to !test.test_tensor<[42], f32>
+
+ // CHECK: return %[[tensor]]
+ return %tensor : !test.test_tensor<[42], f32>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 614121f1d43dd..9cf64a896d28a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -569,11 +569,17 @@ TestTensorType::getBufferType(
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
- auto testMemref = dyn_cast<TestMemrefType>(bufferType);
- if (!testMemref)
- return emitError() << "expected TestMemrefType";
+ if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) {
+ const bool valid = getShape() == testMemref.getShape() &&
+ getElementType() == testMemref.getElementType();
+ return mlir::success(valid);
+ }
+
+ if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) {
+ const bool valid = getShape() == builtinMemref.getShape() &&
+ getElementType() == builtinMemref.getElementType();
+ return mlir::success(valid);
+ }
- const bool valid = getShape() == testMemref.getShape() &&
- getElementType() == testMemref.getElementType();
- return mlir::success(valid);
+ return emitError() << "expected MemRefType or TestMemrefType";
}
More information about the Mlir-commits
mailing list