[Mlir-commits] [mlir] [mlir] [memref] add more checks to the memref.reinterpret_cast (PR #112669)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 17 00:19:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: donald chen (cxy-1993)
<details>
<summary>Changes</summary>
Operation memref.reinterpret_cast was accept input like:
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out.
This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.
---
Full diff: https://github.com/llvm/llvm-project/pull/112669.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+3-7)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+9)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..add78e78a97a8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,8 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
- if (!ShapedType::isDynamic(resultSize) &&
- !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+ if (resultSize != expectedSize)
return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << idx;
@@ -1910,17 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
- if (!ShapedType::isDynamic(resultOffset) &&
- !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+ if (resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) {
- if (!ShapedType::isDynamic(resultStride) &&
- !ShapedType::isDynamic(expectedStride) &&
- resultStride != expectedStride)
+ if (resultStride != expectedStride)
return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << idx;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..739cf76429c045 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
// -----
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+ // expected-error @+1 {{expected result type with offset = -9223372036854775808 instead of 0}}
+ %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
``````````
</details>
https://github.com/llvm/llvm-project/pull/112669
More information about the Mlir-commits
mailing list