[Mlir-commits] [mlir] [mlir][spirv][gpu] Clean up wmma to coop matrix NV conversion. NFC. (PR #66278)

Jakub Kuderski llvmlistbot at llvm.org
Wed Sep 13 12:27:16 PDT 2023


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/66278:

This is a cleanup in preparation for adding a second conversion path using the KHR cooperative matrix extension.

Make the existing lowering explicit about emitting ops from the NV coop matrix extension. Clean up surrounding code.

>From 01aa362e74aa4a3885c4a8644952147c50616706 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 13 Sep 2023 15:24:34 -0400
Subject: [PATCH] [mlir][spirv][gpu] Clean up wmma to coop matrix NV
 conversion. NFC.

This is a cleanup in preparation for adding a second conversion path
using the KHR cooperative matrix extension.

Make the existing lowering explicit about emitting ops from the NV coop
matrix extension. Clean up surrounding code.
---
 .../mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h   | 14 ++--
 .../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp  | 11 +--
 .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp  | 75 ++++++++++---------
 ... => wmma-ops-to-spirv-nv-coop-matrix.mlir} |  2 +-
 4 files changed, 55 insertions(+), 47 deletions(-)
 rename mlir/test/Conversion/GPUToSPIRV/{wmma-ops-to-spirv.mlir => wmma-ops-to-spirv-nv-coop-matrix.mlir} (98%)

diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 3c3281513f60d89..6c4643da1884900 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,15 @@ class MMAMatrixType;
 void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns);
 
-/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
-void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
-                                              RewritePatternSet &patterns);
-
-spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the NV Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
+/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixNVType
+convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index f37c70a771f5916..d0ce58597f980d4 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -39,8 +39,7 @@ namespace {
 /// replace it).
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
-class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
-public:
+struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
   explicit GPUToSPIRVPass(bool mapMemorySpace)
       : mapMemorySpace(mapMemorySpace) {}
   void runOnOperation() override;
@@ -48,7 +47,6 @@ class GPUToSPIRVPass : public impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
 private:
   bool mapMemorySpace;
 };
-} // namespace
 
 void GPUToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
@@ -89,11 +87,12 @@ void GPUToSPIRVPass::runOnOperation() {
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
     typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToSPIRVType(type);
+      return convertMMAToSPIRVCoopMatrixNVType(type);
     });
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-    populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
+    populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+                                                         patterns);
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
     mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
@@ -105,6 +104,8 @@ void GPUToSPIRVPass::runOnOperation() {
   }
 }
 
+} // namespace
+
 std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
   return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 3851fb728b6654b..bf3fff027fe384a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -1,4 +1,4 @@
-//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
+//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
-// SPIRV Dialect ops.
+// SPIRV Cooperative Matrix ops.
 //
 //===----------------------------------------------------------------------===//
 
@@ -22,7 +22,8 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/TypeUtilities.h"
 
-using namespace mlir;
+namespace mlir::nv {
+namespace {
 
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
@@ -70,12 +71,10 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   return false;
 }
 
-namespace {
-
-/// This class implements the conversion of GPU MMA loadOp to
-/// CooperativeMatrixLoad op in the SPIRV dialect.
-struct WmmaLoadOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -90,7 +89,7 @@ struct WmmaLoadOpToSPIRVLowering
     Value bufferPtr = spirv::getElementPtr(
         *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
         adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
-    auto coopType = convertMMAToSPIRVType(retType);
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
     int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
     auto i32Type = rewriter.getI32Type();
     auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -105,10 +104,10 @@ struct WmmaLoadOpToSPIRVLowering
   }
 };
 
-/// This class implements the conversion of GPU MMA StoreOp to
-/// CooperativeMatrixStore op in the SPIRV dialect.
-struct WmmaStoreOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+/// Converts the GPU MMA StoreOp to NVCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -136,10 +135,10 @@ struct WmmaStoreOpToSPIRVLowering
   }
 };
 
-/// This class implements the conversion of GPU MMA Compute to
-/// CooperativeMatrixMulAdd op in the SPIRV dialect.
-struct WmmaMmaOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+/// Converts GPU MMA Compute to
+/// NVCooperativeMatrixMulAdd op in the SPIRV dialect.
+struct WmmaMmaOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -153,9 +152,10 @@ struct WmmaMmaOpToSPIRVLowering
   }
 };
 
-/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
-struct WmmaConstantOpToSPIRVLowering
-    : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
+/// ops.
+struct WmmaConstantOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -163,7 +163,7 @@ struct WmmaConstantOpToSPIRVLowering
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Value cst = adaptor.getOperands()[0];
-    auto coopType = convertMMAToSPIRVType(
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
         cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         subgroupMmaConstantMatrixOp, coopType, cst);
@@ -173,8 +173,8 @@ struct WmmaConstantOpToSPIRVLowering
 
 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
 /// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering
-    : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -186,7 +186,7 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
       if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
         return failure();
     }
-    auto coopType = convertMMAToSPIRVType(
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
         cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
     return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
                                        adaptor.getOperands()));
@@ -195,8 +195,8 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
 
 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
 /// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering
-    : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -238,7 +238,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
     assert(cc.getConstituents().size() == 1);
     scalar = cc.getConstituents().front();
 
-    auto coopType = convertMMAToSPIRVType(
+    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
         cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
     rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
         elementwiseOp, coopType, ValueRange{matrix, scalar});
@@ -247,23 +247,26 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
 };
 
 } // namespace
+} // namespace mlir::nv
 
-/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
 mlir::spirv::CooperativeMatrixNVType
-mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
+mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
   ArrayRef<int64_t> retTypeShape = type.getShape();
   Type elementType = type.getElementType();
   return spirv::CooperativeMatrixNVType::get(
       elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
 }
 
-void mlir::populateGpuWMMAToSPIRVConversionPatterns(
+void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+  using namespace mlir;
   MLIRContext *context = patterns.getContext();
-  patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
-               WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
-               WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
+  patterns
+      .add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
+           nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
+           nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
   // Give the following patterns higher benefit to prevail over the default one.
-  patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
-                                                          /*benefit=*/2);
+  patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
+                                                              context,
+                                                              /*benefit=*/2);
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
similarity index 98%
rename from mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
rename to mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
index a53eca65fc98699..5811c791f308d1e 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
 
 module attributes {
   gpu.container_module,



More information about the Mlir-commits mailing list