[Mlir-commits] [mlir] [mlir][memref]: Fix Bug in GlobalOp Verifier (PR #144900)
Jack Frankland
llvmlistbot at llvm.org
Wed Jun 25 06:03:44 PDT 2025
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/144900
>From 3a6403a44d10df42df60c7c2d27aafacba678b4a Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Thu, 19 Jun 2025 14:58:48 +0100
Subject: [PATCH] [mlir][memref]: Fix Bug in GlobalOp Verifier
When comparing the type of the initializer in a `memref::GlobalOp`
against its result only consider the element type and the shape. Other
attributes such as memory space should be ignored since comparing these
between tensors and memrefs doesn't make sense and constructing a memref
in a specific memory space with a tensor that has no such attribute
should be valid.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 22 +++++++++++++++++-----
mlir/test/Dialect/MemRef/invalid.mlir | 10 ++++++++++
mlir/test/Dialect/MemRef/ops.mlir | 3 +++
3 files changed, 30 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d56b32193765e..372e83a98ee52 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1567,11 +1567,23 @@ LogicalResult GlobalOp::verify() {
// Check that the type of the initial value is compatible with the type of
// the global variable.
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
- Type initType = elementsAttr.getType();
- Type tensorType = getTensorTypeFromMemRefType(memrefType);
- if (initType != tensorType)
- return emitOpError("initial value expected to be of type ")
- << tensorType << ", but was of type " << initType;
+ // Check the element types match.
+ auto initElementType =
+ cast<TensorType>(elementsAttr.getType()).getElementType();
+ auto memrefElementType = memrefType.getElementType();
+
+ if (initElementType != memrefElementType)
+ return emitOpError("initial value element expected to be of type ")
+ << memrefElementType << ", but was of type " << initElementType;
+
+ // Check the shapes match, given that memref globals can only produce
+ // statically shaped memrefs and elements literal type must have a static
+ // shape we can assume both types are shaped.
+ auto initShape = elementsAttr.getShapedType().getShape();
+ auto memrefShape = memrefType.getShape();
+ if (initShape != memrefShape)
+ return emitOpError("initial value shape expected to be ")
+ << memrefShape << " but was " << initShape;
}
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f908efb638446..8e394b2ac04c8 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -342,6 +342,16 @@ memref.global "priate" constant @memref5 : memref<2xf32> = uninitialized
// -----
+// expected-error @+1 {{op initial value element expected to be of type 'f16', but was of type 'f32'}}
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf32>, sym_name = "memref6", sym_visibility = "private", type = memref<1xf16>}> : () -> ()
+
+// -----
+
+// expected-error @+1 {{op initial value shape expected to be 1, 2 but was 2, 2}}
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<2x2xf16>, sym_name = "memref7", sym_visibility = "private", type = memref<1x2xf16>}> : () -> ()
+
+// -----
+
func.func @nonexistent_global_memref() {
// expected-error @+1 {{'gv' does not reference a valid global memref}}
%0 = memref.get_global @gv : memref<3xf32>
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 13fdf3cf13510..e11de7bec2d0a 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -174,6 +174,9 @@ memref.global "private" @memref3 : memref<2xf32> = uninitialized
// CHECK-LABEL: memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
+// CHECK-LABEL: memref.global "private" constant @memref5 : memref<1xf16, 42 : i32> = dense<1.000000e+00>
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf16>, sym_name = "memref5", sym_visibility = "private", type = memref<1xf16, 42 : i32>}> : () -> ()
+
// CHECK-LABEL: func @read_global_memref
func.func @read_global_memref() {
%0 = memref.get_global @memref0 : memref<2xf32>
More information about the Mlir-commits
mailing list