[Mlir-commits] [mlir] [mlir] [memref] add more checks to the memref.reinterpret_cast (PR #112669)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 17 00:19:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

Operation memref.reinterpret_cast was accept input like:

  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset, but its data type indicates an offset of 0. Permitting this inconsistency can result in incorrect outcomes, as certain pass might erroneously extract the offset from the data type of %out.

 This patch fixes this by enforcing that the return value's data type aligns
with the input parameter.

---
Full diff: https://github.com/llvm/llvm-project/pull/112669.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+3-7) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+9) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..add78e78a97a8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,8 +1892,7 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
-    if (!ShapedType::isDynamic(resultSize) &&
-        !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+    if (resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
              << " in dim = " << idx;
@@ -1910,17 +1909,14 @@ LogicalResult ReinterpretCastOp::verify() {
 
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
-  if (!ShapedType::isDynamic(resultOffset) &&
-      !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+  if (resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
            << expectedOffset << " instead of " << resultOffset;
 
   // Match strides in result memref type and in static_strides attribute.
   for (auto [idx, resultStride, expectedStride] :
        llvm::enumerate(resultStrides, getStaticStrides())) {
-    if (!ShapedType::isDynamic(resultStride) &&
-        !ShapedType::isDynamic(expectedStride) &&
-        resultStride != expectedStride)
+    if (resultStride != expectedStride)
       return emitError("expected result type with stride = ")
              << expectedStride << " instead of " << resultStride
              << " in dim = " << idx;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..739cf76429c045 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
 
 // -----
 
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+  // expected-error @+1 {{expected result type with offset = -9223372036854775808 instead of 0}}
+  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
 func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
   // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
   %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]

``````````

</details>


https://github.com/llvm/llvm-project/pull/112669


More information about the Mlir-commits mailing list