[Mlir-commits] [mlir] 688551f - [mlir][spirv] Fix serialization of TensorARM with rank higher than one (#152391)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 8 10:20:30 PDT 2025


Author: Mohammadreza Ameri Mahabadian
Date: 2025-08-08T13:20:28-04:00
New Revision: 688551f05cf5f6c90e0f5edc890ee13bb563fa95

URL: https://github.com/llvm/llvm-project/commit/688551f05cf5f6c90e0f5edc890ee13bb563fa95
DIFF: https://github.com/llvm/llvm-project/commit/688551f05cf5f6c90e0f5edc890ee13bb563fa95.diff

LOG: [mlir][spirv] Fix serialization of TensorARM with rank higher than one (#152391)

This PR fixes #152012 where serialization of TensorARM values into
OpConstantComposite resulted in invalid binary.

---------

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>

Added: 
    

Modified: 
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Target/SPIRV/arm-tensor-constant.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c967e863554fc..d8c54ec5f88c3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
-  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+  if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+    SmallVector<Attribute> flattenedElems;
+    for (Attribute element : elements) {
+      if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
+        for (auto value : denseElemAttr.getValues<Attribute>())
+          flattenedElems.push_back(value);
+      } else {
+        flattenedElems.push_back(element);
+      }
+    }
+    auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
+    constantMap.try_emplace(resultID, attr, tensorType);
+  } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
     auto attr = DenseElementsAttr::get(shapedType, elements);
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c049574fbc9e3..7c007de315589 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
   uint32_t resultID = getNextID();
   SmallVector<uint32_t, 4> operands = {typeID, resultID};
   auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
+  if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
+    ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
+    if (!innerShape.empty())
+      elementType = spirv::TensorArmType::get(innerShape, elementType);
+  }
 
   // "If the Result Type is a cooperative matrix type, then there must be only
   // one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
     } else {
       return 0;
     }
-  } else if (isa<spirv::TensorArmType>(constType)) {
-    if (isZeroValue(valueAttr)) {
-      encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
-                            {typeID, resultID});
-      return resultID;
-    }
-    numberOfConstituents = shapedType.getNumElements();
-    operands.reserve(numberOfConstituents + 2);
-    for (int i = 0; i < numberOfConstituents; ++i) {
-      uint32_t elementID = 0;
-      if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
-        elementID =
-            elementType.isInteger(1)
-                ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
-                : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
-      }
-      if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
-        elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
-      }
-      if (!elementID) {
-        return 0;
-      }
-      operands.push_back(elementID);
-    }
+  } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                          {typeID, resultID});
+    return resultID;
   } else {
     operands.reserve(numberOfConstituents + 2);
     for (int i = 0; i < numberOfConstituents; ++i) {

diff  --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
index 275e586f70634..7fb8af1904388 100644
--- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
+++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
@@ -1,17 +1,36 @@
 // RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
-// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
-
-// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests.
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
 
 spirv.module Logical Vulkan requires #spirv.vce<v1.3,
              [VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> {
-  // CHECK-LABEL: @arm_tensor_of_i32
-  spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+  // CHECK-LABEL: @rank_1_arm_tensor_of_i32
+  spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+    %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32>
+  }
+
+  // CHECK-LABEL: @rank_2_arm_tensor_of_i32
+  spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
     %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @rank_3_arm_tensor_of_i32
+  spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+    %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32>
+  }
+
+  // CHECK-LABEL: @rank_4_arm_tensor_of_i32
+  spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+    %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32>
+  }
+
   // CHECK-LABEL: @splat_arm_tensor_of_i32
   spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
@@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3,
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
-  // CHECK-LABEL: @arm_tensor_of_f32
-  spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+  // CHECK-LABEL: @rank_1_arm_tensor_of_f32
+  spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32>
+    %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+  }
+
+  // CHECK-LABEL: @rank_2_arm_tensor_of_f32
+  spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
-    %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @rank_3_arm_tensor_of_f32
+  spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32>
+    %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32>
+  }
+
+  // CHECK-LABEL: @rank_4_arm_tensor_of_f32
+  spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32>
+    %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32>
+  }
+
   // CHECK-LABEL: @splat_arm_tensor_of_f32
   spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>


        


More information about the Mlir-commits mailing list