[Mlir-commits] [mlir] 688551f - [mlir][spirv] Fix serialization of TensorARM with rank higher than one (#152391)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 8 10:20:30 PDT 2025
Author: Mohammadreza Ameri Mahabadian
Date: 2025-08-08T13:20:28-04:00
New Revision: 688551f05cf5f6c90e0f5edc890ee13bb563fa95
URL: https://github.com/llvm/llvm-project/commit/688551f05cf5f6c90e0f5edc890ee13bb563fa95
DIFF: https://github.com/llvm/llvm-project/commit/688551f05cf5f6c90e0f5edc890ee13bb563fa95.diff
LOG: [mlir][spirv] Fix serialization of TensorARM with rank higher than one (#152391)
This PR fixes #152012 where serialization of TensorARM values into
OpConstantComposite resulted in invalid binary.
---------
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Added:
Modified:
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Target/SPIRV/arm-tensor-constant.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c967e863554fc..d8c54ec5f88c3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+ if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+ SmallVector<Attribute> flattenedElems;
+ for (Attribute element : elements) {
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
+ for (auto value : denseElemAttr.getValues<Attribute>())
+ flattenedElems.push_back(value);
+ } else {
+ flattenedElems.push_back(element);
+ }
+ }
+ auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
+ constantMap.try_emplace(resultID, attr, tensorType);
+ } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
auto attr = DenseElementsAttr::get(shapedType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c049574fbc9e3..7c007de315589 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
+ if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
+ ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
+ if (!innerShape.empty())
+ elementType = spirv::TensorArmType::get(innerShape, elementType);
+ }
// "If the Result Type is a cooperative matrix type, then there must be only
// one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
- } else if (isa<spirv::TensorArmType>(constType)) {
- if (isZeroValue(valueAttr)) {
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
- {typeID, resultID});
- return resultID;
- }
- numberOfConstituents = shapedType.getNumElements();
- operands.reserve(numberOfConstituents + 2);
- for (int i = 0; i < numberOfConstituents; ++i) {
- uint32_t elementID = 0;
- if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
- elementID =
- elementType.isInteger(1)
- ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
- : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
- }
- if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
- elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
- }
- if (!elementID) {
- return 0;
- }
- operands.push_back(elementID);
- }
+ } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
index 275e586f70634..7fb8af1904388 100644
--- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
+++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
@@ -1,17 +1,36 @@
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
-// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
-
-// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests.
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
spirv.module Logical Vulkan requires #spirv.vce<v1.3,
[VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> {
- // CHECK-LABEL: @arm_tensor_of_i32
- spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK-LABEL: @rank_1_arm_tensor_of_i32
+ spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+ %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32>
+ }
+
+ // CHECK-LABEL: @rank_2_arm_tensor_of_i32
+ spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
%0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @rank_3_arm_tensor_of_i32
+ spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+ %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32>
+ }
+
+ // CHECK-LABEL: @rank_4_arm_tensor_of_i32
+ spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+ %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32>
+ }
+
// CHECK-LABEL: @splat_arm_tensor_of_i32
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
@@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3,
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
- // CHECK-LABEL: @arm_tensor_of_f32
- spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK-LABEL: @rank_1_arm_tensor_of_f32
+ spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32>
+ %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+ }
+
+ // CHECK-LABEL: @rank_2_arm_tensor_of_f32
+ spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
- %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @rank_3_arm_tensor_of_f32
+ spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32>
+ %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32>
+ }
+
+ // CHECK-LABEL: @rank_4_arm_tensor_of_f32
+ spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32>
+ %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32>
+ }
+
// CHECK-LABEL: @splat_arm_tensor_of_f32
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
More information about the Mlir-commits
mailing list