[Mlir-commits] [mlir] [mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix (PR #66455)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 14 18:30:50 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
<details>
<summary>Changes</summary>
These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.
Also improve match failure reporting and error handling in type conversion.
--
Patch is 36.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66455.diff
7 Files Affected:
- (modified) mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h (+10)
- (modified) mlir/include/mlir/Conversion/Passes.td (+5-1)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+25)
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+16-4)
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+268-106)
- (added) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (+174)
- (modified) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir (+2-1)
<pre>
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 6c4643da1884900..c258513ed4878ea 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,21 @@ class MMAMatrixType;
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the KHR Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
+ SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
/// using the NV Cooperative Matrix extension.
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixType
+convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
+
/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
/// `type`.
spirv::CooperativeMatrixNVType
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3218760931b8cb0..5e0f976b18f7da5 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
let options = [
Option<"use64bitIndex", "use-64bit-index",
"bool", /*default=*/"false",
- "Use 64-bit integers to convert index types">
+ "Use 64-bit integers to convert index types">,
+ Option<"useCoopMatrixNV", "use-coop-matrix-nv",
+ "bool", /*default=*/"false",
+ "Use the NV cooperative matrix extension insted of the KHR extension"
+ " to lower GPU WMMA ops">,
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index b5ea0774f589d16..34c76c5e9382302 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);
+
+ let builders = [
+ OpBuilder<(ins "Type":$result, "Value":$pointer,
+ "spirv::ConstantOp":$stride,
+ "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+ build($_builder, $_state, result, pointer, layout, stride,
+ spirv::MemoryAccessAttr{});
+ }]>
+ ];
}
// -----
@@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
);
let results = (outs);
+
+ let builders = [
+ OpBuilder<(ins "Value":$pointer, "Value":$object,
+ "spirv::ConstantOp":$stride,
+ "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+ build($_builder, $_state, pointer, object, layout, stride,
+ spirv::MemoryAccessAttr{});
+ }]>
+ ];
}
// -----
@@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);
+
+ let builders = [
+ OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
+ build($_builder, $_state, a, b, c,
+ spirv::CooperativeMatrixOperandsKHRAttr{});
+ }]>
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index d0ce58597f980d4..5b05c45bf602509 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
- typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
- return convertMMAToSPIRVCoopMatrixNVType(type);
+
+ typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
+ gpu::MMAMatrixType type) -> Type {
+ if (useNV)
+ return convertMMAToSPIRVCoopMatrixNVType(type);
+
+ return convertMMAToSPIRVCoopMatrixType(type);
});
+
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
- populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
- patterns);
+ if (this->useCoopMatrixNV) {
+ populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+ patterns);
+ } else {
+ populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
+ patterns);
+ }
+
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index bf3fff027fe384a..eb7fcb63d920d8f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -18,12 +18,22 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
-namespace mlir::nv {
-namespace {
+#include <cassert>
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// Patterns and helpers used by both the KHR and the NV lowering paths.
+//===----------------------------------------------------------------------===//
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
@@ -31,9 +41,11 @@ namespace {
///
/// See SPV_NV_cooperative_matrix for supported elementwise ops.
static bool createElementwiseOp(ConversionPatternRewriter &builder,
- gpu::SubgroupMmaElementwiseOp op,
- spirv::CooperativeMatrixNVType coopType,
+ gpu::SubgroupMmaElementwiseOp op, Type coopType,
ValueRange operands) {
+ assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+ coopType)));
+
switch (op.getOpType()) {
case gpu::MMAElementwiseOp::ADDF:
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +83,223 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}
+bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
+ assert(!operands.empty());
+ if (!llvm::all_equal(
+ llvm::map_range(operands, [](Value v) { return v.getType(); })))
+ return false;
+
+ return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+ operands.front().getType());
+}
+
+namespace {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaConstantOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 1);
+ Value cst = adaptor.getOperands().front();
+ auto coopType = getTypeConverter()->convertType(op.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
+ return success();
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// the default case.
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands should be of cooperative matrix types.
+ if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+ return rewriter.notifyMatchFailure(op,
+ "not all operands are coop matrices");
+ }
+
+ auto coopType = getTypeConverter()->convertType(op.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ return success(
+ createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// matrix times scalar case.
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+ : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getOperands().size() != 2)
+ return failure();
+
+ // All operands should be of cooperative matrix types.
+ if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+ return rewriter.notifyMatchFailure(op,
+ "not all operands are coop matrices");
+ }
+
+ if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
+ return failure();
+
+ // Use the original operands to check whether one of the operands is a splat
+ // scalar value.
+ Value lhs = op.getOperands().front();
+ Value rhs = op.getOperands().back();
+ Value splat = nullptr;
+ Value matrix = nullptr;
+ if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+ splat = adaptor.getOperands().front();
+ matrix = adaptor.getOperands().back();
+ } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+ matrix = adaptor.getOperands().front();
+ splat = adaptor.getOperands().back();
+ }
+ if (!splat || !matrix)
+ return rewriter.notifyMatchFailure(op, "no splat operand");
+
+ // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
+ Value scalar;
+ auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
+ if (!cc) {
+ return rewriter.notifyMatchFailure(op,
+ "splat is not a composite construct");
+ }
+
+ assert(cc.getConstituents().size() == 1);
+ scalar = cc.getConstituents().front();
+
+ auto coopType = getTypeConverter()->convertType(op.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+ rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
+ op, coopType, ValueRange{matrix, scalar});
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// SPV_KHR_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace khr {
+namespace {
+
+/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Location loc = op->getLoc();
+
+ auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
+ MemRefType memrefType = op.getSrcMemref().getType();
+ Value bufferPtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
+ adaptor.getIndices(), loc, rewriter);
+
+ auto coopType =
+ typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ int64_t stride = op.getLeadDimension().getSExtValue();
+ IntegerType i32Type = rewriter.getI32Type();
+ auto strideValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+ bool isColMajor = op.getTranspose().value_or(false);
+ auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+ : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+ rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
+ op, coopType, bufferPtr, strideValue, layout);
+ return success();
+ }
+};
+
+/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Location loc = op->getLoc();
+
+ auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
+ Value bufferPtr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
+ adaptor.getIndices(), loc, rewriter);
+
+ int64_t stride = op.getLeadDimension().getSExtValue();
+ IntegerType i32Type = rewriter.getI32Type();
+ auto strideValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+ bool isColMajor = op.getTranspose().value_or(false);
+ auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+ : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+ rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
+ op, bufferPtr, adaptor.getSrc(), strideValue, layout);
+ return success();
+ }
+};
+
+/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
+/// dialect.
+struct WmmaMmaOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
+ subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
+ adaptor.getOpC());
+ return success();
+ }
+};
+
+} // namespace
+} // namespace khr
+
+//===----------------------------------------------------------------------===//
+// SPV_NV_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace nv {
+namespace {
+
/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
/// dialect.
struct WmmaLoadOpToSPIRVLowering final
@@ -152,102 +381,9 @@ struct WmmaMmaOpToSPIRVLowering final
}
};
-/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
-/// ops.
-struct WmmaConstantOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Value cst = adaptor.getOperands()[0];
- auto coopType = convertMMAToSPIRVCoopMatrixNVType(
- cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
- rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
- subgroupMmaConstantMatrixOp, coopType, cst);
- return success();
- }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering final
- : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
- OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // All operands should be of cooperative matrix types.
- for (Value operand : adaptor.getOperands()) {
- if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
- return failure();
- }
- auto coopType = convertMMAToSPIRVCoopMatrixNVType(
- cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
- return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
- adaptor.getOperands()));
- }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering final
- : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(gp...
<truncated>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66455
More information about the Mlir-commits
mailing list