[llvm] [SPIR-V]: Add SPIR-V extension: SPV_KHR_cooperative_matrix (PR #96091)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 20 05:04:06 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/96091

>From d3140520ec58798eb21f72676235f1f3ffcbdc30 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 19 Jun 2024 06:50:02 -0700
Subject: [PATCH 1/6] initial support for SPV_KHR_cooperative_matrix

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  40 ++++++
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        |  16 ++-
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |   4 +
 .../lib/Target/SPIRV/SPIRVDuplicatesTracker.h |  28 ++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  19 +++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   6 +-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  16 +++
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |   9 ++
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |   3 +
 .../cooperative_matrix.ll                     | 124 ++++++++++++++++++
 10 files changed, 263 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c14e5098be711..750d371a48d64 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1824,6 +1824,23 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
+static bool generateConstructInst(const SPIRV::IncomingCall *Call,
+                                  MachineIRBuilder &MIRBuilder,
+                                  SPIRVGlobalRegistry *GR) {
+  return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
+                            GR->getSPIRVTypeID(Call->ReturnType));
+}
+
+static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
+                                 MachineIRBuilder &MIRBuilder,
+                                 SPIRVGlobalRegistry *GR) {
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  unsigned Opcode =
+      SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+  return buildOpFromWrapper(MIRBuilder, Opcode, Call,
+                            GR->getSPIRVTypeID(Call->ReturnType));
+}
+
 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
                                      MachineIRBuilder &MIRBuilder,
                                      SPIRVGlobalRegistry *GR) {
@@ -2382,6 +2399,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
   case SPIRV::Select:
     return generateSelectInst(Call.get(), MIRBuilder);
+  case SPIRV::Construct:
+    return generateConstructInst(Call.get(), MIRBuilder, GR);
   case SPIRV::SpecConstant:
     return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Enqueue:
@@ -2400,6 +2419,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
   case SPIRV::KernelClock:
     return generateKernelClockInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::CoopMatr:
+    return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
   }
   return false;
 }
@@ -2524,6 +2545,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
                                        ExtensionType->getIntParameter(0)));
 }
 
+static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
+                                  MachineIRBuilder &MIRBuilder,
+                                  SPIRVGlobalRegistry *GR) {
+  assert(ExtensionType->getNumIntParameters() == 4 &&
+         "Invalid number of parameters for SPIR-V coop matrices builtin!");
+  assert(ExtensionType->getNumTypeParameters() == 1 &&
+         "SPIR-V coop matrices builtin type must have a type parameter!");
+  const SPIRVType *ElemType =
+      GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
+  // Create or get an existing type from GlobalRegistry.
+  return GR->getOrCreateOpTypeCoopMatr(
+      MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(1),
+      ExtensionType->getIntParameter(2), ExtensionType->getIntParameter(3),
+      ExtensionType->getIntParameter(4));
+}
+
 static SPIRVType *
 getImageType(const TargetExtType *ExtensionType,
              const SPIRV::AccessQualifier::AccessQualifier Qualifier,
@@ -2654,6 +2691,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
   case SPIRV::OpTypeSampledImage:
     TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
     break;
+  case SPIRV::OpTypeCooperativeMatrixKHR:
+    TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
+    break;
   default:
     TargetType =
         getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 2b8e6d856686a..5595d4cde120c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -60,6 +60,8 @@ def AtomicFloating : BuiltinGroup;
 def GroupUniform : BuiltinGroup;
 def KernelClock : BuiltinGroup;
 def CastToPtr : BuiltinGroup;
+def Construct : BuiltinGroup;
+def CoopMatr : BuiltinGroup;
 
 //===----------------------------------------------------------------------===//
 // Class defining a demangled builtin record. The information in the record
@@ -114,6 +116,12 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage
 // Select builtin record:
 def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>;
 
+// Composite Construct builtin record:
+def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>;
+
+// Dot builtin record:
+def : DemangledBuiltin<"dot", OpenCL_std, Dot, 2, 2>;
+
 //===----------------------------------------------------------------------===//
 // Class defining an extended builtin record used for lowering into an
 // OpExtInst instruction.
@@ -608,6 +616,12 @@ defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToGlobal", Ope
 defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
 defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
 
+// Cooperative Matrix builtin records:
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;
+
 //===----------------------------------------------------------------------===//
 // Class defining a work/sub group builtin that should be translated into a
 // SPIR-V instruction using the defined properties.
@@ -1436,7 +1450,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>;
 def : BuiltinType<"spirv.Image", OpTypeImage>;
 def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>;
 def : BuiltinType<"spirv.Pipe", OpTypePipe>;
-
+def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>;
 
 //===----------------------------------------------------------------------===//
 // Class matching an OpenCL builtin type name to an equivalent SPIR-V
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 75aa1823b11f2..c2a5b234d6a2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -66,6 +66,10 @@ static const std::map<std::string, SPIRV::Extension::Extension>
          SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
         {"SPV_KHR_shader_clock",
          SPIRV::Extension::Extension::SPV_KHR_shader_clock},
+        {"SPV_KHR_cooperative_matrix",
+         SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix},
+        {"SPV_INTEL_joint_matrix",
+         SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
 };
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index a37e65a47eda0..cb8576ddee719 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -60,6 +60,8 @@ enum SpecialTypeKind {
   STK_Pipe,
   STK_DeviceEvent,
   STK_Pointer,
+  STK_CoopMatr,
+  STK_JointMatr,
   STK_Last = -1
 };
 
@@ -113,7 +115,33 @@ make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
           .Val,
       SpecialTypeKind::STK_SampledImage);
 }
+/*
+union MatrAttrs {
+  struct BitFlags {
+    unsigned Layout : 2;
 
+    unsigned Depth : 2;
+    unsigned Arrayed : 1;
+    unsigned MS : 1;
+    unsigned Sampled : 2;
+    unsigned ImageFormat : 6;
+    unsigned AQ : 2;
+  } Flags;
+  unsigned Val;
+
+  MatrAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
+             unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
+    Val = 0;
+    Flags.Dim = Dim;
+    Flags.Depth = Depth;
+    Flags.Arrayed = Arrayed;
+    Flags.MS = MS;
+    Flags.Sampled = Sampled;
+    Flags.ImageFormat = ImageFormat;
+    Flags.AQ = AQ;
+  }
+};
+*/
 inline SpecialTypeDescriptor make_descr_sampler() {
   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index b22d2a04f75b1..18ae2679b216a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1189,6 +1189,25 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
       .addUse(getSPIRVTypeID(ImageType));
 }
 
+SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
+    MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
+    const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
+    uint32_t Use) {
+  Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
+  if (ResVReg.isValid())
+    return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
+  ResVReg = createTypeVReg(MIRBuilder);
+  SPIRVType *SpirvTy = MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
+                           .addDef(ResVReg)
+                           .addUse(getSPIRVTypeID(ElemType))
+                           .addImm(Scope)
+                           .addImm(Rows)
+                           .addImm(Columns)
+                           .addImm(Use);
+  DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
+  return SpirvTy;
+}
+
 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index db01f68f48de9..0ca9f58695dff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -514,7 +514,11 @@ class SPIRVGlobalRegistry {
 
   SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType,
                                            MachineIRBuilder &MIRBuilder);
-
+  SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder,
+                                       const TargetExtType *ExtensionType,
+                                       const SPIRVType *ElemType,
+                                       uint32_t Scope, uint32_t Rows,
+                                       uint32_t Columns, uint32_t Use);
   SPIRVType *
   getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
                         SPIRV::AccessQualifier::AccessQualifier AccQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index dedfd5e6e32db..bbc24dfbc9b10 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
 def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
                   (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
                   "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
+                  (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+                  "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
 
 // 3.42.7 Constant-Creation Instructions
 
@@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta
                   "$res = OpAsmINTEL $type $asm_type $target $asm">;
 def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops),
                   "$res = OpAsmCallINTEL $type $asm">;
+
+// SPV_KHR_cooperative_matrix
+def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res),
+                  (ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops),
+                  "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">;
+def OpCooperativeMatrixStoreKHR: Op<4458, (outs),
+                  (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops),
+                  "OpCooperativeMatrixStoreKHR $pointer $objectToStore  $memory_layout">;
+def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
+                  (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops),
+                  "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
+def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
+                  "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 30a6c474f467a..ac0aa682ea4be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
     }
     break;
+  case SPIRV::OpTypeCooperativeMatrixKHR:
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+      report_fatal_error(
+          "OpTypeCooperativeMatrixKHR type requires the "
+          "following SPIR-V extension: SPV_KHR_cooperative_matrix",
+          false);
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+    break;
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 318c5cebb7a43..f7e482449e0ca 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -302,6 +302,8 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>;
 defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
 defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
 defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
+defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
+defm SPV_INTEL_joint_matrix : ExtensionOperand<112>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -478,6 +480,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
 defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
+defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
new file mode 100644
index 0000000000000..b87358bdc88ee
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -0,0 +1,124 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: OpTypeCooperativeMatrixKHR type requires the following SPIR-V extension: SPV_KHR_cooperative_matrix
+
+; CHECK: OpCapability CooperativeMatrixKHR
+; CHECK: OpExtension "SPV_KHR_cooperative_matrix"
+
+; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
+; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
+; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const12:]] 12
+; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const48:]] 48
+; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0
+; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3
+; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2
+; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1
+; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy1:]] [[#Int32Ty]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const2]]
+; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int8Ty]] [[#Const0]] [[#Const12]] [[#Const48]] [[#Const0]]
+; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
+; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
+; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]]
+; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
+; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
+; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
+; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
+; CHECK-SPIRV: CooperativeMatrixStoreKHR
+
+%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
+%"class.sycl::_V1::detail::array" = type { [2 x i64] }
+%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" }
+
+ at __spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
+ at __spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
+
+define weak_odr dso_local spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+entry:
+  %sub_c.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
+  %ref.tmp29.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
+  %agg.tmp15.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::range", ptr %_arg_accB5, i64 0, i32 0, i32 0, i64 1
+  %agg.tmp15.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp15.sroa.0.sroa.2.0..sroa_idx, align 8
+  %agg.tmp16.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accB6, align 8
+  %agg.tmp16.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::id", ptr %_arg_accB6, i64 0, i32 0, i32 0, i64 1
+  %agg.tmp16.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp16.sroa.0.sroa.2.0..sroa_idx, align 8
+  %mul.i4.i.i.i.i45 = mul i64 %agg.tmp16.sroa.0.sroa.0.0.copyload, %agg.tmp15.sroa.0.sroa.2.0.copyload
+  %add.i6.i.i.i.i46 = add i64 %mul.i4.i.i.i.i45, %agg.tmp16.sroa.0.sroa.2.0.copyload
+  %add.ptr.i47 = getelementptr inbounds i8, ptr addrspace(1) %_arg_accB, i64 %add.i6.i.i.i.i46
+  %0 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
+  %1 = extractelement <3 x i64> %0, i64 1
+  %2 = extractelement <3 x i64> %0, i64 0
+  %3 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32
+  %4 = extractelement <3 x i64> %3, i64 1
+  %5 = extractelement <3 x i64> %3, i64 0
+  %cmp.i.i = icmp ult i64 %1, 2147483648
+  %cmp.i54.i = icmp ult i64 %2, 2147483648
+  %cmp.i56.i = icmp ult i64 %4, 2147483648
+  %sub.i = sub nsw i64 %1, %4
+  %cmp.i58.i = icmp ult i64 %5, 2147483648
+  %sub5.i = sub nsw i64 %2, %5
+  call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
+  %call.i.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 0)
+  store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8
+  %mul.i = mul nsw i64 %sub.i, 12
+  %div2452.i = lshr i64 %sub5.i, 4
+  %mul26.i = mul i64 %div2452.i, 48
+  %div.i = udiv i64 %_arg_K, 48
+  %mul11.i = mul i64 %mul.i, %_arg_K
+  %add.ptr.i93.i = getelementptr inbounds i8, ptr addrspace(1) %_arg_accA, i64 %mul11.i
+  %idx.neg.i.i104.i = sub i64 0, %add.i6.i.i.i.i46
+  %add.ptr.i.i105141.i = getelementptr i8, ptr addrspace(1) %add.ptr.i47, i64 %mul26.i
+  %mul22.i = shl i64 %_arg_N, 2
+  %add.ptr.i108140.i = getelementptr i8, ptr addrspace(1) %add.ptr.i.i105141.i, i64 %idx.neg.i.i104.i
+  br label %for.cond.i
+
+for.cond.i:
+  %k.0.i = phi i32 [ 0, %entry ], [ %add.i, %for.body.i ]
+  %conv.i = zext i32 %k.0.i to i64
+  %cmp.i = icmp ugt i64 %div.i, %conv.i
+  br i1 %cmp.i, label %for.body.i, label %exit
+
+for.body.i:
+  %mul12.i = mul nsw i32 %k.0.i, 48
+  %conv13.i = zext i32 %mul12.i to i64
+  %add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
+  %call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4)
+  %call1.i.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) %call.ascast.i66.i, i32 0, i64 %_arg_K, i32 1)
+  %len = tail call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) %call1.i.i)
+  %div20.i = mul nsw i32 %k.0.i, 12
+  %conv21.i = zext i32 %div20.i to i64
+  %mul23.i = mul i64 %mul22.i, %conv21.i
+  %add.ptr.i111.i = getelementptr i8, ptr addrspace(1) %add.ptr.i108140.i, i64 %mul23.i
+  %call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(4)
+  %call1.i73.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(4) %call.ascast.i72.i, i32 0, i64 %mul22.i)
+  call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i)
+  %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
+  %call.i77.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) %call1.i.i, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) %call1.i73.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i, i32 12)
+  store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i77.i, ptr %ref.tmp29.sroa.0.i, align 8
+  %ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i = load i64, ptr %ref.tmp29.sroa.0.i, align 8
+  store i64 %ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i, ptr %sub_c.sroa.0.i, align 8
+  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i)
+  %add.i = add nuw nsw i32 %k.0.i, 1
+  br label %for.cond.i
+
+exit:
+  %mul37.i = mul i64 %mul.i, %_arg_N
+  %add.ptr.i.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_accC, i64 %mul37.i
+  %mul39.i = mul nuw i64 %div2452.i, 12
+  %add.ptr.i81.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i, i64 %mul39.i
+  %call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(4)
+  %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
+  tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(4) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i32 0, i64 %_arg_N, i32 1)
+  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
+  ret void
+}
+
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32)
+declare dso_local spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0))
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4), i32, i64, i32)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(4), i32, i64)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0), target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32)
+declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(4), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32)
+
+declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
+declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)

>From 83cc404d6c1a8f012ffff82213367cf5ed08eb56 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 19 Jun 2024 06:55:28 -0700
Subject: [PATCH 2/6] update the test case

---
 .../cooperative_matrix.ll                     | 35 +++++++++----------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
index b87358bdc88ee..7b964b54fd73e 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -7,24 +7,23 @@
 ; CHECK: OpCapability CooperativeMatrixKHR
 ; CHECK: OpExtension "SPV_KHR_cooperative_matrix"
 
-; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
-; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
-; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const12:]] 12
-; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const48:]] 48
-; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0
-; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3
-; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2
-; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1
-; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy1:]] [[#Int32Ty]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const2]]
-; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int8Ty]] [[#Const0]] [[#Const12]] [[#Const48]] [[#Const0]]
-; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
-; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
-; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]]
-; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
-; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
-; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
-; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
-; CHECK-SPIRV: CooperativeMatrixStoreKHR
+; CHECK-DAG: %[[#Int8Ty:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#Const12:]] = OpConstant %[[#Int32Ty]] 12
+; CHECK-DAG: %[[#Const48:]] = OpConstant %[[#Int32Ty]] 48
+; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int32Ty]] 0
+; CHECK-DAG: %[[#Const3:]] = OpConstant %[[#Int32Ty]] 3
+; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int32Ty]] 2
+; CHECK-DAG: %[[#Const1:]] = OpConstant %[[#Int32Ty]] 1
+; CHECK-DAG: %[[#MatTy1:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const12]] %[[#Const2]]
+; CHECK-DAG: %[[#MatTy2:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const0]] %[[#Const12]] %[[#Const48]] %[[#Const0]]
+; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const2]] %[[#Const48]] %[[#Const12]] %[[#Const1]]
+; CHECK: OpCompositeConstruct %[[#MatTy1]]
+; CHECK: %[[#Load1:]] = OpCooperativeMatrixLoadKHR %[[#MatTy2]]
+; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#]] %[[#Load1]]
+; CHECK: OpCooperativeMatrixLoadKHR %[[#MatTy3]]
+; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]]
+; CHECK: OpCooperativeMatrixStoreKHR
 
 %"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
 %"class.sycl::_V1::detail::array" = type { [2 x i64] }

>From 83256855ad8f2314185fc9167003a10339fb85ec Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 19 Jun 2024 09:08:18 -0700
Subject: [PATCH 3/6] fixes in emission and test case

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp           |  6 +++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp     | 15 ++++++++-------
 .../cooperative_matrix.ll                         |  2 +-
 3 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 750d371a48d64..092ebc1c22541 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2556,9 +2556,9 @@ static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
       GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
   // Create or get an existing type from GlobalRegistry.
   return GR->getOrCreateOpTypeCoopMatr(
-      MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(1),
-      ExtensionType->getIntParameter(2), ExtensionType->getIntParameter(3),
-      ExtensionType->getIntParameter(4));
+      MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
+      ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
+      ExtensionType->getIntParameter(3));
 }
 
 static SPIRVType *
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 18ae2679b216a..6c6a4b3a14ddb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1197,13 +1197,14 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
   if (ResVReg.isValid())
     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
   ResVReg = createTypeVReg(MIRBuilder);
-  SPIRVType *SpirvTy = MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
-                           .addDef(ResVReg)
-                           .addUse(getSPIRVTypeID(ElemType))
-                           .addImm(Scope)
-                           .addImm(Rows)
-                           .addImm(Columns)
-                           .addImm(Use);
+  SPIRVType *SpirvTy =
+      MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
+          .addDef(ResVReg)
+          .addUse(getSPIRVTypeID(ElemType))
+          .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
+          .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
+          .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
+          .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
   DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
   return SpirvTy;
 }
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
index 7b964b54fd73e..3daf070b356de 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -20,7 +20,7 @@
 ; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const2]] %[[#Const48]] %[[#Const12]] %[[#Const1]]
 ; CHECK: OpCompositeConstruct %[[#MatTy1]]
 ; CHECK: %[[#Load1:]] = OpCooperativeMatrixLoadKHR %[[#MatTy2]]
-; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#]] %[[#Load1]]
+; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#Load1]]
 ; CHECK: OpCooperativeMatrixLoadKHR %[[#MatTy3]]
 ; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]]
 ; CHECK: OpCooperativeMatrixStoreKHR

>From 7a6a1475fe32a4aece79fcd4db2a0656058f015d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 19 Jun 2024 09:55:20 -0700
Subject: [PATCH 4/6] fix and update test

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 29 +++++++++++++++----
 .../cooperative_matrix.ll                     |  2 ++
 2 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 092ebc1c22541..35b97cf949b75 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -558,16 +558,21 @@ static Register buildMemSemanticsReg(Register SemanticsRegister,
 
 static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
                                const SPIRV::IncomingCall *Call,
-                               Register TypeReg = Register(0)) {
+                               Register TypeReg,
+                               ArrayRef<uint32_t> ImmArgs = {}) {
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   auto MIB = MIRBuilder.buildInstr(Opcode);
   if (TypeReg.isValid())
     MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
-  for (Register ArgReg : Call->Arguments) {
+  unsigned Sz = Call->Arguments.size() - ImmArgs.size();
+  for (unsigned i = 0; i < Sz; ++i) {
+    Register ArgReg = Call->Arguments[i];
     if (!MRI->getRegClassOrNull(ArgReg))
       MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
     MIB.addUse(ArgReg);
   }
+  for (uint32_t ImmArg : ImmArgs)
+    MIB.addImm(ImmArg);
   return true;
 }
 
@@ -575,7 +580,7 @@ static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
 static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call);
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
 
   assert(Call->Arguments.size() == 2 &&
          "Need 2 arguments for atomic init translation");
@@ -633,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call);
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
 
   Register ScopeRegister =
       buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -870,7 +875,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                              MachineIRBuilder &MIRBuilder,
                              SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, Opcode, Call);
+    return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
 
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
@@ -1837,8 +1842,20 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
   const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
   unsigned Opcode =
       SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+  bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
+  unsigned ArgSz = Call->Arguments.size();
+  unsigned LiteralIdx = 0;
+  if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
+    LiteralIdx = 3;
+  else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
+    LiteralIdx = 4;
+  SmallVector<uint32_t, 1> ImmArgs;
+  if (LiteralIdx > 0)
+    ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx],
+                                            MIRBuilder.getMRI()));
+  Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
   return buildOpFromWrapper(MIRBuilder, Opcode, Call,
-                            GR->getSPIRVTypeID(Call->ReturnType));
+                            IsSet ? TypeReg : Register(0), ImmArgs);
 }
 
 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
index 3daf070b356de..77e07d5b83abc 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -25,6 +25,8 @@
 ; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]]
 ; CHECK: OpCooperativeMatrixStoreKHR
 
+target triple = "spir64-unknown-unknown"
+
 %"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
 %"class.sycl::_V1::detail::array" = type { [2 x i64] }
 %"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" }

>From 1dc762712329aac47aa41551428a2011fce416b3 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 19 Jun 2024 10:00:53 -0700
Subject: [PATCH 5/6] update test

---
 .../cooperative_matrix.ll                      | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
index 77e07d5b83abc..092a1a3df9228 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -83,15 +83,15 @@ for.body.i:
   %mul12.i = mul nsw i32 %k.0.i, 48
   %conv13.i = zext i32 %mul12.i to i64
   %add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
-  %call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4)
-  %call1.i.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4) %call.ascast.i66.i, i32 0, i64 %_arg_K, i32 1)
+  %call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(3)
+  %call1.i.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3) %call.ascast.i66.i, i32 0, i64 %_arg_K, i32 1)
   %len = tail call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) %call1.i.i)
   %div20.i = mul nsw i32 %k.0.i, 12
   %conv21.i = zext i32 %div20.i to i64
   %mul23.i = mul i64 %mul22.i, %conv21.i
   %add.ptr.i111.i = getelementptr i8, ptr addrspace(1) %add.ptr.i108140.i, i64 %mul23.i
-  %call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(4)
-  %call1.i73.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(4) %call.ascast.i72.i, i32 0, i64 %mul22.i)
+  %call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(3)
+  %call1.i73.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3) %call.ascast.i72.i, i32 0, i64 %mul22.i)
   call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i)
   %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
   %call.i77.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) %call1.i.i, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) %call1.i73.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i, i32 12)
@@ -107,19 +107,19 @@ exit:
   %add.ptr.i.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_accC, i64 %mul37.i
   %mul39.i = mul nuw i64 %div2452.i, 12
   %add.ptr.i81.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i, i64 %mul39.i
-  %call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(4)
+  %call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(3)
   %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
-  tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(4) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i32 0, i64 %_arg_N, i32 1)
+  tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i32 0, i64 %_arg_N, i32 1)
   call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
   ret void
 }
 
 declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32)
 declare dso_local spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0))
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(4), i32, i64, i32)
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(4), i32, i64)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64)
 declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0), target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32)
-declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(4), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32)
+declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32)
 
 declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
 declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)

>From 85436a933af5ac0a05ba9a3113bf1b0f828e4409 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 20 Jun 2024 05:03:52 -0700
Subject: [PATCH 6/6] fix instruction emission and update test

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  14 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  12 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |   2 +
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |   2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |   2 +-
 .../cooperative_matrix.ll                     | 123 ++++--------------
 .../SPIRV/transcoding/OpPtrCastToGeneric.ll   |  30 +++++
 7 files changed, 79 insertions(+), 106 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 35b97cf949b75..f5f36075d4a31 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1850,10 +1850,20 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
   else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
     LiteralIdx = 4;
   SmallVector<uint32_t, 1> ImmArgs;
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   if (LiteralIdx > 0)
-    ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx],
-                                            MIRBuilder.getMRI()));
+    ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
   Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+  if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
+    SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
+    if (!CoopMatrType)
+      report_fatal_error("Can't find a register's type definition");
+    MIRBuilder.buildInstr(Opcode)
+        .addDef(Call->ReturnRegister)
+        .addUse(TypeReg)
+        .addUse(CoopMatrType->getOperand(0).getReg());
+    return true;
+  }
   return buildOpFromWrapper(MIRBuilder, Opcode, Call,
                             IsSet ? TypeReg : Register(0), ImmArgs);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6c6a4b3a14ddb..b8710d24bff94 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1080,12 +1080,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
   return IntType && IntType->getOperand(2).getImm() != 0;
 }
 
+SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
+  return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
+             ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
+             : nullptr;
+}
+
 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
-  SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg);
-  SPIRVType *ElemType =
-      PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
-          ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
-          : nullptr;
+  SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
   return ElemType ? ElemType->getOpcode() : 0;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 0ca9f58695dff..cc4e20b8247cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -292,6 +292,8 @@ class SPIRVGlobalRegistry {
     return Res->second;
   }
 
+  // Return a pointee's type, or nullptr otherwise.
+  SPIRVType *getPointeeType(SPIRVType *PtrType);
   // Return a pointee's type op code, or 0 otherwise.
   unsigned getPointeeTypeOp(Register PtrReg);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index bbc24dfbc9b10..63549b06e9670 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -874,7 +874,7 @@ def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res),
                   "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">;
 def OpCooperativeMatrixStoreKHR: Op<4458, (outs),
                   (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops),
-                  "OpCooperativeMatrixStoreKHR $pointer $objectToStore  $memory_layout">;
+                  "OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">;
 def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
                   (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops),
                   "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b9e5569029cfd..3134a9108e5e2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1105,7 +1105,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
   if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
     Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
     SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
-        SrcPtrTy, I, TII, SPIRV::StorageClass::Generic);
+        GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
     MachineBasicBlock &BB = *I.getParent();
     const DebugLoc &DL = I.getDebugLoc();
     bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
index 092a1a3df9228..1c41c7331cda8 100644
--- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll
@@ -7,119 +7,48 @@
 ; CHECK: OpCapability CooperativeMatrixKHR
 ; CHECK: OpExtension "SPV_KHR_cooperative_matrix"
 
-; CHECK-DAG: %[[#Int8Ty:]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
 ; CHECK-DAG: %[[#Const12:]] = OpConstant %[[#Int32Ty]] 12
 ; CHECK-DAG: %[[#Const48:]] = OpConstant %[[#Int32Ty]] 48
-; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int32Ty]] 0
 ; CHECK-DAG: %[[#Const3:]] = OpConstant %[[#Int32Ty]] 3
 ; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int32Ty]] 2
 ; CHECK-DAG: %[[#Const1:]] = OpConstant %[[#Int32Ty]] 1
+; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int32Ty]] 0
 ; CHECK-DAG: %[[#MatTy1:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const12]] %[[#Const2]]
-; CHECK-DAG: %[[#MatTy2:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const0]] %[[#Const12]] %[[#Const48]] %[[#Const0]]
-; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const2]] %[[#Const48]] %[[#Const12]] %[[#Const1]]
+; CHECK-DAG: %[[#MatTy2:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const48]] %[[#Const0]]
+; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const48]] %[[#Const12]] %[[#Const1]]
 ; CHECK: OpCompositeConstruct %[[#MatTy1]]
 ; CHECK: %[[#Load1:]] = OpCooperativeMatrixLoadKHR %[[#MatTy2]]
-; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#Load1]]
+; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#MatTy2:]]
 ; CHECK: OpCooperativeMatrixLoadKHR %[[#MatTy3]]
 ; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]]
 ; CHECK: OpCooperativeMatrixStoreKHR
 
-target triple = "spir64-unknown-unknown"
-
-%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" }
-%"class.sycl::_V1::detail::array" = type { [2 x i64] }
-%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" }
-
- at __spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
- at __spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
-
-define weak_odr dso_local spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
+define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {
 entry:
-  %sub_c.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
-  %ref.tmp29.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
-  %agg.tmp15.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::range", ptr %_arg_accB5, i64 0, i32 0, i32 0, i64 1
-  %agg.tmp15.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp15.sroa.0.sroa.2.0..sroa_idx, align 8
-  %agg.tmp16.sroa.0.sroa.0.0.copyload = load i64, ptr %_arg_accB6, align 8
-  %agg.tmp16.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::id", ptr %_arg_accB6, i64 0, i32 0, i32 0, i64 1
-  %agg.tmp16.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp16.sroa.0.sroa.2.0..sroa_idx, align 8
-  %mul.i4.i.i.i.i45 = mul i64 %agg.tmp16.sroa.0.sroa.0.0.copyload, %agg.tmp15.sroa.0.sroa.2.0.copyload
-  %add.i6.i.i.i.i46 = add i64 %mul.i4.i.i.i.i45, %agg.tmp16.sroa.0.sroa.2.0.copyload
-  %add.ptr.i47 = getelementptr inbounds i8, ptr addrspace(1) %_arg_accB, i64 %add.i6.i.i.i.i46
-  %0 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
-  %1 = extractelement <3 x i64> %0, i64 1
-  %2 = extractelement <3 x i64> %0, i64 0
-  %3 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32
-  %4 = extractelement <3 x i64> %3, i64 1
-  %5 = extractelement <3 x i64> %3, i64 0
-  %cmp.i.i = icmp ult i64 %1, 2147483648
-  %cmp.i54.i = icmp ult i64 %2, 2147483648
-  %cmp.i56.i = icmp ult i64 %4, 2147483648
-  %sub.i = sub nsw i64 %1, %4
-  %cmp.i58.i = icmp ult i64 %5, 2147483648
-  %sub5.i = sub nsw i64 %2, %5
-  call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
-  %call.i.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 0)
-  store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8
-  %mul.i = mul nsw i64 %sub.i, 12
-  %div2452.i = lshr i64 %sub5.i, 4
-  %mul26.i = mul i64 %div2452.i, 48
-  %div.i = udiv i64 %_arg_K, 48
-  %mul11.i = mul i64 %mul.i, %_arg_K
-  %add.ptr.i93.i = getelementptr inbounds i8, ptr addrspace(1) %_arg_accA, i64 %mul11.i
-  %idx.neg.i.i104.i = sub i64 0, %add.i6.i.i.i.i46
-  %add.ptr.i.i105141.i = getelementptr i8, ptr addrspace(1) %add.ptr.i47, i64 %mul26.i
-  %mul22.i = shl i64 %_arg_N, 2
-  %add.ptr.i108140.i = getelementptr i8, ptr addrspace(1) %add.ptr.i.i105141.i, i64 %idx.neg.i.i104.i
-  br label %for.cond.i
-
-for.cond.i:
-  %k.0.i = phi i32 [ 0, %entry ], [ %add.i, %for.body.i ]
-  %conv.i = zext i32 %k.0.i to i64
-  %cmp.i = icmp ugt i64 %div.i, %conv.i
-  br i1 %cmp.i, label %for.body.i, label %exit
-
-for.body.i:
-  %mul12.i = mul nsw i32 %k.0.i, 48
-  %conv13.i = zext i32 %mul12.i to i64
-  %add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
-  %call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(3)
-  %call1.i.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3) %call.ascast.i66.i, i32 0, i64 %_arg_K, i32 1)
-  %len = tail call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) %call1.i.i)
-  %div20.i = mul nsw i32 %k.0.i, 12
-  %conv21.i = zext i32 %div20.i to i64
-  %mul23.i = mul i64 %mul22.i, %conv21.i
-  %add.ptr.i111.i = getelementptr i8, ptr addrspace(1) %add.ptr.i108140.i, i64 %mul23.i
-  %call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(3)
-  %call1.i73.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3) %call.ascast.i72.i, i32 0, i64 %mul22.i)
-  call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i)
-  %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
-  %call.i77.i = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) %call1.i.i, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) %call1.i73.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i, i32 12)
-  store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i77.i, ptr %ref.tmp29.sroa.0.i, align 8
-  %ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i = load i64, ptr %ref.tmp29.sroa.0.i, align 8
-  store i64 %ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i, ptr %sub_c.sroa.0.i, align 8
-  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i)
-  %add.i = add nuw nsw i32 %k.0.i, 1
-  br label %for.cond.i
-
-exit:
-  %mul37.i = mul i64 %mul.i, %_arg_N
-  %add.ptr.i.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_accC, i64 %mul37.i
-  %mul39.i = mul nuw i64 %div2452.i, 12
-  %add.ptr.i81.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i, i64 %mul39.i
-  %call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(3)
-  %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
-  tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3) %call.ascast.i.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i32 0, i64 %_arg_N, i32 1)
-  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
+  %addr1 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
+  %res = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
+  %m1 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 0)
+  store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m1, ptr %addr1, align 8
+  %accA3 = addrspacecast ptr addrspace(1) %_arg_accA to ptr addrspace(3)
+  %m2 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3) %accA3, i32 0, i64 %_arg_K, i32 1)
+  %len = tail call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) %m2)
+  %accB3 = addrspacecast ptr addrspace(1) %_arg_accB to ptr addrspace(3)
+  %m3 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3) %accB3, i32 0, i64 0)
+  %m4 = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %addr1, align 8
+  %m5 = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) %m2, target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) %m3, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m4, i32 12)
+  store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m5, ptr %res, align 8
+  %r = load i64, ptr %res, align 8
+  store i64 %r, ptr %addr1, align 8
+  %accC3 = addrspacecast ptr addrspace(1) %_arg_accC to ptr addrspace(3)
+  %m6 = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %addr1, align 8
+  tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3) %accC3, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %m6, i32 0, i64 %_arg_N, i32 1)
   ret void
 }
 
 declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32)
-declare dso_local spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0))
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32)
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64)
-declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 0, 12, 48, 0), target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32)
+declare dso_local spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0))
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64)
+declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0), target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32)
 declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32)
-
-declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
-declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll
new file mode 100644
index 0000000000000..818243ab19e41
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll
@@ -0,0 +1,30 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
+; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
+; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Arg1:]] = OpFunctionParameter %[[#GlobalCharPtr]]
+; CHECK-SPIRV: %[[#Ptr1:]] = OpPtrCastToGeneric %[[#GenericCharPtr]] %[[#Arg1]]
+; CHECK-SPIRV: OpGenericCastToPtr %[[#LocalCharPtr]] %[[#Ptr1]]
+; CHECK-SPIRV: OpFunctionEnd
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Arg2:]] = OpFunctionParameter %[[#GlobalCharPtr]]
+; CHECK-SPIRV: %[[#Ptr2:]] = OpPtrCastToGeneric %[[#GenericCharPtr]] %[[#Arg2]]
+; CHECK-SPIRV: OpGenericCastToPtr %[[#LocalCharPtr]] %[[#Ptr2]]
+; CHECK-SPIRV: OpFunctionEnd
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) {
+entry:
+  %p = addrspacecast ptr addrspace(1) %arg to ptr addrspace(3)
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %arg) {
+entry:
+  %p1 = addrspacecast ptr addrspace(1) %arg to ptr addrspace(4)
+  %p2 = addrspacecast ptr addrspace(4) %p1 to ptr addrspace(3)
+  ret void
+}



More information about the llvm-commits mailing list