[llvm] [SPIR-V] Fix bitcast legalization/instruction selection in SPIR-V Backend (PR #83139)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 27 07:42:18 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR is to fix a way how SPIR-V Backend describes legality of OpBitcast instruction and how it is validated on a step of instruction selection. Instead of checking a size of virtual registers (that makes no sense due to lack of guarantee of direct relations between size of virtual register and bit width associated with the type size), this PR allows to legalize OpBitcast without size check and postpones validation to the instruction selection step.
As an example, let's consider the next example that was copied as is from a bigger test suite:
```
%355:id(s16) = G_BITCAST %301:id(s32)
%303:id(s16) = ASSIGN_TYPE %355:id(s16), %349:type(s32)
...
%644:fid(s32) = G_FMUL %645:fid, %646:fid
%645:fid(s32) = GET_fID %297:id(s16)
%646:fid(s32) = GET_fID %287:id(s16)
%301:id(s32) = ASSIGN_TYPE %644:fid(s32), %40:type(s32)
```
Without the PR this leads to a crash with complains to an illegal bitcast, because %355 is s16 and %301 is s32. However, we must check not virtual registers in this case, but types of %355 and %301, i.e., %349:type(s32) and %40:type(s32), which are perfectly well compatible in a sense of OpBitcast in this case.
In a test case that is a part of this PR the corresponding code looks like the following:
```
# *** IR Dump After Legalizer (legalizer) ***:
# Machine code for function foo: IsSSA, TracksLiveness, Legalized
bb.1.entry:
%1:type(s32) = OpTypeInt 32, 0
OpName %0:id(s32), 6779489
%3:type(s32) = OpTypeInt 16, 0
%4:type(s32) = OpTypeFunction %3:type(s32), %1:type(s32)
%2:id(s32) = OpFunction %3:type(s32), 0, %4:type(s32)
%0:id(s32) = OpFunctionParameter %1:type(s32)
OpName %2:id(s32), 7303014
OpDecorate %2:id(s32), 41, 7303014, 0
%18:id(s16) = G_TRUNC %0:id(s32)
%5:id(s16) = ASSIGN_TYPE %18:id(s16), %3:type(s32)
G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@<!-- -->llvm.spv.assign.name), %5:id(s16), 909209711, 0
%17:id(s16) = G_BITCAST %5:id(s16)
%8:id(s16) = ASSIGN_TYPE %17:id(s16), %16:type(s32)
G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@<!-- -->llvm.spv.assign.name), %8:id(s16), 7102838
%19:fid(s32) = G_FMUL %20:fid, %21:fid
%20:fid(s32) = GET_fID %8:id(s16)
%21:fid(s32) = GET_fID %8:id(s16)
%16:type(s32) = OpTypeFloat 16
%10:id(s32) = ASSIGN_TYPE %19:fid(s32), %16:type(s32)
G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@<!-- -->llvm.spv.assign.name), %10:id(s32), 845963638, 0
%14:anyid(s16) = G_BITCAST %10:id(s32)
%12:anyid(s16) = ASSIGN_TYPE %14:anyid(s16), %3:type(s32)
G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@<!-- -->llvm.spv.assign.name), %12:anyid(s16), 7562610
OpReturnValue %12:anyid(s16)
```
Here OpBitcast is legal, being applied for `OpTypeInt 16` and `OpTypeFloat 16`, but would not be legalized without this PR due to virtual registers defined as `%14:anyid(s16) = G_BITCAST %10:id(s32)`.
---
Full diff: https://github.com/llvm/llvm-project/pull/83139.diff
5 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+40-6)
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+16-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+14-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+3-6)
- (added) llvm/test/CodeGen/SPIRV/bitcast.ll (+21)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a1cb630f1aa477..64af253e25a1e1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -834,15 +834,49 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
}
-bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
+ const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
+ unsigned NumElements = 1;
if (Type->getOpcode() == SPIRV::OpTypeVector) {
- auto EleTypeReg = Type->getOperand(1).getReg();
- Type = getSPIRVTypeForVReg(EleTypeReg);
+ NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
+ Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
}
- if (Type->getOpcode() == SPIRV::OpTypeInt)
- return Type->getOperand(2).getImm() != 0;
- llvm_unreachable("Attempting to get sign of non-integer type.");
+ return Type->getOpcode() == SPIRV::OpTypeInt ||
+ Type->getOpcode() == SPIRV::OpTypeFloat
+ ? NumElements * Type->getOperand(1).getImm()
+ : 0;
+}
+
+const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
+ const SPIRVType *Type) const {
+ if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
+ Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
+ return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
+}
+
+bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+ const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
+ return IntType && IntType->getOperand(2).getImm() != 0;
+}
+
+bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
+ const SPIRVType *Type2) const {
+ if (!Type1 || !Type2)
+ return false;
+ auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
+ // Ignore difference between <1.5 and >=1.5 protocol versions:
+ // it's valid if either Result Type or Operand is a pointer, and the other
+ // is a pointer, an integer scalar, or an integer vector.
+ if (Op1 == SPIRV::OpTypePointer &&
+ (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
+ return true;
+ if (Op2 == SPIRV::OpTypePointer &&
+ (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
+ return true;
+ unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
+ Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
+ return Bits1 > 0 && Bits1 == Bits2;
}
SPIRV::StorageClass::StorageClass
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..151af19fcef2ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,9 +197,19 @@ class SPIRVGlobalRegistry {
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
- // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
+ // For vectors or scalars of booleans, integers and floats, return the scalar
+ // type's bitwidth. Otherwise calls llvm_unreachable().
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
+ // For vectors or scalars of integers and floats, return total bitwidth of the
+ // argument. Otherwise returns 0.
+ unsigned getNumScalarOrVectorTotalBitWidth(const SPIRVType *Type) const;
+
+ // Returns either pointer to integer type, that may be a type of vector
+ // elements or an original type, or nullptr if the argument is niether
+ // an integer scalar, nor an integer vector
+ const SPIRVType *retrieveScalarOrVectorIntType(const SPIRVType *Type) const;
+
// For integer vectors or scalars, return whether the integers are signed.
bool isScalarOrVectorSigned(const SPIRVType *Type) const;
@@ -209,6 +219,11 @@ class SPIRVGlobalRegistry {
// Return the number of bits SPIR-V pointers and size_t variables require.
unsigned getPointerSize() const { return PointerSize; }
+ // Returns true if two types are defined and are compatible in a sense of
+ // OpBitcast instruction
+ bool isBitcastCompatible(const SPIRVType *Type1,
+ const SPIRVType *Type2) const;
+
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 9b38073ec3bcf7..2958b73ba00ab5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -95,6 +95,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
unsigned Opcode) const;
+ bool selectBitcast(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectLoad(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectStore(MachineInstr &I) const;
@@ -449,7 +452,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_INTTOPTR:
return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr);
case TargetOpcode::G_BITCAST:
- return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+ return selectBitcast(ResVReg, ResType, I);
case TargetOpcode::G_ADDRSPACE_CAST:
return selectAddrSpaceCast(ResVReg, ResType, I);
case TargetOpcode::G_PTR_ADD: {
@@ -586,6 +589,16 @@ bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
Opcode);
}
+bool SPIRVInstructionSelector::selectBitcast(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ Register OpReg = I.getOperand(1).getReg();
+ SPIRVType *OpType = OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
+ if (!GR.isBitcastCompatible(ResType, OpType))
+ report_fatal_error("incompatible result and operand types in a bitcast");
+ return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+}
+
static SPIRV::Scope::Scope getScope(SyncScope::ID Ord) {
switch (Ord) {
case SyncScope::SingleThread:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 049ca4ac818c4e..b0f94d0ab2f7d9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -177,12 +177,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
- getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
- typeInSet(0, allPtrsScalarsAndVectors),
- typeInSet(1, allPtrsScalarsAndVectors),
- LegalityPredicate(([=](const LegalityQuery &Query) {
- return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
- }))));
+ getActionDefinitionsBuilder(G_BITCAST).legalIf(
+ all(typeInSet(0, allPtrsScalarsAndVectors),
+ typeInSet(1, allPtrsScalarsAndVectors)));
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
diff --git a/llvm/test/CodeGen/SPIRV/bitcast.ll b/llvm/test/CodeGen/SPIRV/bitcast.ll
new file mode 100644
index 00000000000000..242c5a46583c22
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast.ll
@@ -0,0 +1,21 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#TyInt32:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#TyInt16:]] = OpTypeInt 16 0
+; CHECK-SPIRV-DAG: %[[#TyHalf:]] = OpTypeFloat 16
+; CHECK-SPIRV-DAG: %[[#Arg32:]] = OpFunctionParameter %[[#TyInt32]]
+; CHECK-SPIRV-DAG: %[[#Arg16:]] = OpUConvert %[[#TyInt16]] %[[#Arg32]]
+; CHECK-SPIRV-DAG: %[[#ValHalf:]] = OpBitcast %[[#TyHalf]] %8
+; CHECK-SPIRV-DAG: %[[#ValHalf2:]] = OpFMul %[[#TyHalf]] %[[#ValHalf]] %[[#ValHalf]]
+; CHECK-SPIRV-DAG: %[[#Res16:]] = OpBitcast %[[#TyInt16]] %[[#ValHalf2]]
+; CHECK-SPIRV-DAG: OpReturnValue %[[#Res16]]
+
+define i16 @foo(i32 %arg) {
+entry:
+ %op16 = trunc i32 %arg to i16
+ %val = bitcast i16 %op16 to half
+ %val2 = fmul half %val, %val
+ %res = bitcast half %val2 to i16
+ ret i16 %res
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/83139
More information about the llvm-commits
mailing list