[Mlir-commits] [mlir] [mlir][spirv] Fix verification and serialization replicated constant … (PR #151168)
Mohammadreza Ameri Mahabadian
llvmlistbot at llvm.org
Tue Jul 29 08:04:19 PDT 2025
https://github.com/mahabadm created https://github.com/llvm/llvm-project/pull/151168
…composites of multi-dimensional array
This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array.
>From c246a1a4ab86aec86585356aede89b31ac4309ea Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Fri, 18 Jul 2025 14:04:15 +0100
Subject: [PATCH] [mlir][spirv] Fix verification and serialization replicated
constant composites of multi-dimensional array
This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array.
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 25 +++++++++++--------
.../Target/SPIRV/Serialization/Serializer.cpp | 25 +++++++++++--------
mlir/test/Target/SPIRV/constant.mlir | 14 +++++++++++
3 files changed, 42 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 52c672a05fa43..c8b87fad8ccad 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -767,19 +767,22 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
// spirv.EXTConstantCompositeReplicate
//===----------------------------------------------------------------------===//
+static Type getValueType(Attribute attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ return typedAttr.getType();
+ }
+
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+ }
+
+ return nullptr;
+}
+
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
- Type valueType;
- if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
- valueType = typedAttr.getType();
- } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
- auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
- if (!typedElemAttr)
- return emitError("value attribute is not typed");
- valueType =
- spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
- } else {
+ Type valueType = getValueType(getValue());
+ if (!valueType)
return emitError("unknown value attribute type");
- }
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
if (!compositeType)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index a8a2b2e7cf38c..9e81b6ca505a2 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1124,6 +1124,18 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
return resultID;
}
+static Type getValueType(Attribute attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ return typedAttr.getType();
+ }
+
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+ }
+
+ return nullptr;
+}
+
uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
Type resultType,
Attribute valueAttr) {
@@ -1137,18 +1149,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
return 0;
}
- Type valueType;
- if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
- valueType = typedAttr.getType();
- } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
- auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
- if (!typedElemAttr)
- return 0;
- valueType =
- spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
- } else {
+ Type valueType = getValueType(valueAttr);
+ if (!valueAttr)
return 0;
- }
auto compositeType = dyn_cast<CompositeType>(resultType);
if (!compositeType)
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 6aca11ec5e6e6..e06c0146d4ad2 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -363,6 +363,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_i32
+ spirv.func @array_of_splat_array_of_non_splat_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+ }
+
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -411,4 +418,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+
+ // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_f32
+ spirv.func @array_of_splat_array_of_non_splat_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+ }
}
More information about the Mlir-commits
mailing list