[Mlir-commits] [mlir] [mlir] [memref] add more checks to the memref.reinterpret_cast (PR #112669)
donald chen
llvmlistbot at llvm.org
Thu Oct 17 00:19:19 PDT 2024
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/112669
Operation memref.reinterpret_cast was accept input like:
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out.
This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.
>From 831847df0f5af9627a9e325b6027cbb4d0d8b3ac Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Thu, 17 Oct 2024 03:05:56 +0000
Subject: [PATCH] [mlir] [memref] add more checks to the
memref.reinterpret_cast
Operation memref.reinterpret_cast was accept input like:
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
A problem arises: while lowering, the true offset of %out is %offset, but its
data type indicates an offset of 0. Permitting this inconsistency can result in
incorrect outcomes, as certain pass might erroneously extract the offset from
the data type of %out.
This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++-------
mlir/test/Dialect/MemRef/invalid.mlir | 9 +++++++++
2 files changed, 12 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..add78e78a97a8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,8 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
- if (!ShapedType::isDynamic(resultSize) &&
- !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+ if (resultSize != expectedSize)
return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << idx;
@@ -1910,17 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
- if (!ShapedType::isDynamic(resultOffset) &&
- !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+ if (resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) {
- if (!ShapedType::isDynamic(resultStride) &&
- !ShapedType::isDynamic(expectedStride) &&
- resultStride != expectedStride)
+ if (resultStride != expectedStride)
return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << idx;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..739cf76429c045 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
// -----
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+ // expected-error @+1 {{expected result type with offset = -9223372036854775808 instead of 0}}
+ %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
More information about the Mlir-commits
mailing list