[llvm] ecc3bda - [SPIR-V] Fix bitcast legalization/instruction selection in SPIR-V Backend (#83139)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 03:15:34 PST 2024


Author: Vyacheslav Levytskyy
Date: 2024-03-04T12:15:30+01:00
New Revision: ecc3bdaae14a02acc879c018e21d58a83329dc6e

URL: https://github.com/llvm/llvm-project/commit/ecc3bdaae14a02acc879c018e21d58a83329dc6e
DIFF: https://github.com/llvm/llvm-project/commit/ecc3bdaae14a02acc879c018e21d58a83329dc6e.diff

LOG: [SPIR-V] Fix bitcast legalization/instruction selection in SPIR-V Backend (#83139)

This PR is to fix a way how SPIR-V Backend describes legality of
OpBitcast instruction and how it is validated on a step of instruction
selection. Instead of checking a size of virtual registers (that makes
no sense due to lack of guarantee of direct relations between size of
virtual register and bit width associated with the type size), this PR
allows to legalize OpBitcast without size check and postpones validation
to the instruction selection step.

As an example, let's consider the next example that was copied as is
from a bigger test suite:

```
  %355:id(s16) = G_BITCAST %301:id(s32)
  %303:id(s16) = ASSIGN_TYPE %355:id(s16), %349:type(s32)
  %644:fid(s32) = G_FMUL %645:fid, %646:fid
  %301:id(s32) = ASSIGN_TYPE %644:fid(s32), %40:type(s32)
```

Without the PR this leads to a crash with complains to an illegal
bitcast, because %355 is s16 and %301 is s32. However, we must check not
virtual registers in this case, but types of %355 and %301, i.e.,
%349:type(s32) and %40:type(s32), which are perfectly well compatible in
a sense of OpBitcast in this case.

In a test case that is a part of this PR OpBitcast is legal, being
applied for `OpTypeInt 16` and `OpTypeFloat 16`, but would not be
legalized without this PR due to virtual registers defined as having
size 16 and 32.

Added: 
    llvm/test/CodeGen/SPIRV/bitcast.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
    llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
    llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a797eca1e32ea..cc79aeda0cb85a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -834,15 +834,49 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
 }
 
-bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
+    const SPIRVType *Type) const {
   assert(Type && "Invalid Type pointer");
+  unsigned NumElements = 1;
   if (Type->getOpcode() == SPIRV::OpTypeVector) {
-    auto EleTypeReg = Type->getOperand(1).getReg();
-    Type = getSPIRVTypeForVReg(EleTypeReg);
+    NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
+    Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
   }
-  if (Type->getOpcode() == SPIRV::OpTypeInt)
-    return Type->getOperand(2).getImm() != 0;
-  llvm_unreachable("Attempting to get sign of non-integer type.");
+  return Type->getOpcode() == SPIRV::OpTypeInt ||
+                 Type->getOpcode() == SPIRV::OpTypeFloat
+             ? NumElements * Type->getOperand(1).getImm()
+             : 0;
+}
+
+const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
+    const SPIRVType *Type) const {
+  if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
+    Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
+  return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
+}
+
+bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+  const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
+  return IntType && IntType->getOperand(2).getImm() != 0;
+}
+
+bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
+                                              const SPIRVType *Type2) const {
+  if (!Type1 || !Type2)
+    return false;
+  auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
+  // Ignore 
diff erence between <1.5 and >=1.5 protocol versions:
+  // it's valid if either Result Type or Operand is a pointer, and the other
+  // is a pointer, an integer scalar, or an integer vector.
+  if (Op1 == SPIRV::OpTypePointer &&
+      (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
+    return true;
+  if (Op2 == SPIRV::OpTypePointer &&
+      (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
+    return true;
+  unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
+           Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
+  return Bits1 > 0 && Bits1 == Bits2;
 }
 
 SPIRV::StorageClass::StorageClass

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..151af19fcef2ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,9 +197,19 @@ class SPIRVGlobalRegistry {
   // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
   bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
 
-  // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
+  // For vectors or scalars of booleans, integers and floats, return the scalar
+  // type's bitwidth. Otherwise calls llvm_unreachable().
   unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
 
+  // For vectors or scalars of integers and floats, return total bitwidth of the
+  // argument. Otherwise returns 0.
+  unsigned getNumScalarOrVectorTotalBitWidth(const SPIRVType *Type) const;
+
+  // Returns either pointer to integer type, that may be a type of vector
+  // elements or an original type, or nullptr if the argument is niether
+  // an integer scalar, nor an integer vector
+  const SPIRVType *retrieveScalarOrVectorIntType(const SPIRVType *Type) const;
+
   // For integer vectors or scalars, return whether the integers are signed.
   bool isScalarOrVectorSigned(const SPIRVType *Type) const;
 
@@ -209,6 +219,11 @@ class SPIRVGlobalRegistry {
   // Return the number of bits SPIR-V pointers and size_t variables require.
   unsigned getPointerSize() const { return PointerSize; }
 
+  // Returns true if two types are defined and are compatible in a sense of
+  // OpBitcast instruction
+  bool isBitcastCompatible(const SPIRVType *Type1,
+                           const SPIRVType *Type2) const;
+
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e050f806f29163..3e01a6ac71f63c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -95,6 +95,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
                   unsigned Opcode) const;
 
+  bool selectBitcast(Register ResVReg, const SPIRVType *ResType,
+                     MachineInstr &I) const;
+
   bool selectLoad(Register ResVReg, const SPIRVType *ResType,
                   MachineInstr &I) const;
   bool selectStore(MachineInstr &I) const;
@@ -452,7 +455,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_INTTOPTR:
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr);
   case TargetOpcode::G_BITCAST:
-    return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+    return selectBitcast(ResVReg, ResType, I);
   case TargetOpcode::G_ADDRSPACE_CAST:
     return selectAddrSpaceCast(ResVReg, ResType, I);
   case TargetOpcode::G_PTR_ADD: {
@@ -592,6 +595,16 @@ bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
                            Opcode);
 }
 
+bool SPIRVInstructionSelector::selectBitcast(Register ResVReg,
+                                             const SPIRVType *ResType,
+                                             MachineInstr &I) const {
+  Register OpReg = I.getOperand(1).getReg();
+  SPIRVType *OpType = OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
+  if (!GR.isBitcastCompatible(ResType, OpType))
+    report_fatal_error("incompatible result and operand types in a bitcast");
+  return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+}
+
 static SPIRV::Scope::Scope getScope(SyncScope::ID Ord) {
   switch (Ord) {
   case SyncScope::SingleThread:

diff  --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index d4c15b2ba73d9f..f81548742a11e2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -200,12 +200,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
 
-  getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
-      typeInSet(0, allPtrsScalarsAndVectors),
-      typeInSet(1, allPtrsScalarsAndVectors),
-      LegalityPredicate(([=](const LegalityQuery &Query) {
-        return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
-      }))));
+  getActionDefinitionsBuilder(G_BITCAST).legalIf(
+      all(typeInSet(0, allPtrsScalarsAndVectors),
+          typeInSet(1, allPtrsScalarsAndVectors)));
 
   getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
 

diff  --git a/llvm/test/CodeGen/SPIRV/bitcast.ll b/llvm/test/CodeGen/SPIRV/bitcast.ll
new file mode 100644
index 00000000000000..242c5a46583c22
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast.ll
@@ -0,0 +1,21 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#TyInt32:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#TyInt16:]] = OpTypeInt 16 0
+; CHECK-SPIRV-DAG: %[[#TyHalf:]] = OpTypeFloat 16
+; CHECK-SPIRV-DAG: %[[#Arg32:]] = OpFunctionParameter %[[#TyInt32]]
+; CHECK-SPIRV-DAG: %[[#Arg16:]] = OpUConvert %[[#TyInt16]] %[[#Arg32]]
+; CHECK-SPIRV-DAG: %[[#ValHalf:]] = OpBitcast %[[#TyHalf]] %8
+; CHECK-SPIRV-DAG: %[[#ValHalf2:]] = OpFMul %[[#TyHalf]] %[[#ValHalf]] %[[#ValHalf]]
+; CHECK-SPIRV-DAG: %[[#Res16:]] = OpBitcast %[[#TyInt16]] %[[#ValHalf2]]
+; CHECK-SPIRV-DAG: OpReturnValue %[[#Res16]]
+
+define i16 @foo(i32 %arg) {
+entry:
+  %op16 = trunc i32 %arg to i16
+  %val = bitcast i16 %op16 to half
+  %val2 = fmul half %val, %val
+  %res = bitcast half %val2 to i16
+  ret i16 %res
+}


        


More information about the llvm-commits mailing list