[Mlir-commits] [mlir] [mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix (PR #66455)

Jakub Kuderski llvmlistbot at llvm.org
Thu Sep 14 18:29:39 PDT 2023


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/66455:

These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.

Also improve match failure reporting and error handling in type conversion.

>From 80f437c0f6bd139ea02fb4cb32fd5922cebd4914 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 13 Sep 2023 16:30:46 -0400
Subject: [PATCH 1/2] [mlir][spirv][gpu] Add conversion for load/store/mad coop
 matrix ops

This is plugged in as an alternative lowering path in the gpu to spirv
dialect conversion.

The remaining lowering patterns will be added in a future patch.
---
 .../mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h   |  10 ++
 mlir/include/mlir/Conversion/Passes.td        |   6 +-
 .../SPIRV/IR/SPIRVCooperativeMatrixOps.td     |  25 +++
 .../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp  |  20 ++-
 .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp  | 145 +++++++++++++++++-
 .../wmma-ops-to-spirv-khr-coop-matrix.mlir    |  80 ++++++++++
 .../wmma-ops-to-spirv-nv-coop-matrix.mlir     |   3 +-
 7 files changed, 278 insertions(+), 11 deletions(-)
 create mode 100644 mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 6c4643da1884900..c258513ed4878ea 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,21 @@ class MMAMatrixType;
 void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns);
 
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the KHR Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
 /// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
 /// using the NV Cooperative Matrix extension.
 void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
 
+/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixType
+convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
+
 /// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
 /// `type`.
 spirv::CooperativeMatrixNVType
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3218760931b8cb0..5e0f976b18f7da5 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
   let options = [
     Option<"use64bitIndex", "use-64bit-index",
            "bool", /*default=*/"false",
-           "Use 64-bit integers to convert index types">
+           "Use 64-bit integers to convert index types">,
+    Option<"useCoopMatrixNV", "use-coop-matrix-nv",
+           "bool", /*default=*/"false",
+           "Use the NV cooperative matrix extension insted of the KHR extension"
+           " to lower GPU WMMA ops">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index b5ea0774f589d16..34c76c5e9382302 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   let results = (outs
     SPIRV_AnyCooperativeMatrix:$result
   );
+
+  let builders = [
+    OpBuilder<(ins "Type":$result, "Value":$pointer,
+                   "spirv::ConstantOp":$stride,
+                   "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+      build($_builder, $_state, result, pointer, layout, stride,
+            spirv::MemoryAccessAttr{});
+    }]>
+  ];
 }
 
 // -----
@@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   );
 
   let results = (outs);
+
+  let builders = [
+    OpBuilder<(ins "Value":$pointer, "Value":$object,
+                   "spirv::ConstantOp":$stride,
+                   "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+      build($_builder, $_state, pointer, object, layout, stride,
+            spirv::MemoryAccessAttr{});
+    }]>
+  ];
 }
 
 // -----
@@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
   let results = (outs
     SPIRV_AnyCooperativeMatrix:$result
   );
+
+  let builders = [
+    OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
+      build($_builder, $_state, a, b, c,
+            spirv::CooperativeMatrixOperandsKHRAttr{});
+    }]>
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index d0ce58597f980d4..5b05c45bf602509 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
     SPIRVConversionOptions options;
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
-    typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToSPIRVCoopMatrixNVType(type);
+
+    typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
+                                    gpu::MMAMatrixType type) -> Type {
+      if (useNV)
+        return convertMMAToSPIRVCoopMatrixNVType(type);
+
+      return convertMMAToSPIRVCoopMatrixType(type);
     });
+
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-    populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
-                                                         patterns);
+    if (this->useCoopMatrixNV) {
+      populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+                                                           patterns);
+    } else {
+      populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
+                                                            patterns);
+    }
+
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
     mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index bf3fff027fe384a..d73cd5686d66e92 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -18,22 +18,28 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/StringSwitch.h"
 
-namespace mlir::nv {
-namespace {
+#include <cassert>
 
+namespace mlir {
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
 /// Returns false if cannot.
 ///
 /// See SPV_NV_cooperative_matrix for supported elementwise ops.
 static bool createElementwiseOp(ConversionPatternRewriter &builder,
-                                gpu::SubgroupMmaElementwiseOp op,
-                                spirv::CooperativeMatrixNVType coopType,
+                                gpu::SubgroupMmaElementwiseOp op, Type coopType,
                                 ValueRange operands) {
+  assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+      coopType)));
+
   switch (op.getOpType()) {
   case gpu::MMAElementwiseOp::ADDF:
     builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   return false;
 }
 
+//===----------------------------------------------------------------------===//
+// SPV_KHR_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace khr {
+namespace {
+
+/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op->getLoc();
+
+    auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
+    MemRefType memrefType = op.getSrcMemref().getType();
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+
+    auto coopType =
+        typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    int64_t stride = op.getLeadDimension().getSExtValue();
+    IntegerType i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+    bool isColMajor = op.getTranspose().value_or(false);
+    auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+                             : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
+        op, coopType, bufferPtr, strideValue, layout);
+    return success();
+  }
+};
+
+/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op->getLoc();
+
+    auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+
+    int64_t stride = op.getLeadDimension().getSExtValue();
+    IntegerType i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+    bool isColMajor = op.getTranspose().value_or(false);
+    auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+                             : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
+        op, bufferPtr, adaptor.getSrc(), strideValue, layout);
+    return success();
+  }
+};
+
+/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
+/// dialect.
+struct WmmaMmaOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
+        subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
+        adaptor.getOpC());
+    return success();
+  }
+};
+
+} // namespace
+} // namespace khr
+
+//===----------------------------------------------------------------------===//
+// SPV_NV_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace nv {
+namespace {
+
 /// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
 /// dialect.
 struct WmmaLoadOpToSPIRVLowering final
@@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
 };
 
 } // namespace
-} // namespace mlir::nv
+} // namespace nv
+} // namespace mlir
 
 mlir::spirv::CooperativeMatrixNVType
 mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
@@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
       elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
 }
 
+mlir::spirv::CooperativeMatrixType
+mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
+  ArrayRef<int64_t> retTypeShape = type.getShape();
+  Type elementType = type.getElementType();
+
+  auto use =
+      llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
+          .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
+          .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
+          .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
+
+  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
+                                           retTypeShape[1],
+                                           spirv::Scope::Subgroup, use);
+}
+
+void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
+    SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+  using namespace mlir;
+  MLIRContext *context = patterns.getContext();
+  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
+               khr::WmmaStoreOpToSPIRVLowering>(converter, context);
+}
+
 void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
   using namespace mlir;
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
new file mode 100644
index 000000000000000..0818791b98471da
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
+// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+    [Shader, CooperativeMatrixKHR, Float16],
+    [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
+    #spirv.resource_limits<>>} {
+
+  gpu.module @kernels {
+    // CHECK-LABEL: spirv.func @gpu_wmma_load_op
+    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
+    gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK:      %[[STRIDE:.+]] = spirv.Constant 32 : i32
+      // CHECK:      spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
+      // CHECK-SAME:   !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
+        memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+      // CHECK:      spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
+      // CHECK-SAME:   !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
+        memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_store_op
+    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
+                                %arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK:      %[[STRIDE:.+]] = spirv.Constant 32 : i32
+      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
+      // CHECK-SAME:  !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
+
+      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
+      // CHECK-SAME:  !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
+       // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
+                              %B: !gpu.mma_matrix<16x16xf16, "BOp">,
+                              %C: !gpu.mma_matrix<16x16xf16, "COp">,
+                              %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK:      %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
+      // CHECK-SAME:   !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
+      // CHECK-SAME:   !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
+      // CHECK-SAME:   -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
+                                                 !gpu.mma_matrix<16x16xf16, "BOp">
+                                                 -> !gpu.mma_matrix<16x16xf16, "COp">
+
+      %i = arith.constant 0 : index
+      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
+      gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+  }
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
index 5811c791f308d1e..ec7da92704c07c2 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
+// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
 
 module attributes {
   gpu.container_module,

>From 98a33240cf5094354d9438416a33a351859755f0 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 14 Sep 2023 21:06:43 -0400
Subject: [PATCH 2/2] [mlir][spirv][gpu] Convert remaining wmma ops to KHR coop
 matrix

These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.

Also improve match failure reporting and error handling in type
conversion.
---
 .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp  | 231 ++++++++++--------
 .../wmma-ops-to-spirv-khr-coop-matrix.mlir    |  96 +++++++-
 2 files changed, 224 insertions(+), 103 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index d73cd5686d66e92..eb7fcb63d920d8f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -24,11 +24,17 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 
 #include <cassert>
 
 namespace mlir {
+//===----------------------------------------------------------------------===//
+// Patterns and helpers used by both the KHR and the NV lowering paths.
+//===----------------------------------------------------------------------===//
+
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
 /// Returns false if cannot.
@@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   return false;
 }
 
+bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
+  assert(!operands.empty());
+  if (!llvm::all_equal(
+          llvm::map_range(operands, [](Value v) { return v.getType(); })))
+    return false;
+
+  return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+      operands.front().getType());
+}
+
+namespace {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaConstantOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 1);
+    Value cst = adaptor.getOperands().front();
+    auto coopType = getTypeConverter()->convertType(op.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
+    return success();
+  }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// the default case.
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // All operands should be of cooperative matrix types.
+    if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+      return rewriter.notifyMatchFailure(op,
+                                         "not all operands are coop matrices");
+    }
+
+    auto coopType = getTypeConverter()->convertType(op.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    return success(
+        createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
+  }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// matrix times scalar case.
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (adaptor.getOperands().size() != 2)
+      return failure();
+
+    // All operands should be of cooperative matrix types.
+    if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+      return rewriter.notifyMatchFailure(op,
+                                         "not all operands are coop matrices");
+    }
+
+    if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
+      return failure();
+
+    // Use the original operands to check whether one of the operands is a splat
+    // scalar value.
+    Value lhs = op.getOperands().front();
+    Value rhs = op.getOperands().back();
+    Value splat = nullptr;
+    Value matrix = nullptr;
+    if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+      splat = adaptor.getOperands().front();
+      matrix = adaptor.getOperands().back();
+    } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+      matrix = adaptor.getOperands().front();
+      splat = adaptor.getOperands().back();
+    }
+    if (!splat || !matrix)
+      return rewriter.notifyMatchFailure(op, "no splat operand");
+
+    // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
+    Value scalar;
+    auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
+    if (!cc) {
+      return rewriter.notifyMatchFailure(op,
+                                         "splat is not a composite construct");
+    }
+
+    assert(cc.getConstituents().size() == 1);
+    scalar = cc.getConstituents().front();
+
+    auto coopType = getTypeConverter()->convertType(op.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+    rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
+        op, coopType, ValueRange{matrix, scalar});
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // SPV_KHR_cooperative_matrix
 //===----------------------------------------------------------------------===//
@@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final
   }
 };
 
-/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
-/// ops.
-struct WmmaConstantOpToSPIRVLowering final
-    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Value cst = adaptor.getOperands()[0];
-    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
-        cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        subgroupMmaConstantMatrixOp, coopType, cst);
-    return success();
-  }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering final
-    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // All operands should be of cooperative matrix types.
-    for (Value operand : adaptor.getOperands()) {
-      if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
-        return failure();
-    }
-    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
-        cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
-    return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
-                                       adaptor.getOperands()));
-  }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering final
-    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (adaptor.getOperands().size() != 2)
-      return failure();
-    // All operands should be of cooperative matrix types.
-    for (Value operand : adaptor.getOperands()) {
-      if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
-        return failure();
-    }
-
-    if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
-      return failure();
-
-    // Use the original operands to check whether one of the operands is a splat
-    // scalar value.
-    Value lhs = elementwiseOp.getOperands().front();
-    Value rhs = elementwiseOp.getOperands().back();
-    Value splat = nullptr;
-    Value matrix = nullptr;
-    if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
-      splat = adaptor.getOperands().front();
-      matrix = adaptor.getOperands().back();
-    } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
-      matrix = adaptor.getOperands().front();
-      splat = adaptor.getOperands().back();
-    }
-    if (!splat || !matrix)
-      return failure();
-
-    // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
-    Value scalar = nullptr;
-    auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
-    if (!cc)
-      return failure();
-    assert(cc.getConstituents().size() == 1);
-    scalar = cc.getConstituents().front();
-
-    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
-        cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
-    rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
-        elementwiseOp, coopType, ValueRange{matrix, scalar});
-    return success();
-  }
-};
-
 } // namespace
 } // namespace nv
 } // namespace mlir
@@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
   using namespace mlir;
   MLIRContext *context = patterns.getContext();
   patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
-               khr::WmmaStoreOpToSPIRVLowering>(converter, context);
+               khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+               WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+  // Give the following patterns higher benefit to prevail over the default one.
+  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+                                                          /*benefit=*/2);
 }
 
 void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
   using namespace mlir;
   MLIRContext *context = patterns.getContext();
-  patterns
-      .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
-           nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
-           nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+  patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
+               nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+               WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
   // Give the following patterns higher benefit to prevail over the default one.
-  patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
-                                                              context,
-                                                              /*benefit=*/2);
+  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
+                                                          /*benefit=*/2);
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 0818791b98471da..f129cc8ce84ec39 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -69,12 +69,106 @@ module attributes {
                                                  -> !gpu.mma_matrix<16x16xf16, "COp">
 
       %i = arith.constant 0 : index
-      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
+      // CHECK:      spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor>
       gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
         !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
       // CHECK: spirv.Return
       gpu.return
     }
 
+    // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
+    gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK:       %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16
+      %cst = arith.constant 1.0 : f16
+      // CHECK:       %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] :
+      // CHECK-SAME:   (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+
+      %i = arith.constant 0 : index
+      // CHECK:      spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor>
+      gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
+                                              %B: !gpu.mma_matrix<16x16xf16, "COp">,
+                                              %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %C = gpu.subgroup_mma_elementwise addf %A, %B :
+        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK:  {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %D = gpu.subgroup_mma_elementwise negatef %C :
+        (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %E = gpu.subgroup_mma_elementwise divf %D, %A :
+        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK:  {{%.*}} = spirv.FConvert {{%.*}} :
+      // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+      %F = gpu.subgroup_mma_elementwise extf %E :
+        (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+
+      %i = arith.constant 0 : index
+      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
+      gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
+    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    // CHECK-SAME:    %[[S:.+]]: f16
+    gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(
+      %A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16,
+      %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 0 : index
+
+      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
+      %C = gpu.subgroup_mma_elementwise mulf %A, %B :
+        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+
+      // CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
+      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor>
+      %D = gpu.subgroup_mma_elementwise mulf %B, %C :
+        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
+    // CHECK-SAME:    %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    // CHECK-SAME:    %[[S:.+]]: f16
+    gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(
+      %A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16,
+      %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 0 : index
+
+      // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %C = gpu.subgroup_mma_elementwise addf %A, %B :
+        (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+
+      // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
+      gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
   }
 }



More information about the Mlir-commits mailing list