[Mlir-commits] [mlir] [mlir][spirv][gpu] Add conversion for load/store/mad coop matrix ops (PR #66311)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 13 18:42:02 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core
            
<details>
<summary>Changes</summary>
This is plugged in as an alternative lowering path in the gpu to spirv dialect conversion. Add custom op builders for coop matrix ops to make the create functions nicer to work with and less error-prone. The latter is accomplished by following the op syntax and also requiring stride to be a constant op to avoid confusion around the order of arguments.

The remaining lowering patterns will be added in a future patch.
--
Full diff: https://github.com/llvm/llvm-project/pull/66311.diff

7 Files Affected:

- (modified) mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h (+10) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+5-1) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+25) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+16-4) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+140-5) 
- (added) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (+80) 
- (modified) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir (+2-1) 


<pre>
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 6c4643da1884900..c258513ed4878ea 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,21 @@ class MMAMatrixType;
 void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns);
 
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the KHR Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
 /// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
 /// using the NV Cooperative Matrix extension.
 void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
 
+/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixType
+convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
+
 /// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
 /// `type`.
 spirv::CooperativeMatrixNVType
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3218760931b8cb0..5e0f976b18f7da5 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
   let options = [
     Option<"use64bitIndex", "use-64bit-index",
            "bool", /*default=*/"false",
-           "Use 64-bit integers to convert index types">
+           "Use 64-bit integers to convert index types">,
+    Option<"useCoopMatrixNV", "use-coop-matrix-nv",
+           "bool", /*default=*/"false",
+           "Use the NV cooperative matrix extension insted of the KHR extension"
+           " to lower GPU WMMA ops">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index b5ea0774f589d16..34c76c5e9382302 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   let results = (outs
     SPIRV_AnyCooperativeMatrix:$result
   );
+
+  let builders = [
+    OpBuilder<(ins "Type":$result, "Value":$pointer,
+                   "spirv::ConstantOp":$stride,
+                   "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+      build($_builder, $_state, result, pointer, layout, stride,
+            spirv::MemoryAccessAttr{});
+    }]>
+  ];
 }
 
 // -----
@@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   );
 
   let results = (outs);
+
+  let builders = [
+    OpBuilder<(ins "Value":$pointer, "Value":$object,
+                   "spirv::ConstantOp":$stride,
+                   "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+      build($_builder, $_state, pointer, object, layout, stride,
+            spirv::MemoryAccessAttr{});
+    }]>
+  ];
 }
 
 // -----
@@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
   let results = (outs
     SPIRV_AnyCooperativeMatrix:$result
   );
+
+  let builders = [
+    OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
+      build($_builder, $_state, a, b, c,
+            spirv::CooperativeMatrixOperandsKHRAttr{});
+    }]>
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index d0ce58597f980d4..5b05c45bf602509 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
     SPIRVConversionOptions options;
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
-    typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToSPIRVCoopMatrixNVType(type);
+
+    typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
+                                    gpu::MMAMatrixType type) -> Type {
+      if (useNV)
+        return convertMMAToSPIRVCoopMatrixNVType(type);
+
+      return convertMMAToSPIRVCoopMatrixType(type);
     });
+
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-    populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
-                                                         patterns);
+    if (this->useCoopMatrixNV) {
+      populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+                                                           patterns);
+    } else {
+      populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
+                                                            patterns);
+    }
+
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
     mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index bf3fff027fe384a..d73cd5686d66e92 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -18,22 +18,28 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/StringSwitch.h"
 
-namespace mlir::nv {
-namespace {
+#include <cassert>
 
+namespace mlir {
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
 /// Returns false if cannot.
 ///
 /// See SPV_NV_cooperative_matrix for supported elementwise ops.
 static bool createElementwiseOp(ConversionPatternRewriter &builder,
-                                gpu::SubgroupMmaElementwiseOp op,
-                                spirv::CooperativeMatrixNVType coopType,
+                                gpu::SubgroupMmaElementwiseOp op, Type coopType,
                                 ValueRange operands) {
+  assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+      coopType)));
+
   switch (op.getOpType()) {
   case gpu::MMAElementwiseOp::ADDF:
     builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   return false;
 }
 
+//===----------------------------------------------------------------------===//
+// SPV_KHR_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace khr {
+namespace {
+
+/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op->getLoc();
+
+    auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
+    MemRefType memrefType = op.getSrcMemref().getType();
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+
+    auto coopType =
+        typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    int64_t stride = op.getLeadDimension().getSExtValue();
+    IntegerType i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+    bool isColMajor = op.getTranspose().value_or(false);
+    auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+                             : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
+        op, coopType, bufferPtr, strideValue, layout);
+    return success();
+  }
+};
+
+/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op->getLoc();
+
+    auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+
+    int64_t stride = op.getLeadDimension().getSExtValue();
+    IntegerType i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+    bool isColMajor = op.getTranspose().value_or(false);
+    auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+                             : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
+        op, bufferPtr, adaptor.getSrc(), strideValue, layout);
+    return success();
+  }
+};
+
+/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
+/// dialect.
+struct WmmaMmaOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
+        subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
+        adaptor.getOpC());
+    return success();
+  }
+};
+
+} // namespace
+} // namespace khr
+
+//===----------------------------------------------------------------------===//
+// SPV_NV_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace nv {
+namespace {
+
 /// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
 /// dialect.
 struct WmmaLoadOpToSPIRVLowering final
@@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
 };
 
 } // namespace
-} // namespace mlir::nv
+} // namespace nv
+} // namespace mlir
 
 mlir::spirv::CooperativeMatrixNVType
 mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
@@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
       elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
 }
 
+mlir::spirv::CooperativeMatrixType
+mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
+  ArrayRef<int64_t> retTypeShape = type.getShape();
+  Type elementType = type.getElementType();
+
+  auto use =
+      llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
+          .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
+          .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
+          .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
+
+  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
+                                           retTypeShape[1],
+                                           spirv::Scope::Subgroup, use);
+}
+
+void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
+    SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+  using namespace mlir;
+  MLIRContext *context = patterns.getContext();
+  patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
+               khr::WmmaStoreOpToSPIRVLowering>(converter, context);
+}
+
 void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
   using namespace mlir;
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
new file mode 100644
index 000000000000000..0818791b98471da
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
+// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+    [Shader, CooperativeMatrixKHR, Float16],
+    [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
+    #spirv.resource_limits<>>} {
+
+  gpu.module @kernels {
+    // CHECK-LABEL: spirv.func @gpu_wmma_load_op
+    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
+    gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK:      %[[STRIDE:.+]] = spirv.Constant 32 : i32
+      // CHECK:      spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
+      // CHECK-SAME:   !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
+        memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+      // CHECK:      spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
+      // CHECK-SAME:   !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
+        memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_store_op
+    // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
+    // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
+                                %arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK:      %[[STRIDE:.+]] = spirv.Constant 32 : i32
+      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
+      // CHECK-SAME:  !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
+
+      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
+      // CHECK-SAME:  !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
+       // CHECK: spirv.Return
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
+    // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
+                              %B: !gpu.mma_matrix<16x16xf16, "BOp">,
+                              %C: !gpu.mma_matrix<16x16xf16, "COp">,
+                              %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK:      %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
+      // CHECK-SAME:   !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
+      // CHECK-SAME:   !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
+      // CHECK-SAME:   -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
+                                                 !gpu.mma_matrix<16x16xf16, "BOp">
+                                                 -> !gpu.mma_matrix<16x16xf16, "COp">
+
+      %i = arith.constant 0 : index
+      // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
+      gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+
+  }
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
index 5811c791f308d1e..ec7da92704c07c2 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
+// RUN:   --split-input-file --verify-diagnostics %s | FileCheck %s
 
 module attributes {
   gpu.container_module,
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66311


More information about the Mlir-commits mailing list