[Mlir-commits] [mlir] 48ea3e4 - [mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (#151485)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 1 05:21:15 PDT 2025
Author: Mohammadreza Ameri Mahabadian
Date: 2025-08-01T08:21:11-04:00
New Revision: 48ea3e425bb97b73a2c3117fdcd55fe96314ab1c
URL: https://github.com/llvm/llvm-project/commit/48ea3e425bb97b73a2c3117fdcd55fe96314ab1c
DIFF: https://github.com/llvm/llvm-project/commit/48ea3e425bb97b73a2c3117fdcd55fe96314ab1c.diff
LOG: [mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (#151485)
…tNull
This patch enables (de)serialization to/from OpConstantNull for null
TensorARM
---------
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Added:
Modified:
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Target/SPIRV/constant.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 88931b53a6889..d0ae5132252ff 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1779,7 +1779,7 @@ LogicalResult
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc,
- "OpConstantNull must have type <id> and result <id>");
+ "OpConstantNull must only have type <id> and result <id>");
}
Type resultType = getType(operands[0]);
@@ -1789,8 +1789,15 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
+ Attribute attr;
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
- auto attr = opBuilder.getZeroAttr(resultType);
+ attr = opBuilder.getZeroAttr(resultType);
+ } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+ if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
+ attr = DenseElementsAttr::get(tensorType, element);
+ }
+
+ if (attr) {
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 737f29662f64b..59665ec1add54 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,25 @@ static Block *getPhiIncomingBlock(Block *block) {
return block;
}
+static bool isZeroValue(Attribute attr) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ return floatAttr.getValue().isZero();
+ }
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+ return !boolAttr.getValue();
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ return intAttr.getValue().isZero();
+ }
+ if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
+ return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
+ }
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+ return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
+ }
+ return false;
+}
+
namespace mlir {
namespace spirv {
@@ -959,6 +978,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
return 0;
}
} else if (isa<spirv::TensorArmType>(constType)) {
+ if (isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
+ }
numberOfConstituents = shapedType.getNumElements();
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1226,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
}
uint32_t resultID = getNextID();
- uint32_t operands[] = {typeID, resultID, constandID};
-
- encodeInstructionInto(typesGlobalValues,
- spirv::Opcode::OpConstantCompositeReplicateEXT,
- operands);
+ if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ } else {
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpConstantCompositeReplicateEXT,
+ {typeID, resultID, constandID});
+ }
constCompositeReplicateIDMap[valueTypePair] = resultID;
return resultID;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 1695d2a6a2eb4..3be49eefcaebf 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @null_arm_tensor_of_i32
+ spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @null_arm_tensor_of_f32
+ spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
spirv.EntryPoint "GLCompute" @bool_const
}
@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @null_cc_arm_tensor_of_i32
+ spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+
+ // CHECK-LABEL: @null_cc_arm_tensor_of_f32
+ spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
}
More information about the Mlir-commits
mailing list