[llvm] Add support for the SPV_INTEL_usm_storage_classes extension (PR #82247)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 01:39:28 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/82247

>From 32adae1c35efbae68282afa1c4fc20e1b769460d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 19 Feb 2024 05:33:10 -0800
Subject: [PATCH 1/2] add support for the SPV_INTEL_usm_storage_classes
 extension

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 17 ++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  5 +-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  4 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 36 ++++++--
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  | 10 ++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  7 ++
 llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp   | 11 ++-
 llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp      |  6 ++
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |  3 +
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          | 19 ++++-
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  3 +-
 .../intel-usm-addrspaces.ll                   | 84 +++++++++++++++++++
 12 files changed, 180 insertions(+), 25 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cc438b2bb8d4d7..10569ef0468bda 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
 
 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
                                   SPIRVGlobalRegistry *GR,
-                                  MachineIRBuilder &MIRBuilder) {
+                                  MachineIRBuilder &MIRBuilder,
+                                  const SPIRVSubtarget &ST) {
   // Read argument's access qualifier from metadata or default.
   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
       getArgAccessQual(F, ArgIdx);
@@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
     if (MDTypeStr.ends_with("*"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           MDTypeStr, MIRBuilder,
-          addressSpaceToStorageClass(
-              OriginalArgType->getPointerAddressSpace()));
+          addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
+                                     ST));
     else if (MDTypeStr.ends_with("_t"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           "opencl." + MDTypeStr.str(), MIRBuilder,
@@ -206,6 +207,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
   GR->setCurrentFunc(MIRBuilder.getMF());
 
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+
   // Assign types and names to all args, and store their types for later.
   FunctionType *FTy = getOriginalFunctionType(F);
   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
@@ -216,7 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       // TODO: handle the case of multiple registers.
       if (VRegs[i].size() > 1)
         return false;
-      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
+      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
       ArgTypeVRegs.push_back(SpirvTy);
 
@@ -318,10 +323,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
   if (F.hasName())
     buildOpName(FuncVReg, F.getName(), MIRBuilder);
 
-  // Get access to information about available extensions
-  const auto *ST =
-      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
-
   // Handle entry points and function linkage.
   if (isEntryPoint(F)) {
     const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 47fec745c3f18a..a1cb630f1aa477 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     // TODO: change the implementation once opaque pointers are supported
     // in the SPIR-V specification.
     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
-    auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
+    // Get access to information about available extensions
+    const SPIRVSubtarget *ST =
+        static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+    auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
     // Null pointer means we have a loop in type definitions, make and
     // return corresponding OpTypeForwardPointer.
     if (SpvElementType == nullptr) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 86f65b6320d530..7c5252e8cb372b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor
                               "$r = OpGenericCastToPtrExplicit $t $p $s">;
 def OpBitcast : UnOp<"OpBitcast", 124>;
 
+// SPV_INTEL_usm_storage_classes
+def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
+def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
+
 // 3.42.12 Composite Instructions
 
 def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 53d19a1e31382d..7258d3b4d88ed3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) {
   }
 }
 
+static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
+  switch (SC) {
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return true;
+  default:
+    return false;
+  }
+}
+
 // In SPIR-V address space casting can only happen to and from the Generic
-// storage class. We can also only case Workgroup, CrossWorkgroup, or Function
+// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
 // pointers to and from Generic pointers. As such, we can convert e.g. from
 // Workgroup to Function by going via a Generic pointer as an intermediary. All
 // other combinations can only be done by a bitcast, and are probably not safe.
@@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
   SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
   SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
 
-  // Casting from an eligable pointer to Generic.
+  // don't generate a cast between identical storage classes
+  if (SrcSC == DstSC)
+    return true;
+
+  // Casting from an eligible pointer to Generic.
   if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC))
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric);
-  // Casting from Generic to an eligable pointer.
+  // Casting from Generic to an eligible pointer.
   if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC))
     return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
-  // Casting between 2 eligable pointers using Generic as an intermediary.
+  // Casting between 2 eligible pointers using Generic as an intermediary.
   if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
     Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
     SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
@@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
                           .addUse(Tmp)
                           .constrainAllUses(TII, TRI, RBI);
   }
+
+  // Check if instructions from the SPV_INTEL_usm_storage_classes extension may
+  // be applied
+  if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup)
+    return selectUnOp(ResVReg, ResType, I,
+                      SPIRV::OpPtrCastToCrossWorkgroupINTEL);
+  if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC))
+    return selectUnOp(ResVReg, ResType, I,
+                      SPIRV::OpCrossWorkgroupCastToPtrINTEL);
+
   // TODO Should this case just be disallowed completely?
   // We're casting 2 other arbitrary address spaces, so have to bitcast.
   return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
@@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   }
   SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
       PointerBaseType, I, TII,
-      addressSpaceToStorageClass(GV->getAddressSpace()));
+      addressSpaceToStorageClass(GV->getAddressSpace(), STI));
 
   std::string GlobalIdent;
   if (!GV->hasName()) {
@@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
 
   unsigned AddrSpace = GV->getAddressSpace();
   SPIRV::StorageClass::StorageClass Storage =
-      addressSpaceToStorageClass(AddrSpace);
+      addressSpaceToStorageClass(AddrSpace, STI);
   bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
                   Storage != SPIRV::StorageClass::Function;
   SPIRV::LinkageType::LinkageType LnkType =
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 011a550a7b3d9b..2ed538ab93f698 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -102,11 +102,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
   const LLT p3 = LLT::pointer(3, PSize); // Workgroup
   const LLT p4 = LLT::pointer(4, PSize); // Generic
-  const LLT p5 = LLT::pointer(5, PSize); // Input
+  const LLT p5 =
+      LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
+  const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
 
   // TODO: remove copy-pasting here by using concatenation in some way.
   auto allPtrsScalarsAndVectors = {
-      p0,    p1,    p2,    p3,    p4,    p5,    s1,     s8,     s16,
+      p0,    p1,    p2,    p3,    p4,    p5,    p6,     s1,     s8,     s16,
       s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
       v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
       v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
@@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   auto allFloatAndIntScalars = allIntScalars;
 
-  auto allPtrs = {p0, p1, p2, p3, p4, p5};
-  auto allWritablePtrs = {p0, p1, p3, p4};
+  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
+  auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
 
   for (auto Opc : TypeFoldingSupportingOpcs)
     getActionDefinitionsBuilder(Opc).custom();
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 9b9575b9879948..3be28c97d95381 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1063,6 +1063,13 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
     }
     break;
+  case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
+  case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
+      Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
+    }
+    break;
   case SPIRV::OpConstantFunctionPointerINTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index cbc16fa986614e..144216896eb68c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
 
 static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                            MachineIRBuilder MIB) {
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
   SmallVector<MachineInstr *, 10> ToErase;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
@@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
       SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
           BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
-          addressSpaceToStorageClass(MI.getOperand(4).getImm()));
+          addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
 
       // If the bitcast would be redundant, replace all uses with the source
       // register.
@@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
 
 static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                                  MachineIRBuilder MIB) {
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
+
   MachineRegisterInfo &MRI = MF.getRegInfo();
   SmallVector<MachineInstr *, 10> ToErase;
 
@@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
             getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
         SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
             BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
-            addressSpaceToStorageClass(MI.getOperand(3).getImm()));
+            addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
         MachineInstr *Def = MRI.getVRegDef(Reg);
         assert(Def && "Expecting an instruction that defines the register");
         insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index 4694363614ef60..79f16146ccd944 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -49,6 +49,12 @@ cl::list<SPIRV::Extension::Extension> Extensions(
         clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
                    "Adds OptNoneINTEL value for Function Control mask that "
                    "indicates a request to not optimize the function."),
+        clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes,
+                   "SPV_INTEL_usm_storage_classes",
+                   "Introduces two new storage classes that are sub classes of "
+                   "the CrossWorkgroup storage class "
+                   "that provides additional information that can enable "
+                   "optimization."),
         clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
                    "Allows work items in a subgroup to share data without the "
                    "use of local memory and work group barriers, and to "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 6c36087baa85ed..b022b97408d7d4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -463,6 +463,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
 defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
+defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
@@ -700,6 +701,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
 defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
 defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
 defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
+defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>;
+defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Dim enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 05f766d3ec5483..7d6e1497ee57b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "SPIRV.h"
@@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
     return 3;
   case SPIRV::StorageClass::Generic:
     return 4;
+  case SPIRV::StorageClass::DeviceOnlyINTEL:
+    return 5;
+  case SPIRV::StorageClass::HostOnlyINTEL:
+    return 6;
   case SPIRV::StorageClass::Input:
     return 7;
   default:
-    llvm_unreachable("Unable to get address space id");
+    report_fatal_error("Unable to get address space id");
   }
 }
 
 SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace) {
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
   switch (AddrSpace) {
   case 0:
     return SPIRV::StorageClass::Function;
@@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) {
     return SPIRV::StorageClass::Workgroup;
   case 4:
     return SPIRV::StorageClass::Generic;
+  case 5:
+    return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+               ? SPIRV::StorageClass::DeviceOnlyINTEL
+               : SPIRV::StorageClass::CrossWorkgroup;
+  case 6:
+    return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+               ? SPIRV::StorageClass::HostOnlyINTEL
+               : SPIRV::StorageClass::CrossWorkgroup;
   case 7:
     return SPIRV::StorageClass::Input;
   default:
-    llvm_unreachable("Unknown address space");
+    report_fatal_error("Unknown address space");
   }
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index a33dc02f854f58..1af53dcd0c4cd1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -27,6 +27,7 @@ class MachineRegisterInfo;
 class Register;
 class StringRef;
 class SPIRVInstrInfo;
+class SPIRVSubtarget;
 
 // Add the given string as a series of integer operand, inserting null
 // terminators and padding to make sure the operands all have 32-bit
@@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);
 
 // Convert an LLVM IR address space to a SPIR-V storage class.
 SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace);
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI);
 
 SPIRV::MemorySemantics::MemorySemantics
 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC);
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll
new file mode 100644
index 00000000000000..30c16350bf2b1f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_usm_storage_classes/intel-usm-addrspaces.ll
@@ -0,0 +1,84 @@
+; Modified from: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/test/extensions/INTEL/SPV_INTEL_usm_storage_classes/intel_usm_addrspaces.ll
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-EXT
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_usm_storage_classes %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV,CHECK-SPIRV-WITHOUT
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-: Capability USMStorageClassesINTEL
+; CHECK-SPIRV-WITHOUT-NO: Capability USMStorageClassesINTEL
+; CHECK-SPIRV-EXT-DAG: %[[DevTy:[0-9]+]] = OpTypePointer DeviceOnlyINTEL %[[#]]
+; CHECK-SPIRV-EXT-DAG: %[[HostTy:[0-9]+]] = OpTypePointer HostOnlyINTEL %[[#]]
+; CHECK-SPIRV-DAG: %[[CrsWrkTy:[0-9]+]] = OpTypePointer CrossWorkgroup %[[#]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_kernel void @foo_kernel() {
+entry:
+  ret void
+}
+
+; CHECK-SPIRV: %[[Ptr1:[0-9]+]] = OpLoad %[[CrsWrkTy]] %[[#]]
+; CHECK-SPIRV-EXT: %[[CastedPtr1:[0-9]+]] = OpCrossWorkgroupCastToPtrINTEL %[[DevTy]] %[[Ptr1]]
+; CHECK-SPIRV-WITHOUT-NOT: OpCrossWorkgroupCastToPtrINTEL
+; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr1]]
+define spir_func void @test1(ptr addrspace(1) %arg_glob, ptr addrspace(5) %arg_dev) {
+entry:
+  %arg_glob.addr = alloca ptr addrspace(1), align 4
+  %arg_dev.addr = alloca ptr addrspace(5), align 4
+  store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4
+  store ptr addrspace(5) %arg_dev, ptr %arg_dev.addr, align 4
+  %loaded_glob = load ptr addrspace(1), ptr %arg_glob.addr, align 4
+  %casted_ptr = addrspacecast ptr addrspace(1) %loaded_glob to ptr addrspace(5)
+  store ptr addrspace(5) %casted_ptr, ptr %arg_dev.addr, align 4
+  ret void
+}
+
+; CHECK-SPIRV: %[[Ptr2:[0-9]+]] = OpLoad %[[CrsWrkTy]] %[[#]]
+; CHECK-SPIRV-EXT: %[[CastedPtr2:[0-9]+]] = OpCrossWorkgroupCastToPtrINTEL %[[HostTy]] %[[Ptr2]]
+; CHECK-SPIRV-WITHOUT-NOT: OpCrossWorkgroupCastToPtrINTEL
+; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr2]]
+define spir_func void @test2(ptr addrspace(1) %arg_glob, ptr addrspace(6) %arg_host) {
+entry:
+  %arg_glob.addr = alloca ptr addrspace(1), align 4
+  %arg_host.addr = alloca ptr addrspace(6), align 4
+  store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4
+  store ptr addrspace(6) %arg_host, ptr %arg_host.addr, align 4
+  %loaded_glob = load ptr addrspace(1), ptr %arg_glob.addr, align 4
+  %casted_ptr = addrspacecast ptr addrspace(1) %loaded_glob to ptr addrspace(6)
+  store ptr addrspace(6) %casted_ptr, ptr %arg_host.addr, align 4
+  ret void
+}
+
+; CHECK-SPIRV-EXT: %[[Ptr3:[0-9]+]] = OpLoad %[[DevTy]] %[[#]]
+; CHECK-SPIRV-EXT: %[[CastedPtr3:[0-9]+]] = OpPtrCastToCrossWorkgroupINTEL %[[CrsWrkTy]] %[[Ptr3]]
+; CHECK-SPIRV-WITHOUT-NOT: OpPtrCastToCrossWorkgroupINTEL
+; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr3]]
+define spir_func void @test3(ptr addrspace(1) %arg_glob, ptr addrspace(5) %arg_dev) {
+entry:
+  %arg_glob.addr = alloca ptr addrspace(1), align 4
+  %arg_dev.addr = alloca ptr addrspace(5), align 4
+  store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4
+  store ptr addrspace(5) %arg_dev, ptr %arg_dev.addr, align 4
+  %loaded_dev = load ptr addrspace(5), ptr %arg_dev.addr, align 4
+  %casted_ptr = addrspacecast ptr addrspace(5) %loaded_dev to ptr addrspace(1)
+  store ptr addrspace(1) %casted_ptr, ptr %arg_glob.addr, align 4
+  ret void
+}
+
+; CHECK-SPIRV-EXT: %[[Ptr4:[0-9]+]] = OpLoad %[[HostTy]] %[[#]]
+; CHECK-SPIRV-EXT: %[[CastedPtr4:[0-9]+]] = OpPtrCastToCrossWorkgroupINTEL %[[CrsWrkTy]] %[[Ptr4]]
+; CHECK-SPIRV-WITHOUT-NOT: OpPtrCastToCrossWorkgroupINTEL
+; CHECK-SPIRV-EXT: OpStore %[[#]] %[[CastedPtr4]]
+define spir_func void @test4(ptr addrspace(1) %arg_glob, ptr addrspace(6) %arg_host) {
+entry:
+  %arg_glob.addr = alloca ptr addrspace(1), align 4
+  %arg_host.addr = alloca ptr addrspace(6), align 4
+  store ptr addrspace(1) %arg_glob, ptr %arg_glob.addr, align 4
+  store ptr addrspace(6) %arg_host, ptr %arg_host.addr, align 4
+  %loaded_host = load ptr addrspace(6), ptr %arg_host.addr, align 4
+  %casted_ptr = addrspacecast ptr addrspace(6) %loaded_host to ptr addrspace(1)
+  store ptr addrspace(1) %casted_ptr, ptr %arg_glob.addr, align 4
+  ret void
+}

>From b5b3e68035ad4060f2d028479ce0fdecb98d174d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 19 Feb 2024 05:54:51 -0800
Subject: [PATCH 2/2] apply clang-format

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 8 ++++----
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp         | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 2ed538ab93f698..4f2e7a240fc2cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -108,10 +108,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   // TODO: remove copy-pasting here by using concatenation in some way.
   auto allPtrsScalarsAndVectors = {
-      p0,    p1,    p2,    p3,    p4,    p5,    p6,     s1,     s8,     s16,
-      s32,   s64,   v2s1,  v2s8,  v2s16, v2s32, v2s64,  v3s1,   v3s8,
-      v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16, v4s32,  v4s64,  v8s1,
-      v8s8,  v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
+      p0,    p1,    p2,    p3,    p4,     p5,     p6,    s1,   s8,   s16,
+      s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64, v3s1, v3s8, v3s16,
+      v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64, v8s1, v8s8, v8s16,
+      v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
 
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 7d6e1497ee57b4..169d7cc93897ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -10,11 +10,11 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "SPIRV.h"
 #include "SPIRVInstrInfo.h"
+#include "SPIRVSubtarget.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"



More information about the llvm-commits mailing list