[Mlir-commits] [mlir] 022e1e9 - [mlir][memref]: Fix Bug in GlobalOp Verifier (#144900)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 25 07:14:58 PDT 2025
Author: Jack Frankland
Date: 2025-06-25T15:14:55+01:00
New Revision: 022e1e99f3b017ac1baf8b65f5a48212c5fca2ae
URL: https://github.com/llvm/llvm-project/commit/022e1e99f3b017ac1baf8b65f5a48212c5fca2ae
DIFF: https://github.com/llvm/llvm-project/commit/022e1e99f3b017ac1baf8b65f5a48212c5fca2ae.diff
LOG: [mlir][memref]: Fix Bug in GlobalOp Verifier (#144900)
When comparing the type of the initializer in a `memref::GlobalOp`
against its result only consider the element type and the shape. Other
attributes such as memory space should be ignored since comparing these
between tensors and memrefs doesn't make sense and constructing a memref
in a specific memory space with a tensor that has no such attribute
should be valid.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d56b32193765e..372e83a98ee52 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1567,11 +1567,23 @@ LogicalResult GlobalOp::verify() {
// Check that the type of the initial value is compatible with the type of
// the global variable.
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
- Type initType = elementsAttr.getType();
- Type tensorType = getTensorTypeFromMemRefType(memrefType);
- if (initType != tensorType)
- return emitOpError("initial value expected to be of type ")
- << tensorType << ", but was of type " << initType;
+ // Check the element types match.
+ auto initElementType =
+ cast<TensorType>(elementsAttr.getType()).getElementType();
+ auto memrefElementType = memrefType.getElementType();
+
+ if (initElementType != memrefElementType)
+ return emitOpError("initial value element expected to be of type ")
+ << memrefElementType << ", but was of type " << initElementType;
+
+ // Check the shapes match, given that memref globals can only produce
+ // statically shaped memrefs and elements literal type must have a static
+ // shape we can assume both types are shaped.
+ auto initShape = elementsAttr.getShapedType().getShape();
+ auto memrefShape = memrefType.getShape();
+ if (initShape != memrefShape)
+ return emitOpError("initial value shape expected to be ")
+ << memrefShape << " but was " << initShape;
}
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f908efb638446..8e394b2ac04c8 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -342,6 +342,16 @@ memref.global "priate" constant @memref5 : memref<2xf32> = uninitialized
// -----
+// expected-error @+1 {{op initial value element expected to be of type 'f16', but was of type 'f32'}}
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf32>, sym_name = "memref6", sym_visibility = "private", type = memref<1xf16>}> : () -> ()
+
+// -----
+
+// expected-error @+1 {{op initial value shape expected to be 1, 2 but was 2, 2}}
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<2x2xf16>, sym_name = "memref7", sym_visibility = "private", type = memref<1x2xf16>}> : () -> ()
+
+// -----
+
func.func @nonexistent_global_memref() {
// expected-error @+1 {{'gv' does not reference a valid global memref}}
%0 = memref.get_global @gv : memref<3xf32>
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 13fdf3cf13510..e11de7bec2d0a 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -174,6 +174,9 @@ memref.global "private" @memref3 : memref<2xf32> = uninitialized
// CHECK-LABEL: memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
+// CHECK-LABEL: memref.global "private" constant @memref5 : memref<1xf16, 42 : i32> = dense<1.000000e+00>
+"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf16>, sym_name = "memref5", sym_visibility = "private", type = memref<1xf16, 42 : i32>}> : () -> ()
+
// CHECK-LABEL: func @read_global_memref
func.func @read_global_memref() {
%0 = memref.get_global @memref0 : memref<2xf32>
More information about the Mlir-commits
mailing list