[Mlir-commits] [mlir] [mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (PR #151485)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 31 03:08:18 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Mohammadreza Ameri Mahabadian (mahabadm)
<details>
<summary>Changes</summary>
…tNull
This patch enables (de)serialization to/from OpConstantNull for null TensorARM
---
Full diff: https://github.com/llvm/llvm-project/pull/151485.diff
3 Files Affected:
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+11-2)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+29-5)
- (modified) mlir/test/Target/SPIRV/constant.mlir (+28)
``````````diff
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..333046a8e5d6f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc,
- "OpConstantNull must have type <id> and result <id>");
+ "OpConstantNull must only have type <id> and result <id>");
}
Type resultType = getType(operands[0]);
@@ -1789,8 +1789,17 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
+ Attribute attr;
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
- auto attr = opBuilder.getZeroAttr(resultType);
+ attr = opBuilder.getZeroAttr(resultType);
+ } else if (isa<TensorArmType>(resultType)) {
+ auto shapedType = cast<ShapedType>(resultType);
+ auto element = opBuilder.getZeroAttr(shapedType.getElementType());
+ if (element)
+ attr = DenseElementsAttr::get(shapedType, element);
+ }
+
+ if (attr) {
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..3ef9a89a3ca62 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,22 @@ static Block *getPhiIncomingBlock(Block *block) {
return block;
}
+static bool isNull(Attribute attr) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ return floatAttr.getValue().isZero();
+ }
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+ return !boolAttr.getValue();
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ return intAttr.getValue().isZero();
+ }
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+ return all_of(denseElemAttr.getValues<Attribute>(), isNull);
+ }
+ return false;
+}
+
namespace mlir {
namespace spirv {
@@ -959,6 +975,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
return 0;
}
} else if (isa<spirv::TensorArmType>(constType)) {
+ if (isNull(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
+ }
numberOfConstituents = shapedType.getNumElements();
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1223,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
}
uint32_t resultID = getNextID();
- uint32_t operands[] = {typeID, resultID, constandID};
-
- encodeInstructionInto(typesGlobalValues,
- spirv::Opcode::OpConstantCompositeReplicateEXT,
- operands);
+ if (dyn_cast<spirv::TensorArmType>(resultType) && isNull(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ } else {
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpConstantCompositeReplicateEXT,
+ {typeID, resultID, constandID});
+ }
constCompositeReplicateIDMap[valueTypePair] = resultID;
return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @null_arm_tensor_of_i32
+ spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @null_arm_tensor_of_f32
+ spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
spirv.EntryPoint "GLCompute" @bool_const
}
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+ spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+
+ // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+ spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/151485
More information about the Mlir-commits
mailing list