[Mlir-commits] [mlir] [mlir][memref] extract_strided_metadata for zero-sized memref (PR #74835)

Guray Ozen llvmlistbot at llvm.org
Fri Dec 8 05:34:34 PST 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/74835

Fix to support zero-sized memrefs in utils

>From d3390efbfd9aa4567bfb94fdbc78eb955c7acd9b Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 8 Dec 2023 14:34:08 +0100
Subject: [PATCH] [mlir][memref] extract_strided_metadata for zero-sized memref

Fix to support zero-sized memrefs in utils
---
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      |  2 +-
 .../MemRef/expand-strided-metadata.mlir       | 20 +++++++++++++++++++
 2 files changed, 21 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index f4e29539214b4..bb8a0d5912d7c 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -70,7 +70,7 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
 //===----------------------------------------------------------------------===//
 
 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
-  assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) &&
+  assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
          "sizes must be nonnegative");
   int64_t unit = 1;
   return ::computeSuffixProductImpl(sizes, unit);
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index ab0c78a8ba766..28b7004300594 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1494,3 +1494,23 @@ func.func @extract_strided_metadata_of_cast_unranked(
       index, index,
       index, index
 }
+
+
+// -----
+memref.global "private" @dynamicShmem : memref<0xf16,3>
+
+// CHECK-LABEL: func @zero_sized_memred
+func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index) {
+  %c0 = arith.constant 0 : index
+  %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
+
+  // CHECK: %[[BASE:.*]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
+  // CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [], strides: [] : memref<0xf16, 3> to memref<f16, 3>
+  // CHECK: return %[[CAST]]
+
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %dynamicMem : memref<0xf16, 3> -> memref<f16, 3>, index, index, index
+  return %base_buffer, %offset,
+    %sizes, %strides :
+      memref<f16,3>, index,
+      index, index
+}
\ No newline at end of file



More information about the Mlir-commits mailing list