[Mlir-commits] [mlir] [mlir][memref]: Fix Bug in GlobalOp Verifier (PR #144900)
Jack Frankland
llvmlistbot at llvm.org
Mon Jun 23 04:06:27 PDT 2025
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/144900
>From d8ed7e06731e43ac36725c62bd963301d8c94cba Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Thu, 19 Jun 2025 14:58:48 +0100
Subject: [PATCH] [mlir][memref]: Fix Bug in GlobalOp Verifier
When reconstructing the corresponding tensor type of a memref ensure we
include the memory space of the tensor if it exists.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 3 ++-
mlir/test/Dialect/MemRef/invalid.mlir | 5 +++++
mlir/test/Dialect/MemRef/ops.mlir | 3 +++
3 files changed, 10 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..5db4ea30c25f7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -59,7 +59,8 @@ LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
/// type.
Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
if (auto memref = llvm::dyn_cast<MemRefType>(type))
- return RankedTensorType::get(memref.getShape(), memref.getElementType());
+ return RankedTensorType::get(memref.getShape(), memref.getElementType(),
+ memref.getMemorySpace());
if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
return UnrankedTensorType::get(memref.getElementType());
return NoneType::get(type.getContext());
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f72ad48245f81..9b4a0767ccf2b 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -342,6 +342,11 @@ memref.global "priate" constant @memref5 : memref<2xf32> = uninitialized
// -----
+// expected-error @+1 {{op initial value expected to be of type 'tensor<1xf16>', but was of type 'tensor<1xf16, 1 : i32>'}}
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf16, 1 : i32>, sym_name = "memref6", sym_visibility = "private", type = memref<1xf16>}> : () -> ()
+
+// -----
+
func.func @nonexistent_global_memref() {
// expected-error @+1 {{'gv' does not reference a valid global memref}}
%0 = memref.get_global @gv : memref<3xf32>
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7038a6ff744e4..1f1f7677429c8 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -174,6 +174,9 @@ memref.global "private" @memref3 : memref<2xf32> = uninitialized
// CHECK-LABEL: memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
+// CHECK-LABEL: memref.global "private" constant @memref5 : memref<1xf16, 1 : i32> = dense<1.000000e+00>
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf16, 1 : i32>, sym_name = "memref5", sym_visibility = "private", type = memref<1xf16, 1 : i32>}> : () -> ()
+
// CHECK-LABEL: func @read_global_memref
func.func @read_global_memref() {
%0 = memref.get_global @memref0 : memref<2xf32>
More information about the Mlir-commits
mailing list