[llvm] 864a83d - [SPIR-V] Add support for inline SPIR-V types (#125316)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 20 12:49:48 PDT 2025


Author: Cassandra Beckley
Date: 2025-03-20T15:49:44-04:00
New Revision: 864a83deb0b613a2e957b6c106a4412b54949131

URL: https://github.com/llvm/llvm-project/commit/864a83deb0b613a2e957b6c106a4412b54949131
DIFF: https://github.com/llvm/llvm-project/commit/864a83deb0b613a2e957b6c106a4412b54949131.diff

LOG: [SPIR-V] Add support for inline SPIR-V types (#125316)

Using HLSL's [Inline
SPIR-V](https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html)
features, users have the ability to use
[`SpirvType`](https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html#types)
to have fine-grained control over the SPIR-V representation of a type.
As explained in the spec, this is useful because it enables vendors to
author headers with types for their own extensions.

As discussed in [Target Extension Types for Inline SPIR-V and Decorated
Types](https://github.com/llvm/wg-hlsl/pull/105), we would like to
represent the HLSL SpirvType type using a 'spirv.Type' target extension
type in LLVM IR. This pull request lowers that type to SPIR-V.

Added: 
    llvm/test/CodeGen/SPIRV/inline/type.coop-matrix.ll
    llvm/test/CodeGen/SPIRV/inline/type.ll

Modified: 
    llvm/docs/SPIRVUsage.rst
    llvm/lib/IR/Type.cpp
    llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
    llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h
    llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
    llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
    llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
    llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
    llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index ec842db586f77..d3d9fbb3ba294 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -266,6 +266,54 @@ parameters of its underlying image type, so that a sampled image for the
 previous type has the representation
 ``target("spirv.SampledImage, void, 1, 1, 0, 0, 0, 0, 0)``.
 
+.. _inline-spirv-types:
+
+Inline SPIR-V Types
+-------------------
+
+HLSL allows users to create types representing specific SPIR-V types, using ``vk::SpirvType`` and
+``vk::SpirvOpaqueType``. These are specified in the `Inline SPIR-V`_ proposal. They may be
+represented using target extension types:
+
+.. _Inline SPIR-V: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html#types
+
+  .. table:: Inline SPIR-V Types
+
+    ========================== =================== =========================
+    LLVM type name             LLVM type arguments LLVM integer arguments
+    ========================== =================== =========================
+    ``spirv.Type``             SPIR-V operands     opcode, size, alignment
+    ``spirv.IntegralConstant`` integral type       value
+    ``spirv.Literal``          (none)              value
+    ========================== =================== =========================
+
+The operand arguments to ``spirv.Type`` may be either a ``spirv.IntegralConstant`` type,
+representing an ``OpConstant`` id operand, a ``spirv.Literal`` type, representing an immediate
+literal operand, or any other type, representing the id of that type as an operand.
+``spirv.IntegralConstant`` and ``spirv.Literal`` may not be used outside of this context.
+
+For example, ``OpTypeArray`` (opcode 28) takes an id for the element type and an id for the element
+length, so an array of 16 integers could be declared as:
+
+``target("spirv.Type", i32, target("spirv.IntegralConstant", i32, 16), 28, 64, 32)``
+
+This will be lowered to:
+
+``OpTypeArray %int %int_16``
+
+``OpTypeVector`` takes an id for the component type and a literal for the component count, so a
+4-integer vector could be declared as:
+
+``target("spirv.Type", i32, target("spirv.Literal", 4), 23, 16, 32)``
+
+This will be lowered to:
+
+``OpTypeVector %int 4``
+
+See `Target Extension Types for Inline SPIR-V and Decorated Types`_ for further details.
+
+.. _Target Extension Types for Inline SPIR-V and Decorated Types: https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md
+
 .. _spirv-intrinsics:
 
 Target Intrinsics

diff  --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 277985b6b00a3..65b56fec78c2e 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -968,6 +968,29 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
   if (Name == "spirv.Image")
     return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::CanBeGlobal,
                           TargetExtType::CanBeLocal);
+  if (Name == "spirv.Type") {
+    assert(Ty->getNumIntParameters() == 3 &&
+           "Wrong number of parameters for spirv.Type");
+
+    auto Size = Ty->getIntParameter(1);
+    auto Alignment = Ty->getIntParameter(2);
+
+    llvm::Type *LayoutType = nullptr;
+    if (Size > 0 && Alignment > 0) {
+      LayoutType =
+          ArrayType::get(Type::getIntNTy(C, Alignment), Size * 8 / Alignment);
+    } else {
+      // LLVM expects variables that can be allocated to have an alignment and
+      // size. Default to using a 32-bit int as the layout type if none are
+      // present.
+      LayoutType = Type::getInt32Ty(C);
+    }
+
+    return TargetTypeInfo(LayoutType, TargetExtType::CanBeGlobal,
+                          TargetExtType::CanBeLocal);
+  }
+  if (Name == "spirv.IntegralConstant" || Name == "spirv.Literal")
+    return TargetTypeInfo(Type::getVoidTy(C));
   if (Name.starts_with("spirv."))
     return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::HasZeroInit,
                           TargetExtType::CanBeGlobal,

diff  --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
index 63dcf0067b515..cd65985a4229c 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
@@ -114,6 +114,8 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
     recordOpExtInstImport(MI);
   } else if (OpCode == SPIRV::OpExtInst) {
     printOpExtInst(MI, OS);
+  } else if (OpCode == SPIRV::UNKNOWN_type) {
+    printUnknownType(MI, OS);
   } else {
     // Print any extra operands for variadic instructions.
     const MCInstrDesc &MCDesc = MII.get(OpCode);
@@ -312,6 +314,31 @@ void SPIRVInstPrinter::printOpDecorate(const MCInst *MI, raw_ostream &O) {
   }
 }
 
+void SPIRVInstPrinter::printUnknownType(const MCInst *MI, raw_ostream &O) {
+  const auto EnumOperand = MI->getOperand(1);
+  assert(EnumOperand.isImm() &&
+         "second operand of UNKNOWN_type must be opcode!");
+
+  const auto Enumerant = EnumOperand.getImm();
+  const auto NumOps = MI->getNumOperands();
+
+  // Print the opcode using the spirv-as unknown opcode syntax
+  O << "OpUnknown(" << Enumerant << ", " << NumOps << ") ";
+
+  // The result ID must be printed after the opcode when using this syntax
+  printOperand(MI, 0, O);
+
+  O << " ";
+
+  const MCInstrDesc &MCDesc = MII.get(MI->getOpcode());
+  unsigned NumFixedOps = MCDesc.getNumOperands();
+  if (NumOps == NumFixedOps)
+    return;
+
+  // Print the rest of the operands
+  printRemainingVariableOps(MI, NumFixedOps, O, true);
+}
+
 static void printExpr(const MCExpr *Expr, raw_ostream &O) {
 #ifndef NDEBUG
   const MCSymbolRefExpr *SRE;

diff  --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h
index 9b02524f50b81..a7b38a6951c51 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.h
@@ -35,6 +35,7 @@ class SPIRVInstPrinter : public MCInstPrinter {
 
   void printOpDecorate(const MCInst *MI, raw_ostream &O);
   void printOpExtInst(const MCInst *MI, raw_ostream &O);
+  void printUnknownType(const MCInst *MI, raw_ostream &O);
   void printRemainingVariableOps(const MCInst *MI, unsigned StartIndex,
                                  raw_ostream &O, bool SkipFirstSpace = false,
                                  bool SkipImmediates = false);

diff  --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
index c1810cdf20537..bd104858b90c6 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
@@ -45,6 +45,9 @@ class SPIRVMCCodeEmitter : public MCCodeEmitter {
   void encodeInstruction(const MCInst &MI, SmallVectorImpl<char> &CB,
                          SmallVectorImpl<MCFixup> &Fixups,
                          const MCSubtargetInfo &STI) const override;
+  void encodeUnknownType(const MCInst &MI, SmallVectorImpl<char> &CB,
+                         SmallVectorImpl<MCFixup> &Fixups,
+                         const MCSubtargetInfo &STI) const;
 };
 
 } // end anonymous namespace
@@ -104,10 +107,32 @@ static void emitUntypedInstrOperands(const MCInst &MI,
     emitOperand(Op, CB);
 }
 
+void SPIRVMCCodeEmitter::encodeUnknownType(const MCInst &MI,
+                                           SmallVectorImpl<char> &CB,
+                                           SmallVectorImpl<MCFixup> &Fixups,
+                                           const MCSubtargetInfo &STI) const {
+  // Encode the first 32 SPIR-V bits with the number of args and the opcode.
+  const uint64_t OpCode = MI.getOperand(1).getImm();
+  const uint32_t NumWords = MI.getNumOperands();
+  const uint32_t FirstWord = (0xFFFF & NumWords) << 16 | (0xFFFF & OpCode);
+
+  // encoding: <opcode+len> <result type> [<operand0> <operand1> ...]
+  support::endian::write(CB, FirstWord, llvm::endianness::little);
+
+  emitOperand(MI.getOperand(0), CB);
+  for (unsigned i = 2; i < NumWords; ++i)
+    emitOperand(MI.getOperand(i), CB);
+}
+
 void SPIRVMCCodeEmitter::encodeInstruction(const MCInst &MI,
                                            SmallVectorImpl<char> &CB,
                                            SmallVectorImpl<MCFixup> &Fixups,
                                            const MCSubtargetInfo &STI) const {
+  if (MI.getOpcode() == SPIRV::UNKNOWN_type) {
+    encodeUnknownType(MI, CB, Fixups, STI);
+    return;
+  }
+
   // Encode the first 32 SPIR-V bytes with the number of args and the opcode.
   const uint64_t OpCode = getBinaryCodeForInstr(MI, Fixups, STI);
   const uint32_t NumWords = MI.getNumOperands() + 1;

diff  --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 579e37f68d5d8..058e9166d90c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -3041,6 +3041,57 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
   return GR->getOrCreateOpTypeSampledImage(OpaqueImageType, MIRBuilder);
 }
 
+static SPIRVType *getInlineSpirvType(const TargetExtType *ExtensionType,
+                                     MachineIRBuilder &MIRBuilder,
+                                     SPIRVGlobalRegistry *GR) {
+  assert(ExtensionType->getNumIntParameters() == 3 &&
+         "Inline SPIR-V type builtin takes an opcode, size, and alignment "
+         "parameter");
+  auto Opcode = ExtensionType->getIntParameter(0);
+
+  SmallVector<MCOperand> Operands;
+  for (llvm::Type *Param : ExtensionType->type_params()) {
+    if (const TargetExtType *ParamEType = dyn_cast<TargetExtType>(Param)) {
+      if (ParamEType->getName() == "spirv.IntegralConstant") {
+        assert(ParamEType->getNumTypeParameters() == 1 &&
+               "Inline SPIR-V integral constant builtin must have a type "
+               "parameter");
+        assert(ParamEType->getNumIntParameters() == 1 &&
+               "Inline SPIR-V integral constant builtin must have a "
+               "value parameter");
+
+        auto OperandValue = ParamEType->getIntParameter(0);
+        auto *OperandType = ParamEType->getTypeParameter(0);
+
+        const SPIRVType *OperandSPIRVType = GR->getOrCreateSPIRVType(
+            OperandType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+
+        Operands.push_back(MCOperand::createReg(GR->buildConstantInt(
+            OperandValue, MIRBuilder, OperandSPIRVType, true)));
+        continue;
+      } else if (ParamEType->getName() == "spirv.Literal") {
+        assert(ParamEType->getNumTypeParameters() == 0 &&
+               "Inline SPIR-V literal builtin does not take type "
+               "parameters");
+        assert(ParamEType->getNumIntParameters() == 1 &&
+               "Inline SPIR-V literal builtin must have an integer "
+               "parameter");
+
+        auto OperandValue = ParamEType->getIntParameter(0);
+
+        Operands.push_back(MCOperand::createImm(OperandValue));
+        continue;
+      }
+    }
+    const SPIRVType *TypeOperand = GR->getOrCreateSPIRVType(
+        Param, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
+    Operands.push_back(MCOperand::createReg(GR->getSPIRVTypeID(TypeOperand)));
+  }
+
+  return GR->getOrCreateUnknownType(ExtensionType, MIRBuilder, Opcode,
+                                    Operands);
+}
+
 namespace SPIRV {
 TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
                                                    LLVMContext &Context) {
@@ -3113,39 +3164,45 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
   const StringRef Name = BuiltinType->getName();
   LLVM_DEBUG(dbgs() << "Lowering builtin type: " << Name << "\n");
 
-  // Lookup the demangled builtin type in the TableGen records.
-  const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
-  if (!TypeRecord)
-    report_fatal_error("Missing TableGen record for builtin type: " + Name);
-
-  // "Lower" the BuiltinType into TargetType. The following get<...>Type methods
-  // use the implementation details from TableGen records or TargetExtType
-  // parameters to either create a new OpType<...> machine instruction or get an
-  // existing equivalent SPIRVType from GlobalRegistry.
   SPIRVType *TargetType;
-  switch (TypeRecord->Opcode) {
-  case SPIRV::OpTypeImage:
-    TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
-    break;
-  case SPIRV::OpTypePipe:
-    TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
-    break;
-  case SPIRV::OpTypeDeviceEvent:
-    TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
-    break;
-  case SPIRV::OpTypeSampler:
-    TargetType = getSamplerType(MIRBuilder, GR);
-    break;
-  case SPIRV::OpTypeSampledImage:
-    TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
-    break;
-  case SPIRV::OpTypeCooperativeMatrixKHR:
-    TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
-    break;
-  default:
-    TargetType =
-        getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
-    break;
+  if (Name == "spirv.Type") {
+    TargetType = getInlineSpirvType(BuiltinType, MIRBuilder, GR);
+  } else {
+    // Lookup the demangled builtin type in the TableGen records.
+    const SPIRV::BuiltinType *TypeRecord = SPIRV::lookupBuiltinType(Name);
+    if (!TypeRecord)
+      report_fatal_error("Missing TableGen record for builtin type: " + Name);
+
+    // "Lower" the BuiltinType into TargetType. The following get<...>Type
+    // methods use the implementation details from TableGen records or
+    // TargetExtType parameters to either create a new OpType<...> machine
+    // instruction or get an existing equivalent SPIRVType from
+    // GlobalRegistry.
+
+    switch (TypeRecord->Opcode) {
+    case SPIRV::OpTypeImage:
+      TargetType = getImageType(BuiltinType, AccessQual, MIRBuilder, GR);
+      break;
+    case SPIRV::OpTypePipe:
+      TargetType = getPipeType(BuiltinType, MIRBuilder, GR);
+      break;
+    case SPIRV::OpTypeDeviceEvent:
+      TargetType = GR->getOrCreateOpTypeDeviceEvent(MIRBuilder);
+      break;
+    case SPIRV::OpTypeSampler:
+      TargetType = getSamplerType(MIRBuilder, GR);
+      break;
+    case SPIRV::OpTypeSampledImage:
+      TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
+      break;
+    case SPIRV::OpTypeCooperativeMatrixKHR:
+      TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
+      break;
+    default:
+      TargetType =
+          getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
+      break;
+    }
   }
 
   // Emit OpName instruction if a new OpType<...> instruction was added

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cbec1c95eadc3..f7482b4686848 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1456,6 +1456,32 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
   return SpirvTy;
 }
 
+SPIRVType *SPIRVGlobalRegistry::getOrCreateUnknownType(
+    const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode,
+    const ArrayRef<MCOperand> Operands) {
+  Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
+  if (ResVReg.isValid())
+    return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
+  ResVReg = createTypeVReg(MIRBuilder);
+
+  DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
+
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    MachineInstrBuilder MIB = MIRBuilder.buildInstr(SPIRV::UNKNOWN_type)
+                                  .addDef(ResVReg)
+                                  .addImm(Opcode);
+    for (MCOperand Operand : Operands) {
+      if (Operand.isReg()) {
+        MIB.addUse(Operand.getReg());
+      } else if (Operand.isImm()) {
+        MIB.addImm(Operand.getImm());
+      }
+    }
+
+    return MIB;
+  });
+}
+
 const MachineInstr *
 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
                                        MachineIRBuilder &MIRBuilder) {

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 89599f17ef737..459a7b3eb1400 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -621,6 +621,11 @@ class SPIRVGlobalRegistry {
                                        MachineIRBuilder &MIRBuilder,
                                        unsigned Opcode);
 
+  SPIRVType *getOrCreateUnknownType(const Type *Ty,
+                                    MachineIRBuilder &MIRBuilder,
+                                    unsigned Opcode,
+                                    const ArrayRef<MCOperand> Operands);
+
   const TargetRegisterClass *getRegClass(SPIRVType *SpvType) const;
   LLT getRegType(SPIRVType *SpvType) const;
 

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
index 9451583a5fa85..2fde2b0bc0b1f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
@@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
   let Pattern = pattern;
 }
 
+class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
+    : Op<0, outs, ins, asmstr, pattern> {
+  let isPseudo = 1;
+}
+
 // Pseudo instructions
 class Pseudo<dag outs, dag ins> : Op<0, outs, ins, ""> {
   let isPseudo = 1;

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a8f862271dbab..7f369ea303a12 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -25,6 +25,9 @@ let isCodeGenOnly=1 in {
   def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins vpID:$src)>;
 }
 
+def UNKNOWN_type
+    : UnknownOp<(outs TYPE:$type), (ins i32imm:$opcode, variable_ops), " ">;
+
 def SPVTypeBin : SDTypeProfile<1, 2, []>;
 
 def assigntype : SDNode<"SPIRVISD::AssignType", SPVTypeBin>;

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3445e8fc3990c..2aba950037ec3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -496,7 +496,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
   bool HasDefs = I.getNumDefs() > 0;
   Register ResVReg = HasDefs ? I.getOperand(0).getReg() : Register(0);
   SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
-  assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
+  assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
+         I.getOpcode() == TargetOpcode::G_IMPLICIT_DEF);
   if (spvSelect(ResVReg, ResType, I)) {
     if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs).
       for (unsigned i = 0; i < I.getNumDefs(); ++i)

diff  --git a/llvm/test/CodeGen/SPIRV/inline/type.coop-matrix.ll b/llvm/test/CodeGen/SPIRV/inline/type.coop-matrix.ll
new file mode 100644
index 0000000000000..12effdfd464f1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/inline/type.coop-matrix.ll
@@ -0,0 +1,30 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+
+; TODO: enable spirv-val once we can add cooperative matrix capability and extension
+
+; CHECK: [[float_t:%[0-9]+]] = OpTypeFloat 32
+; CHECK: [[uint32_t:%[0-9]+]] = OpTypeInt 32 0
+
+; CHECK: [[uint32_2:%[0-9]+]] = OpConstant [[uint32_t]] 2
+%scope = type target("spirv.IntegralConstant", i32, 2) ; Workgroup
+; CHECK: [[uint32_4:%[0-9]+]] = OpConstant [[uint32_t]] 4
+%cols = type target("spirv.IntegralConstant", i32, 4)
+%rows = type target("spirv.IntegralConstant", i32, 4)
+; CHECK: [[uint32_0:%[0-9]+]] = OpConstant [[uint32_t]] 0
+%use = type target("spirv.IntegralConstant", i32, 0) ; MatrixAKHR
+
+; CHECK: OpUnknown(4456, 7) [[coop_t:%[0-9]+]] [[float_t]] [[uint32_2]] [[uint32_4]] [[uint32_4]] [[uint32_0]]
+%coop_t = type target("spirv.Type", float, %scope, %rows, %cols, %use, 4456, 0, 0)
+
+; CHECK: [[getCooperativeMatrix_t:%[0-9]+]] = OpTypeFunction [[coop_t]]
+
+; CHECK: [[getCooperativeMatrix:%[0-9]+]] = OpFunction [[coop_t]] None [[getCooperativeMatrix_t]]
+declare %coop_t @getCooperativeMatrix()
+
+define void @main() #1 {
+entry:
+; CHECK: {{%[0-9]+}} = OpFunctionCall [[coop_t]] [[getCooperativeMatrix]]
+  %coop = call %coop_t @getCooperativeMatrix()
+
+  ret void
+}

diff  --git a/llvm/test/CodeGen/SPIRV/inline/type.ll b/llvm/test/CodeGen/SPIRV/inline/type.ll
new file mode 100644
index 0000000000000..c9f1946d97048
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/inline/type.ll
@@ -0,0 +1,40 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - | spirv-as - -o - | spirv-val %}
+
+; CHECK: [[uint32_t:%[0-9]+]] = OpTypeInt 32 0
+
+; CHECK: [[image_t:%[0-9]+]] = OpTypeImage %3 2D 2 0 0 1 Unknown
+%type_2d_image = type target("spirv.Image", float, 1, 2, 0, 0, 1, 0)
+
+%literal_true = type target("spirv.Literal", 1)
+%literal_32 = type target("spirv.Literal", 32)
+
+; CHECK: [[uint32_4:%[0-9]+]] = OpConstant [[uint32_t]] 4
+%integral_constant_4 = type target("spirv.IntegralConstant", i32, 4)
+
+; CHECK: OpUnknown(28, 4) [[array_t:%[0-9]+]] [[image_t]] [[uint32_4]]
+%ArrayTex2D = type target("spirv.Type", %type_2d_image, %integral_constant_4, 28, 0, 0)
+
+; CHECK: OpUnknown(21, 4) [[int_t:%[0-9]+]] 32 1
+%int_t = type target("spirv.Type", %literal_32, %literal_true, 21, 0, 0)
+
+; CHECK: [[getTexArray_t:%[0-9]+]] = OpTypeFunction [[array_t]]
+; CHECK: [[getInt_t:%[0-9]+]] = OpTypeFunction [[int_t]]
+
+; CHECK: [[getTexArray:%[0-9]+]] = OpFunction [[array_t]] None [[getTexArray_t]]
+declare %ArrayTex2D @getTexArray()
+
+; CHECK: [[getInt:%[0-9]+]] = OpFunction [[int_t]] None [[getInt_t]]
+declare %int_t @getInt()
+
+define void @main() #1 {
+entry:
+; CHECK: {{%[0-9]+}} = OpFunctionCall [[array_t]] [[getTexArray]]
+  %retTex = call %ArrayTex2D @getTexArray()
+
+; CHECK: {{%[0-9]+}} = OpFunctionCall [[int_t]] [[getInt]]
+  %i = call %int_t @getInt()
+
+  ret void
+}


        


More information about the llvm-commits mailing list