[Mlir-commits] [mlir] [mlir][spirv][gpu] Default to KHR coop matrix. Clean up type conversion. (PR #67485)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 13:34:05 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

<details>
<summary>Changes</summary>

- Now that the KHR coop matrix implementation is robust, switch the gpu conversion pass to default to it.
- Use a populate function for MMA to coop matrix type conversions. This makes the API surface area smaller.

---
Full diff: https://github.com/llvm/llvm-project/pull/67485.diff


4 Files Affected:

- (modified) mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h (+5-13) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+1-1) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+2-8) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+39-28) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index c258513ed4878ea..cd650345f1daa2d 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -20,10 +20,6 @@
 namespace mlir {
 class SPIRVTypeConverter;
 
-namespace gpu {
-class MMAMatrixType;
-} // namespace gpu
-
 /// Appends to a pattern list additional patterns for translating GPU Ops to
 /// SPIR-V ops. For a gpu.func to be converted, it should have a
 /// spirv.entry_point_abi attribute.
@@ -40,15 +36,11 @@ void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
 void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
 
-/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
-/// `type`.
-spirv::CooperativeMatrixType
-convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
-
-/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
-/// `type`.
-spirv::CooperativeMatrixNVType
-convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type);
+/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
+/// conversion to the type converter. Defaults to KHR cooperative matrix types.
+/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
+void populateMMAToSPIRVCoopMatrixTypeConversion(
+    SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 11008baa0160efe..5e0f976b18f7da5 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -569,7 +569,7 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
            "bool", /*default=*/"false",
            "Use 64-bit integers to convert index types">,
     Option<"useCoopMatrixNV", "use-coop-matrix-nv",
-           "bool", /*default=*/"true",
+           "bool", /*default=*/"false",
            "Use the NV cooperative matrix extension insted of the KHR extension"
            " to lower GPU WMMA ops">,
   ];
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 5b05c45bf602509..272e3de8723aeb6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,14 +86,8 @@ void GPUToSPIRVPass::runOnOperation() {
     SPIRVConversionOptions options;
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
-
-    typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
-                                    gpu::MMAMatrixType type) -> Type {
-      if (useNV)
-        return convertMMAToSPIRVCoopMatrixNVType(type);
-
-      return convertMMAToSPIRVCoopMatrixType(type);
-    });
+    populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter,
+                                               this->useCoopMatrixNV);
 
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index eb7fcb63d920d8f..4a4281aaaf0dbc4 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -311,14 +311,21 @@ struct WmmaLoadOpToSPIRVLowering final
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = subgroupMmaLoadMatrixOp->getLoc();
+    auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+
     gpu::MMAMatrixType retType =
         cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
     auto memrefType =
         cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
-    Value bufferPtr = spirv::getElementPtr(
-        *getTypeConverter<const SPIRVTypeConverter>(), memrefType,
-        adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
-    auto coopType = convertMMAToSPIRVCoopMatrixNVType(retType);
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+    auto coopType =
+        typeConverter.convertType<spirv::CooperativeMatrixNVType>(retType);
+    if (!coopType)
+      return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
+                                         "type conversion failed");
+
     int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
     auto i32Type = rewriter.getI32Type();
     auto strideValue = rewriter.create<spirv::ConstantOp>(
@@ -385,30 +392,6 @@ struct WmmaMmaOpToSPIRVLowering final
 } // namespace nv
 } // namespace mlir
 
-mlir::spirv::CooperativeMatrixNVType
-mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
-  ArrayRef<int64_t> retTypeShape = type.getShape();
-  Type elementType = type.getElementType();
-  return spirv::CooperativeMatrixNVType::get(
-      elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
-}
-
-mlir::spirv::CooperativeMatrixType
-mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
-  ArrayRef<int64_t> retTypeShape = type.getShape();
-  Type elementType = type.getElementType();
-
-  auto use =
-      llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
-          .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
-          .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
-          .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
-
-  return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
-                                           retTypeShape[1],
-                                           spirv::Scope::Subgroup, use);
-}
-
 void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
     SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
   using namespace mlir;
@@ -432,3 +415,31 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
   patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
                                                           /*benefit=*/2);
 }
+
+void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
+    mlir::SPIRVTypeConverter &typeConverter, bool useNVTypes) {
+  if (useNVTypes) {
+    typeConverter.addConversion([](gpu::MMAMatrixType type) {
+      ArrayRef<int64_t> retTypeShape = type.getShape();
+      Type elementType = type.getElementType();
+      return spirv::CooperativeMatrixNVType::get(
+          elementType, spirv::Scope::Subgroup, retTypeShape[0],
+          retTypeShape[1]);
+    });
+    return;
+  }
+
+  typeConverter.addConversion([](gpu::MMAMatrixType type) {
+    ArrayRef<int64_t> retTypeShape = type.getShape();
+    Type elementType = type.getElementType();
+    auto use =
+        llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
+            .Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
+            .Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
+            .Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
+
+    return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
+                                             retTypeShape[1],
+                                             spirv::Scope::Subgroup, use);
+  });
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/67485


More information about the Mlir-commits mailing list