[Mlir-commits] [mlir] 48ea3e4 - [mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (#151485)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 05:21:15 PDT 2025


Author: Mohammadreza Ameri Mahabadian
Date: 2025-08-01T08:21:11-04:00
New Revision: 48ea3e425bb97b73a2c3117fdcd55fe96314ab1c

URL: https://github.com/llvm/llvm-project/commit/48ea3e425bb97b73a2c3117fdcd55fe96314ab1c
DIFF: https://github.com/llvm/llvm-project/commit/48ea3e425bb97b73a2c3117fdcd55fe96314ab1c.diff

LOG: [mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (#151485)

…tNull

This patch enables (de)serialization to/from OpConstantNull for null
TensorARM

---------

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>

Added: 
    

Modified: 
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Target/SPIRV/constant.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..d0ae5132252ff 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
-                     "OpConstantNull must have type <id> and result <id>");
+                     "OpConstantNull must only have type <id> and result <id>");
   }
 
   Type resultType = getType(operands[0]);
@@ -1789,8 +1789,15 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
+  Attribute attr;
   if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
-    auto attr = opBuilder.getZeroAttr(resultType);
+    attr = opBuilder.getZeroAttr(resultType);
+  } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+    if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
+      attr = DenseElementsAttr::get(tensorType, element);
+  }
+
+  if (attr) {
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
     constantMap.try_emplace(resultID, attr, resultType);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..59665ec1add54 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,25 @@ static Block *getPhiIncomingBlock(Block *block) {
   return block;
 }
 
+static bool isZeroValue(Attribute attr) {
+  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+    return floatAttr.getValue().isZero();
+  }
+  if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+    return !boolAttr.getValue();
+  }
+  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+    return intAttr.getValue().isZero();
+  }
+  if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
+    return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
+  }
+  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+    return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
+  }
+  return false;
+}
+
 namespace mlir {
 namespace spirv {
 
@@ -959,6 +978,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
       return 0;
     }
   } else if (isa<spirv::TensorArmType>(constType)) {
+    if (isZeroValue(valueAttr)) {
+      encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                            {typeID, resultID});
+      return resultID;
+    }
     numberOfConstituents = shapedType.getNumElements();
     operands.reserve(numberOfConstituents + 2);
     for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1226,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
   }
 
   uint32_t resultID = getNextID();
-  uint32_t operands[] = {typeID, resultID, constandID};
-
-  encodeInstructionInto(typesGlobalValues,
-                        spirv::Opcode::OpConstantCompositeReplicateEXT,
-                        operands);
+  if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                          {typeID, resultID});
+  } else {
+    encodeInstructionInto(typesGlobalValues,
+                          spirv::Opcode::OpConstantCompositeReplicateEXT,
+                          {typeID, resultID, constandID});
+  }
 
   constCompositeReplicateIDMap[valueTypePair] = resultID;
   return resultID;

diff  --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @null_arm_tensor_of_i32
+  spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
+  // CHECK-LABEL: @null_arm_tensor_of_f32
+  spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
+
   spirv.EntryPoint "GLCompute" @bool_const
 }
 
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+  spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+    // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+  }
+
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
+
+  // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+  spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+    // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+  }
 }


        


More information about the Mlir-commits mailing list