[llvm] Add support for the SPV_INTEL_usm_storage_classes extension (PR #82247)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 19 05:58:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

Add support for the SPV_INTEL_usm_storage_classes extension:
* https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_usm_storage_classes.asciidoc

---

Patch is 23.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82247.diff


12 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+9-8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+4-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+30-6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+9-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+9-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp (+6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+16-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+2-1) 
- (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll (+84) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index baeed2ad895a4b..abce2ea1c77820 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
 
 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
                                   SPIRVGlobalRegistry *GR,
-                                  MachineIRBuilder &MIRBuilder) {
+                                  MachineIRBuilder &MIRBuilder,
+                                  const SPIRVSubtarget &ST) {
   // Read argument's access qualifier from metadata or default.
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
@@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
     if (MDTypeStr.ends_with("*"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           MDTypeStr, MIRBuilder,
-          addressSpaceToStorageClass(
-              OriginalArgType->getPointerAddressSpace()));
+          addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
+                                     ST));
     else if (MDTypeStr.ends_with("_t"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           "opencl." + MDTypeStr.str(), MIRBuilder,
@@ -220,6 +221,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
   GR->setCurrentFunc(MIRBuilder.getMF());
 
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+
   // Assign types and names to all args, and store their types for later.
   FunctionType *FTy = getOriginalFunctionType(F);
   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
@@ -230,7 +235,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       // TODO: handle the case of multiple registers.
       if (VRegs[i].size() > 1)
         return false;
-      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
+      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
       ArgTypeVRegs.push_back(SpirvTy);
 
@@ -332,10 +337,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   if (F.hasName())
     buildOpName(FuncVReg, F.getName(), MIRBuilder);
 
-  // Get access to information about available extensions
-  const auto *ST =
-      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
-
   // Handle entry points and function linkage.
   if (isEntryPoint(F)) {
     const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 47fec745c3f18a..a1cb630f1aa477 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     // TODO: change the implementation once opaque pointers are supported
     // in the SPIR-V specification.
     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
-    auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
+    // Get access to information about available extensions
+    const SPIRVSubtarget *ST =
+        static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+    auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
     // Null pointer means we have a loop in type definitions, make and
     // return corresponding OpTypeForwardPointer.
     if (SpvElementType == nullptr) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 7965dd969e5cfa..b16e73466ebe59 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor
                               "$r = OpGenericCastToPtrExplicit $t $p $s">;
 def OpBitcast : UnOp<"OpBitcast", 124>;
 
+// SPV_INTEL_usm_storage_classes
+def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
+def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
+
 // 3.42.12 Composite Instructions
 
 def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 53d19a1e31382d..7258d3b4d88ed3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) {
   }
 }
 
+static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
+  switch (SC) {
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return true;
+  default:
+    return false;
+  }
+}
+
 // In SPIR-V address space casting can only happen to and from the Generic
-// storage class. We can also only case Workgroup, CrossWorkgroup, or Function
+// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
 // pointers to and from Generic pointers. As such, we can convert e.g. from
 // Workgroup to Function by going via a Generic pointer as an intermediary. All
 // other combinations can only be done by a bitcast, and are probably not safe.
@@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
   SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
   SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
 
-  // Casting from an eligable pointer to Generic.
+  // don't generate a cast between identical storage classes
+  if (SrcSC == DstSC)
+    return true;
+
+  // Casting from an eligible pointer to Generic.
   if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC))
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric);
-  // Casting from Generic to an eligable pointer.
+  // Casting from Generic to an eligible pointer.
   if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC))
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
-  // Casting between 2 eligable pointers using Generic as an intermediary.
+  // Casting between 2 eligible pointers using Generic as an intermediary.
   if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
     Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
     SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
@@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
                           .addUse(Tmp)
                           .constrainAllUses(TII, TRI, RBI);
   }
+
+  // Check if instructions from the SPV_INTEL_usm_storage_classes extension may
+  // be applied
+  if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup)
+    return selectUnOp(ResVReg, ResType, I,
+                      SPIRV::OpPtrCastToCrossWorkgroupINTEL);
+  if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC))
+    return selectUnOp(ResVReg, ResType, I,
+                      SPIRV::OpCrossWorkgroupCastToPtrINTEL);
+
   // TODO Should this case just be disallowed completely?
   // We're casting 2 other arbitrary address spaces, so have to bitcast.
   return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
@@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   }
   SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
       PointerBaseType, I, TII,
-      addressSpaceToStorageClass(GV->getAddressSpace()));
+      addressSpaceToStorageClass(GV->getAddressSpace(), STI));
 
   std::string GlobalIdent;
   if (!GV->hasName()) {
@@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
 
   unsigned AddrSpace = GV->getAddressSpace();
   SPIRV::StorageClass::StorageClass Storage =
-      addressSpaceToStorageClass(AddrSpace);
+      addressSpaceToStorageClass(AddrSpace, STI);
   bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
                   Storage != SPIRV::StorageClass::Function;
   SPIRV::LinkageType::LinkageType LnkType =
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 011a550a7b3d9b..4f2e7a240fc2cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -102,14 +102,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
   const LLT p4 = LLT::pointer(4, PSize); // Generic
-  const LLT p5 = LLT::pointer(5, PSize); // Input
+  const LLT p5 =
+      LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
+  const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
 
   // TODO: remove copy-pasting here by using concatenation in some way.
   auto allPtrsScalarsAndVectors = {
-      p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
-      s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
-      v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
-      v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
+      p0,    p1,    p2,    p3,    p4,     p5,     p6,    s1,   s8,   s16,
+      s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64, v3s1, v3s8, v3s16,
+      v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64, v8s1, v8s8, v8s16,
+      v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
 
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
@@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allFloatAndIntScalars = allIntScalars;
 
-  auto allPtrs = {p0, p1, p2, p3, p4, p5};
-  auto allWritablePtrs = {p0, p1, p3, p4};
+  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
+  auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
 
   for (auto Opc : TypeFoldingSupportingOpcs)
     getActionDefinitionsBuilder(Opc).custom();
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b3244cfbec7014..83c407a2a89f31 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1063,6 +1063,13 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
     }
     break;
+  case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
+  case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
+      Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
+    }
+    break;
   case SPIRV::OpConstantFunctionPointerINTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index cbc16fa986614e..144216896eb68c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
 
 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                            MachineIRBuilder MIB) {
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
   SmallVector<MachineInstr *, 10> ToErase;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
@@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
-          addressSpaceToStorageClass(MI.getOperand(4).getImm()));
+          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
 
       // If the bitcast would be redundant, replace all uses with the source
       // register.
@@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
 
 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                                  MachineIRBuilder MIB) {
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
+
   MachineRegisterInfo &MRI = MF.getRegInfo();
   SmallVector<MachineInstr *, 10> ToErase;
 
@@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
             getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
-            addressSpaceToStorageClass(MI.getOperand(3).getImm()));
+            addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
         MachineInstr *Def = MRI.getVRegDef(Reg);
         assert(Def && "Expecting an instruction that defines the register");
         insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index e0493321041e38..f62c350c2eb7df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -49,6 +49,12 @@ cl::list<SPIRV::Extension::Extension> Extensions(
         clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
                    "Adds OptNoneINTEL value for Function Control mask that "
                    "indicates a request to not optimize the function."),
+        clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes,
+                   "SPV_INTEL_usm_storage_classes",
+                   "Introduces two new storage classes that are sub classes of "
+                   "the CrossWorkgroup storage class "
+                   "that provides additional information that can enable "
+                   "optimization."),
         clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
                    "Allows work items in a subgroup to share data without the "
                    "use of local memory and work group barriers, and to "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d133a2575284c0..4da3afd418a4f8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -461,6 +461,7 @@ defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_
 defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
+defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
@@ -698,6 +699,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
 defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
 defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
 defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
+defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>;
+defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Dim enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 89daa19e666f63..f998b3490ca1c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -14,6 +14,7 @@
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "SPIRV.h"
 #include "SPIRVInstrInfo.h"
+#include "SPIRVSubtarget.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
     return 3;
   case SPIRV::StorageClass::Generic:
     return 4;
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+    return 5;
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return 6;
   case SPIRV::StorageClass::Input:
     return 7;
   default:
-    llvm_unreachable("Unable to get address space id");
+    report_fatal_error("Unable to get address space id");
   }
 }
 
 SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace) {
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
   switch (AddrSpace) {
   case 0:
     return SPIRV::StorageClass::Function;
@@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) {
     return SPIRV::StorageClass::Workgroup;
   case 4:
     return SPIRV::StorageClass::Generic;
+  case 5:
+    return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+               ? SPIRV::StorageClass::DeviceOnlyINTEL
+               : SPIRV::StorageClass::CrossWorkgroup;
+  case 6:
+    return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+               ? SPIRV::StorageClass::HostOnlyINTEL
+               : SPIRV::StorageClass::CrossWorkgroup;
   case 7:
     return SPIRV::StorageClass::Input;
   default:
-    llvm_unreachable("Unknown address space");
+    report_fatal_error("Unknown address space");
   }
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 60742e2f272808..42d2d313615770 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -27,6 +27,7 @@ class MachineRegisterInfo;
 class Register;
 class StringRef;
 class SPIRVInstrInfo;
+class SPIRVSubtarget;
 
 // Add the given string as a series of integer operand, inserting null
 // terminators and padding to make sure the operands all have 32-bit
@@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);
 
 // Convert an LLVM IR address space to a SPIR-V storage class.
 SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace);
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI);
 
 SPIRV::MemorySemantics::MemorySemantics
 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll
new file mode 100644
index 00000000000000..30c16350bf2b1f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll
@@ -0,0 +1,84 @@
+; Modified from: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/test/extensions/INTEL/SPV_INTEL_usm_storage_classes/intel_usm_addrspaces.ll
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-EXT
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82247


More information about the llvm-commits mailing list