[Mlir-commits] [mlir] eaaf7a6 - [MLIR][GPU][NVVM] Add conversion of warp synchronous matrix-multiply accumulate GPU ops
Uday Bondhugula
llvmlistbot at llvm.org
Fri May 21 08:51:52 PDT 2021
Author: Navdeep Kumar
Date: 2021-05-21T21:20:33+05:30
New Revision: eaaf7a6a09da905cc314201f93e2be11773726a0
URL: https://github.com/llvm/llvm-project/commit/eaaf7a6a09da905cc314201f93e2be11773726a0
DIFF: https://github.com/llvm/llvm-project/commit/eaaf7a6a09da905cc314201f93e2be11773726a0.diff
LOG: [MLIR][GPU][NVVM] Add conversion of warp synchronous matrix-multiply accumulate GPU ops
Add conversion of warp synchronous matrix-multiply
accumulate GPU ops
Add conversion of warp synchronous matrix-multiply accumulate GPU ops to
NVVM ops. The following conversions are added :-
1.) subgroup_mma_load_matrix -> wmma.m16n16k16.load.[a,b,c]..row.stride
2.) subgroup_mma_store_matrix -> wmma.m16n16k16.store.d.[f16,f32].row.stride
3.) subgroup_mma_compute -> wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32]
Reviewed By: bondhugula, ftynse
Differential Revision: https://reviews.llvm.org/D95331
Added:
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
Modified:
mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index a005fb50226f5..e291a77e3a9be 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
+void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
/// index bitwidth used for the lowering of the device side index computations
/// is configurable.
diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
index e2647103ba22d..040b220135180 100644
--- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRGPUToNVVMIncGen)
add_mlir_conversion_library(MLIRGPUToNVVMTransforms
LowerGpuOpsToNVVMOps.cpp
+ WmmaOpsToNvvm.cpp
DEPENDS
MLIRConversionPassIncGen
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 8471de0a04152..42e64e5fb3c6a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -126,6 +126,38 @@ struct LowerGpuOpsToNVVMOpsPass
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
});
+ // Lowering for MMAMatrixType.
+ converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+ // The number of items in structToReturn are dependent on the the dataType
+ // and the MMA operand that this operation is associated with.
+ llvm::DenseMap<StringRef, int64_t> numElemsPerThreadF16,
+ numElemsPerThreadF32;
+ numElemsPerThreadF16["AOp"] = 8;
+ numElemsPerThreadF16["BOp"] = 8;
+ numElemsPerThreadF16["COp"] = 4;
+ numElemsPerThreadF16["DOp"] = 4;
+ numElemsPerThreadF32["AOp"] = 8;
+ numElemsPerThreadF32["BOp"] = 8;
+ numElemsPerThreadF32["COp"] = 8;
+ numElemsPerThreadF32["DOp"] = 8;
+ Type structToReturn;
+ if (type.getElementType().isF16()) {
+ // Number of f16's in 32-bit.
+ unsigned vecSize = 2;
+ Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext()));
+ unsigned size = numElemsPerThreadF16[type.getOperand()];
+ SmallVector<Type> elements(size, vec);
+ structToReturn =
+ LLVM::LLVMStructType::getLiteral(&getContext(), elements);
+ } else if (type.getElementType().isF32()) {
+ unsigned size = numElemsPerThreadF32[type.getOperand()];
+ SmallVector<Type> elements(size, FloatType::getF32(&getContext()));
+ structToReturn =
+ LLVM::LLVMStructType::getLiteral(&getContext(), elements);
+ }
+ return structToReturn;
+ });
+
RewritePatternSet patterns(m.getContext());
RewritePatternSet llvmPatterns(m.getContext());
@@ -137,6 +169,7 @@ struct LowerGpuOpsToNVVMOpsPass
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
+ populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
configureGpuToNVVMConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
new file mode 100644
index 0000000000000..c4458aa05f96d
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -0,0 +1,451 @@
+//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// NVVM Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Contains all the common LLVM types which are used across the lowerings of
+/// GPU subgroup ops to NVVM dialect.
+struct CommonLLVMAndBuiltInMLIRTypes {
+public:
+ CommonLLVMAndBuiltInMLIRTypes(MLIRContext *context) {
+ numHalfsInOpFrags.resize(4);
+ numHalfsInOpFrags[A] = 8;
+ numHalfsInOpFrags[B] = 8;
+ numHalfsInOpFrags[C] = 4;
+ numHalfsInOpFrags[D] = 4;
+ i32Ty = IntegerType::get(context, 32);
+ f16Ty = FloatType::getF16(context);
+ f32Ty = FloatType::getF32(context);
+ f16x2Ty = VectorType::get(2, f16Ty);
+ fragArrayABTy = LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(8, f16x2Ty));
+ fragArrayCDTy = LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(4, f16x2Ty));
+ fragArrayCDF32Ty =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ };
+
+ Type i32Ty;
+ Type f16Ty;
+ Type f32Ty;
+ Type f16x2Ty;
+ /// Type for the fragment of A and B operands that a single thread holds for
+ /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
+ /// (beta*C).
+ Type fragArrayABTy;
+ /// Type for the fragment of C and D operands that a single thread holds for
+ /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
+ /// (beta*C).
+ Type fragArrayCDTy;
+ /// Type for the fragment of C and D operands that a single thread holds for
+ /// fp32 data type in a WMMA operation of the form D = (alpha*(A*B)) +
+ /// (beta*C).
+ Type fragArrayCDF32Ty;
+ /// Represents the number of f16 elements a single thread holds in a WMMA
+ /// operation of the form D = (alpha*(A*B)) + (beta*C) .
+ SmallVector<unsigned, 4> numHalfsInOpFrags;
+ /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
+ /// (beta*C).
+ enum OperandMap { A, B, C, D };
+};
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the op
+/// is actually lowered. If the type of an operands is not already converted it
+/// hints a missing typeConversion and failure is returned in that case.
+static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter) {
+ if (!llvm::all_of(operands, [](Value value) {
+ return LLVM::isCompatibleType(value.getType());
+ })) {
+ return rewriter.notifyMatchFailure(
+ op, "cannot convert if operands aren't of LLVM type.");
+ }
+
+ return success();
+}
+
+/// Error string to emit when unimplemented WMMA variant is encountered.
+static constexpr StringRef kInvalidCaseStr =
+ "Unimplemented WMMA variant, Only M16N16K16 version implemented.";
+
+/// This class implements the conversion of GPU MMA loadOp to wmma.load op
+/// in the NVVM dialect. The conversion not only emits the NVVM op but also
+/// emits code that is necessary to store the data in the destination memref
+/// after it has been loaded.
+struct WmmaLoadOpToNVVMLowering
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>,
+ private CommonLLVMAndBuiltInMLIRTypes {
+public:
+ explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>(typeConverter),
+ CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
+ }
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Operation *op = subgroupMmaLoadMatrixOp.getOperation();
+ if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ return failure();
+
+ unsigned indexTypeBitwidth =
+ this->getTypeConverter()->getIndexTypeBitwidth();
+
+ // The corresponding intrinsics expects leadDimension to be a 32-bit
+ // integer, so all the calculations of linearizing the load address
+ // must also follow this restriction.
+ if (indexTypeBitwidth != 32)
+ return rewriter.notifyMatchFailure(
+ op, "Expected indices to the memref to be 32-bit wide.");
+
+ // Source memref of the original op.
+ MemRefType srcMemrefType =
+ subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>();
+ Location loc = op->getLoc();
+
+ auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
+
+ // MemRefDescriptor to extract alignedPtr and offset.
+ MemRefDescriptor promotedSrcOp(
+ gpu::SubgroupMmaLoadMatrixOpAdaptor(operands).srcMemref());
+
+ // Emit ops which compute the load offset using `srcOffsetI`,
+ // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
+ // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
+ // assumed to be normalized and hence the simple conversion works.
+ SmallVector<Value> indices(subgroupMmaLoadMatrixOp.indices());
+ Value srcOffsetIVal = indices[0];
+ Value srcOffsetJVal = indices[1];
+ Value leadingDim32 =
+ rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+ Value numElemsLeadDim =
+ rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, srcOffsetIVal);
+ Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
+ srcOffsetJVal);
+
+ Value promotedSrcOpToUse;
+ promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
+ Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
+ promotedSrcOpToUse);
+ Value loadAddress = rewriter.create<LLVM::GEPOp>(
+ loc,
+ LLVM::LLVMPointerType::get(f16Ty, srcMemrefType.getMemorySpaceAsInt()),
+ promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
+
+ // Bitcast the base address pointer of the destination memref, So that
+ // values can be stored in chunks of 32-bits and semantics match with the
+ // intrinsic exposed by NVPTX backend.
+ Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMPointerType::get(i32Ty, srcMemrefType.getMemorySpaceAsInt()),
+ loadAddress);
+
+ // Get the shape of the MMAMatrix type being returned. The shape will
+ // choose which intrinsic this op will be lowered to.
+ gpu::MMAMatrixType retType =
+ subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>();
+ ArrayRef<int64_t> retTypeShape = retType.getShape();
+
+ Type resType;
+ StringRef operandStr = retType.getOperand();
+ if (operandStr.equals("AOp") || operandStr.equals("BOp")) {
+ resType = fragArrayABTy;
+ } else {
+ if (srcMemrefType.getElementType().isF16())
+ resType = fragArrayCDTy;
+ else if (srcMemrefType.getElementType().isF32())
+ resType = fragArrayCDF32Ty;
+ else
+ return failure();
+ }
+
+ // Create nvvm.mma_load op according to the operand types.
+ SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
+ if (operandStr.equals("AOp")) {
+ if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
+ NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp =
+ rewriter.create<NVVM::WMMALoadAM16N16K16Op>(loc, resType,
+ loadOpOperands);
+ rewriter.replaceOp(op, wmmaLoadAOp.getResult());
+ } else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ } else if (operandStr.equals("BOp")) {
+ if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
+ NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp =
+ rewriter.create<NVVM::WMMALoadBM16N16K16Op>(loc, resType,
+ loadOpOperands);
+ rewriter.replaceOp(op, wmmaLoadBOp.getResult());
+ } else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ } else {
+ if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
+ if (srcMemrefType.getElementType().isF16()) {
+ NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp =
+ rewriter.create<NVVM::WMMALoadCF16M16N16K16Op>(loc, resType,
+ loadOpOperands);
+ rewriter.replaceOp(op, wmmaLoadCOp.getResult());
+ } else if (srcMemrefType.getElementType().isF32()) {
+ NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp =
+ rewriter.create<NVVM::WMMALoadCF32M16N16K16Op>(loc, resType,
+ loadOpOperands);
+ rewriter.replaceOp(op, wmmaLoadCOp.getResult());
+ }
+ } else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ }
+ return success();
+ }
+};
+
+/// This class implements the conversion of GPU MMA storeOp to wmma.store op
+/// in the NVVM dialect. The conversion not only emits the NVVM op but also
+/// emits code that is necessary to unpack the data in the source and
+/// convert the data in the format that is needed by the NVVM op.
+struct WmmaStoreOpToNVVMLowering
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>,
+ private CommonLLVMAndBuiltInMLIRTypes {
+public:
+ explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>(typeConverter),
+ CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
+ }
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Operation *op = subgroupMmaStoreMatrixOp.getOperation();
+ if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ return failure();
+
+ unsigned indexTypeBitwidth =
+ this->getTypeConverter()->getIndexTypeBitwidth();
+ // The corresponding intrinsics expects leadDimension to be a 32-bit
+ // integer, so all the calculations of linearizing the store address
+ // must also follow this restriction.
+ if (indexTypeBitwidth != 32)
+ return rewriter.notifyMatchFailure(
+ op, "expected indices to the memref to be 32-bit wide.");
+
+ Location loc = op->getLoc();
+
+ // Destination memref of the original op.
+ MemRefType dstMemrefType =
+ subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>();
+
+ // MemRefDescriptor to extract alignedPtr and offset.
+ MemRefDescriptor promotedDstOp(
+ gpu::SubgroupMmaStoreMatrixOpAdaptor(operands).dstMemref());
+
+ auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
+
+ // Emit ops which compute the store offset using `dstOffsetI`,
+ // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
+ // ((leadDimension * dstOffsetI) + dstOffsetJ)).
+ SmallVector<Value> indices(subgroupMmaStoreMatrixOp.indices());
+ Value dstOffsetIVal = indices[0];
+ Value dstOffsetJVal = indices[1];
+ Value leadingDim32 =
+ rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+ Value numElemsLeadDim =
+ rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, dstOffsetIVal);
+ Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
+ dstOffsetJVal);
+
+ Value promotedDstOpToUse;
+ promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
+ Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
+ promotedDstOpToUse);
+ Value storeAddress = rewriter.create<LLVM::GEPOp>(
+ loc,
+ LLVM::LLVMPointerType::get(f16Ty, dstMemrefType.getMemorySpaceAsInt()),
+ promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
+
+ // Bitcast the base address pointer of the destination memref, So that
+ // values can be stored in chunks of 32-bits and semantics match with the
+ // intrinsic exposed by NVPTX backend.
+ Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMPointerType::get(i32Ty, dstMemrefType.getMemorySpaceAsInt()),
+ storeAddress);
+
+ SmallVector<Value, 4> storeOpOperands;
+ storeOpOperands.push_back(storeAddressCasted);
+
+ // Get the shape of the MMAMatrix type being stored. The shape will
+ // choose which intrinsic this op will be lowered to.
+ gpu::MMAMatrixType srcType =
+ subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>();
+ ArrayRef<int64_t> srcTypeShape = srcType.getShape();
+
+ // Unpack the results from the source.
+ if (subgroupMmaStoreMatrixOp.src()
+ .getType()
+ .cast<gpu::MMAMatrixType>()
+ .getElementType() == f16Ty) {
+ for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) {
+ Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+ loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
+ storeOpOperands.push_back(toUse);
+ }
+ storeOpOperands.push_back(leadingDim32);
+
+ // Create nvvm.mma_store op.
+ if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) {
+ rewriter.create<NVVM::WMMAStoreF16M16N16K16Op>(loc, storeOpOperands);
+ } else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ } else if (subgroupMmaStoreMatrixOp.src()
+ .getType()
+ .cast<gpu::MMAMatrixType>()
+ .getElementType() == f32Ty) {
+ for (unsigned i = 0, e = 8; i < e; ++i) {
+ Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+ loc, f32Ty, operands[0], rewriter.getI32ArrayAttr(i));
+ storeOpOperands.push_back(toUse);
+ }
+ storeOpOperands.push_back(leadingDim32);
+
+ // Create nvvm.mma_store op.
+ if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16)
+ rewriter.create<NVVM::WMMAStoreF32M16N16K16Op>(loc, storeOpOperands);
+ else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
+/// in the NVVM dialect.
+struct WmmaMmaOpToNVVMLowering
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>,
+ private CommonLLVMAndBuiltInMLIRTypes {
+ explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>(typeConverter),
+ CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
+ }
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Operation *op = subgroupMmaComputeOp.getOperation();
+ if (failed(areAllLLVMTypes(op, operands, rewriter)))
+ return failure();
+
+ Location loc = op->getLoc();
+
+ // The wmma.mma intrinsic in llvm requires the operands as individual
+ // values. So individual elements from the memrefs need to be extracted and
+ // then passed on to the intrinsic call. Emit llvm ops to extract individual
+ // values form lowered memrefs.
+ SmallVector<Value> unpackedOps;
+
+ auto unpackOp = [&](CommonLLVMAndBuiltInMLIRTypes::OperandMap op,
+ Value operand, unsigned numElems, Type elemType) {
+ for (unsigned i = 0; i < numElems; ++i) {
+ Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+ loc, elemType, operand, rewriter.getI32ArrayAttr(i));
+ unpackedOps.push_back(toUse);
+ }
+ };
+
+ // Get the shapes of the MMAMatrix type being used. The shapes will
+ // choose which intrinsic this op will be lowered to.
+ gpu::MMAMatrixType aType =
+ subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+ ArrayRef<int64_t> aTypeShape = aType.getShape();
+ gpu::MMAMatrixType bType =
+ subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+ ArrayRef<int64_t> bTypeShape = bType.getShape();
+ gpu::MMAMatrixType cType =
+ subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+ ArrayRef<int64_t> cTypeShape = cType.getShape();
+
+ gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands);
+ if (subgroupMmaComputeOp.opC()
+ .getType()
+ .cast<gpu::MMAMatrixType>()
+ .getElementType() == f16Ty) {
+ unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
+ unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
+ unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty);
+
+ if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
+ bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
+ // Create nvvm.wmma.mma op.
+ NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp =
+ rewriter.create<NVVM::WMMAMmaF16F16M16N16K16Op>(loc, fragArrayCDTy,
+ unpackedOps);
+
+ rewriter.replaceOp(op, wmmaMmaOp.getResult());
+ return success();
+ } else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ } else if (subgroupMmaComputeOp.opC()
+ .getType()
+ .cast<gpu::MMAMatrixType>()
+ .getElementType() == f32Ty) {
+ unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
+ unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
+ unpackOp(C, transformedOperands.opC(), 8, f32Ty);
+
+ if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
+ bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
+ // Create nvvm.wmma.mma op.
+ NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp =
+ rewriter.create<NVVM::WMMAMmaF32F32M16N16K16Op>(
+ loc, fragArrayCDF32Ty, unpackedOps);
+
+ rewriter.replaceOp(op, wmmaMmaOp.getResult());
+ return success();
+ } else {
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ }
+
+ return failure();
+ }
+};
+
+} // anonymous namespace
+
+namespace mlir {
+void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.insert<WmmaLoadOpToNVVMLowering>(converter);
+ patterns.insert<WmmaMmaOpToNVVMLowering>(converter);
+ patterns.insert<WmmaStoreOpToNVVMLowering>(converter);
+}
+} // namespace mlir
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
new file mode 100644
index 0000000000000..c44d7f0cfa301
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s
+
+gpu.module @test_module {
+
+ // CHECK-LABEL: func @gpu_wmma_load_op() ->
+ // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> {
+ func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) {
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ %j = constant 16 : index
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+ // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+ // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
+ // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
+ // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
+ // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+ // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+ // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+ }
+}
+
+// -----
+
+gpu.module @test_module {
+
+ // CHECK-LABEL: func @gpu_wmma_store_op
+ // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
+ func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+ %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
+ %i = constant 16 : index
+ %j = constant 16 : index
+ gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+ // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+ // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+ // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32
+ // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
+ // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32
+ // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+ // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+ // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+ // CHECK: llvm.return
+ return
+ }
+}
+
+// -----
+
+gpu.module @test_module {
+
+ // CHECK-LABEL: func @gpu_wmma_mma_op
+ // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
+ func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+ %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+ // CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A5:.*]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A6:.*]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A7:.*]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[A8:.*]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B2:.*]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B3:.*]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B4:.*]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B5:.*]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B6:.*]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B7:.*]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[B8:.*]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.return
+ return
+ }
+}
More information about the Mlir-commits
mailing list