[Mlir-commits] [mlir] [mlir] [memref] add more checks to the memref.reinterpret_cast (PR #112669)

donald chen llvmlistbot at llvm.org
Thu Oct 17 01:16:34 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/112669

>From 831847df0f5af9627a9e325b6027cbb4d0d8b3ac Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Thu, 17 Oct 2024 03:05:56 +0000
Subject: [PATCH 1/2] [mlir] [memref] add more checks to the
 memref.reinterpret_cast

Operation memref.reinterpret_cast was accept input like:

  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset, but its
data type indicates an offset of 0. Permitting this inconsistency can result in
incorrect outcomes, as certain pass might erroneously extract the offset from
the data type of %out.

 This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++-------
 mlir/test/Dialect/MemRef/invalid.mlir    |  9 +++++++++
 2 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..add78e78a97a8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,8 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
-    if (!ShapedType::isDynamic(resultSize) &&
-        !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+    if (resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
              << " in dim = " << idx;
@@ -1910,17 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
 
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
-  if (!ShapedType::isDynamic(resultOffset) &&
-      !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+  if (resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
            << expectedOffset << " instead of " << resultOffset;
 
   // Match strides in result memref type and in static_strides attribute.
   for (auto [idx, resultStride, expectedStride] :
        llvm::enumerate(resultStrides, getStaticStrides())) {
-    if (!ShapedType::isDynamic(resultStride) &&
-        !ShapedType::isDynamic(expectedStride) &&
-        resultStride != expectedStride)
+    if (resultStride != expectedStride)
       return emitError("expected result type with stride = ")
              << expectedStride << " instead of " << resultStride
              << " in dim = " << idx;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..739cf76429c045 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
 
 // -----
 
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+  // expected-error @+1 {{expected result type with offset = -9223372036854775808 instead of 0}}
+  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
 func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
   // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
   %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]

>From fe54602f46208bab60fd3e2df6b825ffb683dfa3 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Thu, 17 Oct 2024 08:15:33 +0000
Subject: [PATCH 2/2] enable dynamic in result type + static in operand

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index add78e78a97a8e..985aa98eb07ccd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,7 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
-    if (resultSize != expectedSize)
+    if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
              << " in dim = " << idx;
@@ -1909,14 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
 
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
-  if (resultOffset != expectedOffset)
+  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
            << expectedOffset << " instead of " << resultOffset;
 
   // Match strides in result memref type and in static_strides attribute.
   for (auto [idx, resultStride, expectedStride] :
        llvm::enumerate(resultStrides, getStaticStrides())) {
-    if (resultStride != expectedStride)
+    if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
       return emitError("expected result type with stride = ")
              << expectedStride << " instead of " << resultStride
              << " in dim = " << idx;



More information about the Mlir-commits mailing list