[llvm] [SPIR-V] Fix bitcast legalization/instruction selection in SPIR-V Backend (PR #83139)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 27 07:37:57 PST 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/83139
>From fc958ece3e35e37eacb72407bda784cc57d68841 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 27 Feb 2024 06:31:28 -0800
Subject: [PATCH 1/2] fix bitcast legalization/instruction selection
---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 46 ++++++++++++++++---
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 17 ++++++-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 15 +++++-
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 9 ++--
4 files changed, 73 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a1cb630f1aa477..64af253e25a1e1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -834,15 +834,49 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
}
-bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
+ const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
+ unsigned NumElements = 1;
if (Type->getOpcode() == SPIRV::OpTypeVector) {
- auto EleTypeReg = Type->getOperand(1).getReg();
- Type = getSPIRVTypeForVReg(EleTypeReg);
+ NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
+ Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
}
- if (Type->getOpcode() == SPIRV::OpTypeInt)
- return Type->getOperand(2).getImm() != 0;
- llvm_unreachable("Attempting to get sign of non-integer type.");
+ return Type->getOpcode() == SPIRV::OpTypeInt ||
+ Type->getOpcode() == SPIRV::OpTypeFloat
+ ? NumElements * Type->getOperand(1).getImm()
+ : 0;
+}
+
+const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
+ const SPIRVType *Type) const {
+ if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
+ Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
+ return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
+}
+
+bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+ const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
+ return IntType && IntType->getOperand(2).getImm() != 0;
+}
+
+bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
+ const SPIRVType *Type2) const {
+ if (!Type1 || !Type2)
+ return false;
+ auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
+ // Ignore difference between <1.5 and >=1.5 protocol versions:
+ // it's valid if either Result Type or Operand is a pointer, and the other
+ // is a pointer, an integer scalar, or an integer vector.
+ if (Op1 == SPIRV::OpTypePointer &&
+ (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
+ return true;
+ if (Op2 == SPIRV::OpTypePointer &&
+ (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
+ return true;
+ unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
+ Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
+ return Bits1 > 0 && Bits1 == Bits2;
}
SPIRV::StorageClass::StorageClass
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..151af19fcef2ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,9 +197,19 @@ class SPIRVGlobalRegistry {
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
- // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
+ // For vectors or scalars of booleans, integers and floats, return the scalar
+ // type's bitwidth. Otherwise calls llvm_unreachable().
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
+ // For vectors or scalars of integers and floats, return total bitwidth of the
+ // argument. Otherwise returns 0.
+ unsigned getNumScalarOrVectorTotalBitWidth(const SPIRVType *Type) const;
+
+ // Returns either pointer to integer type, that may be a type of vector
+ // elements or an original type, or nullptr if the argument is niether
+ // an integer scalar, nor an integer vector
+ const SPIRVType *retrieveScalarOrVectorIntType(const SPIRVType *Type) const;
+
// For integer vectors or scalars, return whether the integers are signed.
bool isScalarOrVectorSigned(const SPIRVType *Type) const;
@@ -209,6 +219,11 @@ class SPIRVGlobalRegistry {
// Return the number of bits SPIR-V pointers and size_t variables require.
unsigned getPointerSize() const { return PointerSize; }
+ // Returns true if two types are defined and are compatible in a sense of
+ // OpBitcast instruction
+ bool isBitcastCompatible(const SPIRVType *Type1,
+ const SPIRVType *Type2) const;
+
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 9b38073ec3bcf7..2958b73ba00ab5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -95,6 +95,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
unsigned Opcode) const;
+ bool selectBitcast(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectLoad(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectStore(MachineInstr &I) const;
@@ -449,7 +452,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_INTTOPTR:
return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr);
case TargetOpcode::G_BITCAST:
- return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+ return selectBitcast(ResVReg, ResType, I);
case TargetOpcode::G_ADDRSPACE_CAST:
return selectAddrSpaceCast(ResVReg, ResType, I);
case TargetOpcode::G_PTR_ADD: {
@@ -586,6 +589,16 @@ bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
Opcode);
}
+bool SPIRVInstructionSelector::selectBitcast(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ Register OpReg = I.getOperand(1).getReg();
+ SPIRVType *OpType = OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
+ if (!GR.isBitcastCompatible(ResType, OpType))
+ report_fatal_error("incompatible result and operand types in a bitcast");
+ return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+}
+
static SPIRV::Scope::Scope getScope(SyncScope::ID Ord) {
switch (Ord) {
case SyncScope::SingleThread:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 049ca4ac818c4e..b0f94d0ab2f7d9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -177,12 +177,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
- getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
- typeInSet(0, allPtrsScalarsAndVectors),
- typeInSet(1, allPtrsScalarsAndVectors),
- LegalityPredicate(([=](const LegalityQuery &Query) {
- return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
- }))));
+ getActionDefinitionsBuilder(G_BITCAST).legalIf(
+ all(typeInSet(0, allPtrsScalarsAndVectors),
+ typeInSet(1, allPtrsScalarsAndVectors)));
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
>From f654a437818c4a357636d4f11b3c94c72c8417ed Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 27 Feb 2024 07:37:43 -0800
Subject: [PATCH 2/2] add a test case
---
llvm/test/CodeGen/SPIRV/bitcast.ll | 21 +++++++++++++++++++++
1 file changed, 21 insertions(+)
create mode 100644 llvm/test/CodeGen/SPIRV/bitcast.ll
diff --git a/llvm/test/CodeGen/SPIRV/bitcast.ll b/llvm/test/CodeGen/SPIRV/bitcast.ll
new file mode 100644
index 00000000000000..242c5a46583c22
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast.ll
@@ -0,0 +1,21 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#TyInt32:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#TyInt16:]] = OpTypeInt 16 0
+; CHECK-SPIRV-DAG: %[[#TyHalf:]] = OpTypeFloat 16
+; CHECK-SPIRV-DAG: %[[#Arg32:]] = OpFunctionParameter %[[#TyInt32]]
+; CHECK-SPIRV-DAG: %[[#Arg16:]] = OpUConvert %[[#TyInt16]] %[[#Arg32]]
+; CHECK-SPIRV-DAG: %[[#ValHalf:]] = OpBitcast %[[#TyHalf]] %8
+; CHECK-SPIRV-DAG: %[[#ValHalf2:]] = OpFMul %[[#TyHalf]] %[[#ValHalf]] %[[#ValHalf]]
+; CHECK-SPIRV-DAG: %[[#Res16:]] = OpBitcast %[[#TyInt16]] %[[#ValHalf2]]
+; CHECK-SPIRV-DAG: OpReturnValue %[[#Res16]]
+
+define i16 @foo(i32 %arg) {
+entry:
+ %op16 = trunc i32 %arg to i16
+ %val = bitcast i16 %op16 to half
+ %val2 = fmul half %val, %val
+ %res = bitcast half %val2 to i16
+ ret i16 %res
+}
More information about the llvm-commits
mailing list