[llvm] [SPIR-V] Fix bitcast legalization/instruction selection in SPIR-V Backend (PR #83139)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 27 06:32:47 PST 2024


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/83139

Fix bitcast legalization/instruction selection in SPIR-V Backend.

TODO:
* details of the issue
* a test case

>From fc958ece3e35e37eacb72407bda784cc57d68841 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 27 Feb 2024 06:31:28 -0800
Subject: [PATCH] fix bitcast legalization/instruction selection

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 46 ++++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 17 ++++++-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 15 +++++-
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  9 ++--
 4 files changed, 73 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a1cb630f1aa477..64af253e25a1e1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -834,15 +834,49 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
 }
 
-bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
+    const SPIRVType *Type) const {
   assert(Type && "Invalid Type pointer");
+  unsigned NumElements = 1;
   if (Type->getOpcode() == SPIRV::OpTypeVector) {
-    auto EleTypeReg = Type->getOperand(1).getReg();
-    Type = getSPIRVTypeForVReg(EleTypeReg);
+    NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
+    Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
   }
-  if (Type->getOpcode() == SPIRV::OpTypeInt)
-    return Type->getOperand(2).getImm() != 0;
-  llvm_unreachable("Attempting to get sign of non-integer type.");
+  return Type->getOpcode() == SPIRV::OpTypeInt ||
+                 Type->getOpcode() == SPIRV::OpTypeFloat
+             ? NumElements * Type->getOperand(1).getImm()
+             : 0;
+}
+
+const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
+    const SPIRVType *Type) const {
+  if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
+    Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
+  return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
+}
+
+bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
+  const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
+  return IntType && IntType->getOperand(2).getImm() != 0;
+}
+
+bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
+                                              const SPIRVType *Type2) const {
+  if (!Type1 || !Type2)
+    return false;
+  auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
+  // Ignore difference between <1.5 and >=1.5 protocol versions:
+  // it's valid if either Result Type or Operand is a pointer, and the other
+  // is a pointer, an integer scalar, or an integer vector.
+  if (Op1 == SPIRV::OpTypePointer &&
+      (Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
+    return true;
+  if (Op2 == SPIRV::OpTypePointer &&
+      (Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
+    return true;
+  unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
+           Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
+  return Bits1 > 0 && Bits1 == Bits2;
 }
 
 SPIRV::StorageClass::StorageClass
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..151af19fcef2ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,9 +197,19 @@ class SPIRVGlobalRegistry {
   // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
   bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
 
-  // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
+  // For vectors or scalars of booleans, integers and floats, return the scalar
+  // type's bitwidth. Otherwise calls llvm_unreachable().
   unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
 
+  // For vectors or scalars of integers and floats, return total bitwidth of the
+  // argument. Otherwise returns 0.
+  unsigned getNumScalarOrVectorTotalBitWidth(const SPIRVType *Type) const;
+
+  // Returns either pointer to integer type, that may be a type of vector
+  // elements or an original type, or nullptr if the argument is niether
+  // an integer scalar, nor an integer vector
+  const SPIRVType *retrieveScalarOrVectorIntType(const SPIRVType *Type) const;
+
   // For integer vectors or scalars, return whether the integers are signed.
   bool isScalarOrVectorSigned(const SPIRVType *Type) const;
 
@@ -209,6 +219,11 @@ class SPIRVGlobalRegistry {
   // Return the number of bits SPIR-V pointers and size_t variables require.
   unsigned getPointerSize() const { return PointerSize; }
 
+  // Returns true if two types are defined and are compatible in a sense of
+  // OpBitcast instruction
+  bool isBitcastCompatible(const SPIRVType *Type1,
+                           const SPIRVType *Type2) const;
+
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 9b38073ec3bcf7..2958b73ba00ab5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -95,6 +95,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
                   unsigned Opcode) const;
 
+  bool selectBitcast(Register ResVReg, const SPIRVType *ResType,
+                     MachineInstr &I) const;
+
   bool selectLoad(Register ResVReg, const SPIRVType *ResType,
                   MachineInstr &I) const;
   bool selectStore(MachineInstr &I) const;
@@ -449,7 +452,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_INTTOPTR:
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr);
   case TargetOpcode::G_BITCAST:
-    return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+    return selectBitcast(ResVReg, ResType, I);
   case TargetOpcode::G_ADDRSPACE_CAST:
     return selectAddrSpaceCast(ResVReg, ResType, I);
   case TargetOpcode::G_PTR_ADD: {
@@ -586,6 +589,16 @@ bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
                            Opcode);
 }
 
+bool SPIRVInstructionSelector::selectBitcast(Register ResVReg,
+                                             const SPIRVType *ResType,
+                                             MachineInstr &I) const {
+  Register OpReg = I.getOperand(1).getReg();
+  SPIRVType *OpType = OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
+  if (!GR.isBitcastCompatible(ResType, OpType))
+    report_fatal_error("incompatible result and operand types in a bitcast");
+  return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
+}
+
 static SPIRV::Scope::Scope getScope(SyncScope::ID Ord) {
   switch (Ord) {
   case SyncScope::SingleThread:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 049ca4ac818c4e..b0f94d0ab2f7d9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -177,12 +177,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
 
-  getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
-      typeInSet(0, allPtrsScalarsAndVectors),
-      typeInSet(1, allPtrsScalarsAndVectors),
-      LegalityPredicate(([=](const LegalityQuery &Query) {
-        return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
-      }))));
+  getActionDefinitionsBuilder(G_BITCAST).legalIf(
+      all(typeInSet(0, allPtrsScalarsAndVectors),
+          typeInSet(1, allPtrsScalarsAndVectors)));
 
   getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
 



More information about the llvm-commits mailing list