[llvm] [SPIR-V] Implement OpSpecConstantOp with ptr-cast operation (PR #109979)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 25 06:35:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR reworks implementation of OpSpecConstantOp with ptr-cast operation (PtrCastToGeneric, GenericCastToPtr). Previous implementation didn't take into account a lot of use cases, including multiple inclusion of pointers, reference to a pointer from OpName, etc. A reproducer is attached as a new test case.

This PR also fixes wrong type inference for IR patterns which generate new virtual registers without SPIRV type. Previous implementation assumed always that result has the same address space as a source that is not the fact, and, for example, led to impossibility to emit a ptr-cast operation in the reproducer, because wrong type inference rendered source and destination with the same address space, eliminating translation of G_ADDRSPACE_CAST.


---
Full diff: https://github.com/llvm/llvm-project/pull/109979.diff


7 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+76-27) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+14-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+14-13) 
- (added) llvm/test/CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll (+63) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 3e1873e899680c..ceca0a180c95b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1128,6 +1128,11 @@ SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
          Type->getOperand(1).isImm() && "Pointer type is expected");
+  return getPointerStorageClass(Type);
+}
+
+SPIRV::StorageClass::StorageClass
+SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
   return static_cast<SPIRV::StorageClass::StorageClass>(
       Type->getOperand(1).getImm());
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index cad2bf96adf33e..2e32ddd9405d4f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -388,6 +388,8 @@ class SPIRVGlobalRegistry {
 
   // Gets the storage class of the pointer type assigned to this vreg.
   SPIRV::StorageClass::StorageClass getPointerStorageClass(Register VReg) const;
+  SPIRV::StorageClass::StorageClass
+  getPointerStorageClass(const SPIRVType *Type) const;
 
   // Return the number of bits SPIR-V pointers and size_t variables require.
   unsigned getPointerSize() const { return PointerSize; }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e475810f92f717..13bd89c18aa256 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1156,36 +1156,87 @@ static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
 bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
                                                    const SPIRVType *ResType,
                                                    MachineInstr &I) const {
-  // If the AddrSpaceCast user is single and in OpConstantComposite or
-  // OpVariable, we should select OpSpecConstantOp.
-  auto UIs = MRI->use_instructions(ResVReg);
-  if (!UIs.empty() && ++UIs.begin() == UIs.end() &&
-      (UIs.begin()->getOpcode() == SPIRV::OpConstantComposite ||
-       UIs.begin()->getOpcode() == SPIRV::OpVariable ||
-       isSpvIntrinsic(*UIs.begin(), Intrinsic::spv_init_global))) {
-    Register NewReg = I.getOperand(1).getReg();
-    MachineBasicBlock &BB = *I.getParent();
-    SPIRVType *SpvBaseTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
-    ResType = GR.getOrCreateSPIRVPointerType(SpvBaseTy, I, TII,
-                                             SPIRV::StorageClass::Generic);
-    bool Result =
-        BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp))
-            .addDef(ResVReg)
-            .addUse(GR.getSPIRVTypeID(ResType))
-            .addImm(static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric))
-            .addUse(NewReg)
-            .constrainAllUses(TII, TRI, RBI);
-    return Result;
-  }
+  MachineBasicBlock &BB = *I.getParent();
+  const DebugLoc &DL = I.getDebugLoc();
+
   Register SrcPtr = I.getOperand(1).getReg();
   SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr);
-  SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
-  SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
+  // don't generate a cast for a null that is represented by OpTypeInt
+  if (SrcPtrTy->getOpcode() != SPIRV::OpTypePointer ||
+      ResType->getOpcode() != SPIRV::OpTypePointer)
+    return BuildMI(BB, I, DL, TII.get(TargetOpcode::COPY))
+        .addDef(ResVReg)
+        .addUse(SrcPtr)
+        .constrainAllUses(TII, TRI, RBI);
+
+  SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtrTy);
+  SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResType);
+
+  // AddrSpaceCast uses within OpVariable and OpConstantComposite instructions
+  // are expressed by OpSpecConstantOp with an Opcode.
+  bool IsGRef = false;
+  bool IsAllowedRefs =
+      std::all_of(MRI->use_instr_begin(ResVReg), MRI->use_instr_end(),
+                  [&IsGRef](auto const &It) {
+                    unsigned Opcode = It.getOpcode();
+                    if (Opcode == SPIRV::OpConstantComposite ||
+                        Opcode == SPIRV::OpVariable ||
+                        isSpvIntrinsic(It, Intrinsic::spv_init_global))
+                      return IsGRef = true;
+                    return Opcode == SPIRV::OpName;
+                  });
+  if (IsAllowedRefs && IsGRef) {
+    // TODO: insert a check whether the Kernel capability was declared.
+    unsigned SpecOpcode =
+        DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)
+            ? static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric)
+            : (SrcSC == SPIRV::StorageClass::Generic &&
+                       isGenericCastablePtr(DstSC)
+                   ? static_cast<uint32_t>(SPIRV::Opcode::GenericCastToPtr)
+                   : 0);
+    if (SpecOpcode) {
+      // TODO: OpConstantComposite expects i8*, so we are forced to forget a
+      // correct value of ResType and use general i8* instead. Maybe this should
+      // be addressed in the emit-intrinsic step to infer a correct
+      // OpConstantComposite type.
+      SPIRVType *NewResType = GR.getOrCreateSPIRVPointerType(
+          GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, DstSC);
+      bool Result = BuildMI(BB, I, DL, TII.get(SPIRV::OpSpecConstantOp))
+                        .addDef(ResVReg)
+                        .addUse(GR.getSPIRVTypeID(NewResType))
+                        .addImm(SpecOpcode)
+                        .addUse(SrcPtr)
+                        .constrainAllUses(TII, TRI, RBI);
+      return Result;
+    } else if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
+      SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
+          GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
+      Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
+      MRI->setType(Tmp, LLT::pointer(0, 64));
+      GR.assignSPIRVTypeToVReg(GenericPtrTy, Tmp, *BB.getParent());
+      MachineInstrBuilder MIB =
+          BuildMI(BB, I, DL, TII.get(SPIRV::OpSpecConstantOp))
+              .addDef(Tmp)
+              .addUse(GR.getSPIRVTypeID(GenericPtrTy))
+              .addImm(static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric))
+              .addUse(SrcPtr);
+      GR.add(MIB.getInstr(), BB.getParent(), Tmp);
+      bool Result = MIB.constrainAllUses(TII, TRI, RBI);
+      SPIRVType *NewResType = GR.getOrCreateSPIRVPointerType(
+          GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, DstSC);
+      return Result &&
+             BuildMI(BB, I, DL, TII.get(SPIRV::OpSpecConstantOp))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(NewResType))
+                 .addImm(static_cast<uint32_t>(SPIRV::Opcode::GenericCastToPtr))
+                 .addUse(Tmp)
+                 .constrainAllUses(TII, TRI, RBI);
+    }
+  }
 
   // don't generate a cast between identical storage classes
   if (SrcSC == DstSC)
-    return BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                   TII.get(TargetOpcode::COPY))
+    return BuildMI(BB, I, DL, TII.get(TargetOpcode::COPY))
         .addDef(ResVReg)
         .addUse(SrcPtr)
         .constrainAllUses(TII, TRI, RBI);
@@ -1201,8 +1252,6 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
     Register Tmp = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
     SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
         GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
-    MachineBasicBlock &BB = *I.getParent();
-    const DebugLoc &DL = I.getDebugLoc();
     bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
                        .addDef(Tmp)
                        .addUse(GR.getSPIRVTypeID(GenericPtrTy))
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index cd0aff1a518439..e88f2f5ec5418d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -294,8 +294,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
       default:
         break;
       }
-      if (SpvType)
+      if (SpvType) {
+        // check if the address space needs correction
+        LLT RegType = MRI.getType(Reg);
+        if (SpvType->getOpcode() == SPIRV::OpTypePointer &&
+            RegType.isPointer() &&
+            storageClassToAddressSpace(GR->getPointerStorageClass(SpvType)) !=
+                RegType.getAddressSpace()) {
+          const SPIRVSubtarget &ST =
+              MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
+          SpvType = GR->getOrCreateSPIRVPointerType(
+              GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
+              addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
+        }
         GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
+      }
       if (!MRI.getRegClassOrNull(Reg))
         MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
                                      : &SPIRV::iIDRegClass);
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 23cd32eff45d5b..a74b2cc8615eff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -1631,6 +1631,7 @@ multiclass OpcodeOperand<bits<32> value> {
 defm InBoundsAccessChain : OpcodeOperand<66>;
 defm InBoundsPtrAccessChain : OpcodeOperand<70>;
 defm PtrCastToGeneric : OpcodeOperand<121>;
+defm GenericCastToPtr : OpcodeOperand<122>;
 defm Bitcast : OpcodeOperand<124>;
 defm ConvertPtrToU : OpcodeOperand<117>;
 defm ConvertUToPtr : OpcodeOperand<120>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index aec144f6f05861..a8016d42b0154f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -254,26 +254,27 @@ SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
 }
 
 SPIRV::Scope::Scope getMemScope(LLVMContext &Ctx, SyncScope::ID Id) {
-  static const struct {
-    // Named by
-    // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
-    // We don't need aliases for Invocation and CrossDevice, as we already have
-    // them covered by "singlethread" and "" strings respectively (see
-    // implementation of LLVMContext::LLVMContext()).
-    llvm::SyncScope::ID SubGroup = Ctx.getOrInsertSyncScopeID("subgroup");
-    llvm::SyncScope::ID WorkGroup = Ctx.getOrInsertSyncScopeID("workgroup");
-    llvm::SyncScope::ID Device = Ctx.getOrInsertSyncScopeID("device");
-  } SSIDs{};
+  // Named by
+  // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id.
+  // We don't need aliases for Invocation and CrossDevice, as we already have
+  // them covered by "singlethread" and "" strings respectively (see
+  // implementation of LLVMContext::LLVMContext()).
+  static const llvm::SyncScope::ID SubGroup =
+      Ctx.getOrInsertSyncScopeID("subgroup");
+  static const llvm::SyncScope::ID WorkGroup =
+      Ctx.getOrInsertSyncScopeID("workgroup");
+  static const llvm::SyncScope::ID Device =
+      Ctx.getOrInsertSyncScopeID("device");
 
   if (Id == llvm::SyncScope::SingleThread)
     return SPIRV::Scope::Invocation;
   else if (Id == llvm::SyncScope::System)
     return SPIRV::Scope::CrossDevice;
-  else if (Id == SSIDs.SubGroup)
+  else if (Id == SubGroup)
     return SPIRV::Scope::Subgroup;
-  else if (Id == SSIDs.WorkGroup)
+  else if (Id == WorkGroup)
     return SPIRV::Scope::Workgroup;
-  else if (Id == SSIDs.Device)
+  else if (Id == Device)
     return SPIRV::Scope::Device;
   return SPIRV::Scope::CrossDevice;
 }
diff --git a/llvm/test/CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll b/llvm/test/CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll
new file mode 100644
index 00000000000000..cd1a1b0080c621
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/PtrCast-in-OpSpecConstantOp.ll
@@ -0,0 +1,63 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[F:.*]] "F"
+; CHECK-DAG: OpName %[[B:.*]] "B"
+; CHECK-DAG: OpName %[[G1:.*]] "G1"
+; CHECK-DAG: OpName %[[G2:.*]] "G2"
+; CHECK-DAG: OpName %[[X:.*]] "X"
+; CHECK-DAG: OpName %[[Y:.*]] "Y"
+; CHECK-DAG: OpName %[[G3:.*]] "G3"
+; CHECK-DAG: OpName %[[G4:.*]] "G4"
+
+; CHECK-DAG: %[[Int:.*]] = OpTypeInt 32 0
+; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0
+; CHECK-DAG: %[[GenPtrChar:.*]] = OpTypePointer Generic %[[Char]]
+; CHECK-DAG: %[[CWPtrChar:.*]] = OpTypePointer CrossWorkgroup %[[Char]]
+; CHECK-DAG: %[[Arr1:.*]] = OpTypeArray %[[CWPtrChar]] %[[#]]
+; CHECK-DAG: %[[Struct1:.*]] = OpTypeStruct %8
+; CHECK-DAG: %[[Arr2:.*]] = OpTypeArray %[[GenPtrChar]] %[[#]]
+; CHECK-DAG: %[[Struct2:.*]] = OpTypeStruct %[[Arr2]]
+; CHECK-DAG: %[[GenPtr:.*]] = OpTypePointer Generic %[[Int]]
+; CHECK-DAG: %[[CWPtr:.*]] = OpTypePointer CrossWorkgroup %[[Int]]
+; CHECK-DAG: %[[WPtr:.*]] = OpTypePointer Workgroup %[[Int]]
+
+; CHECK-DAG: %[[F]] = OpVariable %[[CWPtr]] CrossWorkgroup %[[#]]
+; CHECK-DAG: %[[GenF:.*]] = OpSpecConstantOp %[[GenPtrChar]] 121 %[[F]]
+; CHECK-DAG: %[[B]] = OpVariable %[[CWPtr]] CrossWorkgroup %[[#]]
+; CHECK-DAG: %[[GenB:.*]] = OpSpecConstantOp %[[GenPtrChar]] 121 %[[B]]
+; CHECK-DAG: %[[GenFB:.*]] = OpConstantComposite %[[Arr2]] %[[GenF]] %[[GenB]]
+; CHECK-DAG: %[[GenBF:.*]] = OpConstantComposite %[[Arr2]] %[[GenB]] %[[GenF]]
+; CHECK-DAG: %[[CG1:.*]] = OpConstantComposite %[[Struct2]] %[[GenFB]]
+; CHECK-DAG: %[[CG2:.*]] = OpConstantComposite %[[Struct2]] %[[GenBF]]
+
+; CHECK-DAG: %[[X]] = OpVariable %[[WPtr]] Workgroup %[[#]]
+; CHECK-DAG: %[[GenX:.*]] = OpSpecConstantOp %[[GenPtr]] 121 %[[X]]
+; CHECK-DAG: %[[CWX:.*]] = OpSpecConstantOp %[[CWPtrChar]] 122 %[[GenX]]
+; CHECK-DAG: %[[Y]] = OpVariable %[[WPtr]] Workgroup %[[#]]
+; CHECK-DAG: %[[GenY:.*]] = OpSpecConstantOp %[[GenPtr]] 121 %[[Y]]
+; CHECK-DAG: %[[CWY:.*]] = OpSpecConstantOp %[[CWPtrChar]] 122 %[[GenY]]
+; CHECK-DAG: %[[CWXY:.*]] = OpConstantComposite %[[Arr1]] %[[CWX]] %[[CWY]]
+; CHECK-DAG: %[[CWYX:.*]] = OpConstantComposite %[[Arr1]] %[[CWY]] %[[CWX]]
+; CHECK-DAG: %[[CG3:.*]] = OpConstantComposite %[[Struct1]] %[[CWXY]]
+; CHECK-DAG: %[[CG4:.*]] = OpConstantComposite %[[Struct1]] %[[CWYX]]
+
+; CHECK-DAG: %[[G4]] = OpVariable %[[#]] CrossWorkgroup %[[CG4]]
+; CHECK-DAG: %[[G3]] = OpVariable %[[#]] CrossWorkgroup %[[CG3]]
+; CHECK-DAG: %[[G2]] = OpVariable %[[#]] CrossWorkgroup %[[CG2]]
+; CHECK-DAG: %[[G1]] = OpVariable %[[#]] CrossWorkgroup %[[CG1]]
+
+ at F = addrspace(1) constant i32 0
+ at B = addrspace(1) constant i32 1
+ at G1 = addrspace(1) constant { [2 x ptr addrspace(4)] } { [2 x ptr addrspace(4)] [ptr addrspace(4) addrspacecast (ptr addrspace(1) @F to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @B to ptr addrspace(4))] }
+ at G2 = addrspace(1) constant { [2 x ptr addrspace(4)] } { [2 x ptr addrspace(4)] [ptr addrspace(4) addrspacecast (ptr addrspace(1) @B to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @F to ptr addrspace(4))] }
+
+ at X = addrspace(3) constant i32 0
+ at Y = addrspace(3) constant i32 1
+ at G3 = addrspace(1) constant { [2 x ptr addrspace(1)] } { [2 x ptr addrspace(1)] [ptr addrspace(1) addrspacecast (ptr addrspace(3) @X to ptr addrspace(1)), ptr addrspace(1) addrspacecast (ptr addrspace(3) @Y to ptr addrspace(1))] }
+ at G4 = addrspace(1) constant { [2 x ptr addrspace(1)] } { [2 x ptr addrspace(1)] [ptr addrspace(1) addrspacecast (ptr addrspace(3) @Y to ptr addrspace(1)), ptr addrspace(1) addrspacecast (ptr addrspace(3) @X to ptr addrspace(1))] }
+
+define void @foo() {
+entry:
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/109979


More information about the llvm-commits mailing list