[Mlir-commits] [mlir] [mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (PR #151485)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 31 03:08:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mohammadreza Ameri Mahabadian (mahabadm)

<details>
<summary>Changes</summary>

…tNull

This patch enables (de)serialization to/from OpConstantNull for null TensorARM

---
Full diff: https://github.com/llvm/llvm-project/pull/151485.diff


3 Files Affected:

- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+11-2) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+29-5) 
- (modified) mlir/test/Target/SPIRV/constant.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..333046a8e5d6f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
-                     "OpConstantNull must have type <id> and result <id>");
+                     "OpConstantNull must only have type <id> and result <id>");
   }
 
   Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
+  Attribute attr;
   if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
-    auto attr = opBuilder.getZeroAttr(resultType);
+    attr = opBuilder.getZeroAttr(resultType);
+  } else if (isa<TensorArmType>(resultType)) {
+    auto shapedType = cast<ShapedType>(resultType);
+    auto element = opBuilder.getZeroAttr(shapedType.getElementType());
+    if (element)
+      attr = DenseElementsAttr::get(shapedType, element);
+  }
+
+  if (attr) {
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
     constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..3ef9a89a3ca62 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
   return block;
 }
 
+static bool isNull(Attribute attr) {
+  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+    return floatAttr.getValue().isZero();
+  }
+  if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+    return !boolAttr.getValue();
+  }
+  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+    return intAttr.getValue().isZero();
+  }
+  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+    return all_of(denseElemAttr.getValues<Attribute>(), isNull);
+  }
+  return false;
+}
+
 namespace mlir {
 namespace spirv {
 
@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
       return 0;
     }
   } else if (isa<spirv::TensorArmType>(constType)) {
+    if (isNull(valueAttr)) {
+      encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                            {typeID, resultID});
+      return resultID;
+    }
     numberOfConstituents = shapedType.getNumElements();
     operands.reserve(numberOfConstituents + 2);
     for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
   }
 
   uint32_t resultID = getNextID();
-  uint32_t operands[] = {typeID, resultID, constandID};
-
-  encodeInstructionInto(typesGlobalValues,
-                        spirv::Opcode::OpConstantCompositeReplicateEXT,
-                        operands);
+  if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                          {typeID, resultID});
+  } else {
+    encodeInstructionInto(typesGlobalValues,
+                          spirv::Opcode::OpConstantCompositeReplicateEXT,
+                          {typeID, resultID, constandID});
+  }
 
   constCompositeReplicateIDMap[valueTypePair] = resultID;
   return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @null_arm_tensor_of_i32
+  spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
+  // CHECK-LABEL: @null_arm_tensor_of_f32
+  spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+
   spirv.EntryPoint "GLCompute" @bool_const
 }
 
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+  spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
+
+  // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+  spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/151485


More information about the Mlir-commits mailing list