[Mlir-commits] [mlir] 2dab551 - [mlir][spirv] Change numeric constant's bit-extension decision to be based on type
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jun 5 07:53:55 PDT 2023
Author: Md Abdullah Shahneous Bari
Date: 2023-06-05T10:52:09-04:00
New Revision: 2dab551d649cb57127f7b3bea183f069883d76e4
URL: https://github.com/llvm/llvm-project/commit/2dab551d649cb57127f7b3bea183f069883d76e4
DIFF: https://github.com/llvm/llvm-project/commit/2dab551d649cb57127f7b3bea183f069883d76e4.diff
LOG: [mlir][spirv] Change numeric constant's bit-extension decision to be based on type
Integer constants with bit width less than a word (e.g., i8, i16)
should be bit extended based on its type to be SPIR-V spec-compliant.
Previously, the decision was based on the most significant bit of the
value which ignores the signless semantics and causes problems when
interfacing with SPIR-V tools.
Dealing with numeric literals: the SPIR-V spec says, "If a numeric
type’s bit width is less than 32-bits, the value appears in the
low-order bits of the word, and the high-order bits must be 0 for
a floating-point type or integer type with Signedness of 0, or sign
extended for an integer type with a Signedness of 1 (similarly for the
remaining bits of widths larger than 32 bits but not a multiple of 32
bits)."
Therefore, signless integers (e.g., i8, i16) and unsigned integers
should be 0-extended, and signed integers (e.g., si8, si16) should be
sign-extended.
Patch By: mshahneo
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D151767
Added:
Modified:
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Target/SPIRV/constant.mlir
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index f32f6e8242063..1ef8ff043e690 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -846,8 +846,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
auto resultID = getNextID();
APInt value = intAttr.getValue();
unsigned bitwidth = value.getBitWidth();
- bool isSigned = value.isSignedIntN(bitwidth);
-
+ bool isSigned = intAttr.getType().isSignedInteger();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 7e3ae2a563099..f3950214a7f05 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -264,4 +264,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.Constant dense<1> : tensor<2x2x3xi32> : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32, stride=4>, stride=12>, stride=24>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32, stride=4>, stride=12>, stride=24>
}
+
+ // CHECK-LABEL: @signless_int_const_bit_extension
+ spirv.func @signless_int_const_bit_extension() -> (i16) "None" {
+ // CHECK: spirv.Constant -1 : i16
+ %signless_minus_one = spirv.Constant -1 : i16
+ spirv.ReturnValue %signless_minus_one : i16
+ }
+ // CHECK-LABEL: @signed_int_const_bit_extension
+ spirv.func @signed_int_const_bit_extension() -> (si16) "None" {
+ // CHECK: spirv.Constant -1 : si16
+ %signed_minus_one = spirv.Constant -1 : si16
+ spirv.ReturnValue %signed_minus_one : si16
+ }
}
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index f7a1db0749c57..56a98cc205ab4 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -76,6 +76,27 @@ class SerializationTest : public ::testing::Test {
builder.getStringAttr(name), nullptr);
}
+ // Inserts an Integer or a Vector of Integers constant of value 'val'.
+ spirv::ConstantOp AddConstInt(Type type, APInt val) {
+ OpBuilder builder(module->getRegion());
+ auto loc = UnknownLoc::get(&context);
+
+ if (auto intType = dyn_cast<IntegerType>(type)) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type, builder.getIntegerAttr(type, val));
+ }
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
+ Type elemType = vectorType.getElementType();
+ if (auto intType = dyn_cast<IntegerType>(elemType)) {
+ return builder.create<spirv::ConstantOp>(
+ loc, type,
+ DenseElementsAttr::get(vectorType,
+ IntegerAttr::get(elemType, val).getValue()));
+ }
+ }
+ llvm_unreachable("unimplemented types for AddConstInt()");
+ }
+
/// Handles a SPIR-V instruction with the given `opcode` and `operand`.
/// Returns true to interrupt.
using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
@@ -149,6 +170,34 @@ TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
EXPECT_EQ(count, 1u);
}
+TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
+
+ auto signlessInt16Type =
+ IntegerType::get(&context, 16, IntegerType::Signless);
+ auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
+ // Check the bit extension of same value under
diff erent signedness semantics.
+ APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
+ signlessInt16Type.getSignedness());
+ APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
+ signedInt16Type.getSignedness());
+
+ AddConstInt(signlessInt16Type, signlessIntConstVal);
+ AddConstInt(signedInt16Type, signedIntConstVal);
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
+ auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+ return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
+ operands[2] == 65535;
+ };
+ EXPECT_TRUE(scanInstruction(hasSignlessVal));
+
+ auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+ return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
+ operands[2] == 4294967295;
+ };
+ EXPECT_TRUE(scanInstruction(hasSignedVal));
+}
+
TEST_F(SerializationTest, ContainsSymbolName) {
auto structType = getFloatStructType();
addGlobalVar(structType, "var0");
More information about the Mlir-commits
mailing list