[Mlir-commits] [mlir] 894a591 - [mlir][nvgpu] Move mma.sync and ldmatrix in nvgpu dialect
Thomas Raoux
llvmlistbot at llvm.org
Thu Apr 14 16:46:25 PDT 2022
Author: Thomas Raoux
Date: 2022-04-14T23:44:52Z
New Revision: 894a591cf6fc542e6fc5d84222c839495a3d832f
URL: https://github.com/llvm/llvm-project/commit/894a591cf6fc542e6fc5d84222c839495a3d832f
DIFF: https://github.com/llvm/llvm-project/commit/894a591cf6fc542e6fc5d84222c839495a3d832f.diff
LOG: [mlir][nvgpu] Move mma.sync and ldmatrix in nvgpu dialect
Move gpu operation mma.sync and ldmatrix in nvgpu as they are specific
to nvidia target.
Differential Revision: https://reviews.llvm.org/D123824
Added:
mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/NVGPU/NVGPU.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/test/Dialect/NVGPU/roundtrip.mlir
Removed:
mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h b/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
new file mode 100644
index 0000000000000..05dd975322622
--- /dev/null
+++ b/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
@@ -0,0 +1,26 @@
+//===- NVGPUToNVVMPass.h - Convert NVGPU to NVVM dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_
+#define MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createConvertNVGPUToNVVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ae05446b7b236..0e9b4fc9f233e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -35,6 +35,7 @@
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5dcf9d7415964..d42e7dd4be871 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -506,6 +506,22 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// NVGPUToNVVM
+//===----------------------------------------------------------------------===//
+
+def ConvertNVGPUToNVVM : Pass<"convert-nvgpu-to-nvvm"> {
+ let summary = "Convert NVGPU dialect to NVVM dialect";
+ let description = [{
+ This pass converts supported NVGPU ops to NVVM dialect intrinsics.
+ }];
+ let constructor = "mlir::createConvertNVGPUToNVVMPass()";
+ let dependentDialects = [
+ "NVVM::NVVMDialect",
+ ];
+}
+
+
//===----------------------------------------------------------------------===//
// OpenACCToSCF
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 078a66dc821b2..f93a32384becc 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1369,58 +1369,4 @@ def GPU_DeviceAsyncWaitOp : GPU_Op<"device_async_wait", []> {
}];
}
-def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix",
- [MemoryEffects<[MemRead]>]> {
- let description = [{
- The `gpu.mma.ldmatrix` op represents loading a matrix fragment from
- memory. The load source and result type must be compatible with lowering
- to the `nvvm.ldmatrix` instruction. This op is meant to represent
- the distributed version of a `vector.transfer_read` as an intermediate
- step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`.
-
- Example:
-
- ```mlir
- gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16>
- ```
- }];
-
- let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$srcMemref,
- Variadic<Index>:$indices, BoolAttr:$transpose,
- I32Attr:$numTiles);
- let results = (outs AnyVector:$res);
- let assemblyFormat = [{
- $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
- }];
-}
-
-def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> {
- let description = [{
- The `gpu.mma.sync` op represents the distributed form of a collective
- matrix-multiply-and-accumulate (mma) operation that is compatible with
- `nvvm.mma.sync`. The operands and results are fragments of the full matrix
- operands. The full shape of the distributed mma operation is given by the
- `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.
-
- This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
- is an intermediate point between lowering from `vector.contract` to
- `nvvm.mma.sync`.
-
- Example:
-
- ```mlir
- gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
- ```
- }];
- let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC,
- I64ArrayAttr:$mmaShape);
-
- let results = (outs AnyVector:$res);
-
- let assemblyFormat = [{
- `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
- `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
- }];
-}
-
#endif // GPU_OPS
diff --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
index 9ed34ace009b6..28147e3fe4aae 100644
--- a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
@@ -69,4 +69,37 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
}];
}
+def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
+ let description = [{
+ The `nvgpu.mma.sync` op represents the distributed form of a collective
+ matrix-multiply-and-accumulate (mma) operation that is compatible with
+ `nvvm.mma.sync`. The operands and results are fragments of the full matrix
+ operands. The full shape of the distributed mma operation is given by the
+ `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.
+
+ This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
+ is an intermediate point between lowering from `vector.contract` to
+ `nvvm.mma.sync`.
+
+ This operation is meant to follow the semantic of described here:
+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma
+
+ Example:
+
+ ```mlir
+ nvgpu.mma.sync (%a, %b, %c) :
+ (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ ```
+ }];
+ let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
+ AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
+
+ let results = (outs AnyVector:$res);
+
+ let assemblyFormat = [{
+ `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
+ `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
+ }];
+}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 5ef84273cbe0c..533d61acea414 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(NVGPUToNVVM)
add_subdirectory(OpenACCToLLVM)
add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e5145f6513fdf..6f0e585365e29 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -208,290 +208,6 @@ struct GPUAsyncWaitLowering
}
};
-struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp> {
- using ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- MLIRContext *ctx = getContext();
- Location loc = op->getLoc();
-
- // The result type of ldmatrix will always be a struct of 32bit integer
- // registers if more than one 32bit value is returned. Otherwise, the result
- // is a single i32. The result type of the GPU operation is always a vector
- // of shape (NumRegisters, VectorRegister) where VectorRegister is the
- // vector type of the result and always 32 bits long. We bitcast the result
- // of the NVVM::LdMatrix to this vector type.
- auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
- if (!vectorResultType) {
- return failure();
- }
- Type innerVectorType = LLVM::getFixedVectorType(
- vectorResultType.getElementType(), vectorResultType.getDimSize(1));
-
- int64_t num32BitRegs = vectorResultType.getDimSize(0);
-
- Type ldMatrixResultType;
- if (num32BitRegs > 1) {
- ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
- ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
- } else {
- ldMatrixResultType = rewriter.getI32Type();
- }
-
- auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
- Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
- adaptor.indices(), rewriter);
- Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
- loc, ldMatrixResultType, srcPtr,
- /*num=*/op.numTiles(),
- /*layout=*/op.transpose() ? NVVM::MMALayout::col
- : NVVM::MMALayout::row);
-
- // The ldmatrix operation returns either a single i32 value or a struct of
- // i32 values. Here we unpack those values and cast them back to their
- // actual vector type (still of width 32b) and repack them into a result
- // struct.
- Type finalResultType = typeConverter->convertType(vectorResultType);
- Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
- for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
- Value i32Register = num32BitRegs > 1
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, rewriter.getI32Type(), ldMatrixResult,
- rewriter.getI64ArrayAttr(i))
- : ldMatrixResult;
- Value casted =
- rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
- }
-
- rewriter.replaceOp(op, result);
- return success();
- }
-};
-
-/// Checks if all the operands of the op being lowered are of LLVM Types. The
-/// types are expected to be converted by the `LLVMTypeConverter` before the
-/// op is actually lowered. If the type of an operands is not already
-/// converted it hints a missing typeConversion and failure is returned in
-/// that case.
-LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
- ConversionPatternRewriter &rewriter) {
- if (!llvm::all_of(operands, [](Value value) {
- return LLVM::isCompatibleType(value.getType());
- })) {
- return rewriter.notifyMatchFailure(
- op, "cannot convert if operands aren't of LLVM type.");
- }
-
- return success();
-}
-
-/// Returns the type for the intrinsic given the vectorResultType of the
-/// `gpu.mma.sync` operation.
-Type inferIntrinsicResultType(Type vectorResultType) {
- MLIRContext *ctx = vectorResultType.getContext();
- auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
- auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
- auto i32Ty = IntegerType::get(ctx, 32);
- auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
- Type f64Ty = Float64Type::get(ctx);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
- if (a.getElementType() == f16x2Ty) {
- return LLVM::LLVMStructType::getLiteral(
- ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
- }
- if (a.getElementType() == i32x2Ty) {
- return LLVM::LLVMStructType::getLiteral(
- ctx,
- SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
- }
- if (a.getElementType() == f64x2Ty) {
- return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
- }
- return vectorResultType;
-}
-
-/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
-/// always an LLVM struct) into a fragment that is compatible with the vector
-/// type of this operation. This involves extracting elements from the struct
-/// and inserting them into an LLVM array. These extra data-movement
-/// operations should be canonicalized away by the LLVM backend.
-Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
- Type resultType, Value intrinsicResult,
- RewriterBase &rewriter) {
- MLIRContext *ctx = rewriter.getContext();
- auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
- auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
- Type i32Ty = rewriter.getI32Type();
- Type f64Ty = rewriter.getF64Type();
- Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
- Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
-
- auto makeConst = [&](int32_t index) -> Value {
- return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
- rewriter.getI32IntegerAttr(index));
- };
-
- if (arrayType) {
- SmallVector<Value, 4> elements;
-
- if (arrayType.getElementType() == f16x2Ty) {
- for (unsigned i = 0; i < structType.getBody().size(); i++) {
- elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
- loc, structType.getBody()[i], intrinsicResult,
- rewriter.getI64ArrayAttr(i)));
- }
- }
-
- // The intrinsic returns i32 and f64 values as individual scalars. We need
- // to extract them from the struct and pack them into vectors.
- if (arrayType.getElementType() == i32x2Ty ||
- arrayType.getElementType() == f64x2Ty) {
- Value vec =
- rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
- for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
- Value x1 = rewriter.create<LLVM::ExtractValueOp>(
- loc, structType.getBody()[i * 2], intrinsicResult,
- rewriter.getI64ArrayAttr(i * 2));
- Value x2 = rewriter.create<LLVM::ExtractValueOp>(
- loc, structType.getBody()[i * 2 + 1], intrinsicResult,
- rewriter.getI64ArrayAttr(i * 2 + 1));
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x1, makeConst(0));
- vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
- x2, makeConst(1));
- }
- elements.push_back(vec);
- }
-
- // Create the final vectorized result.
- Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
- for (const auto &el : llvm::enumerate(elements)) {
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, arrayType, result, el.value(),
- rewriter.getI64ArrayAttr(el.index()));
- }
- return result;
- }
-
- return intrinsicResult;
-}
-
-/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
-/// given as 2D `vectors` where the rows are 32b or 64b wide. The
-/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
-/// scalars of certain types. This function helps unpack the `vector` arguments
-/// and cast them to the types expected by `nvvm.mma.sync`.
-SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, Location loc,
- Value operand) {
- SmallVector<Value> result;
- Type i32Ty = rewriter.getI32Type();
- Type f64Ty = rewriter.getF64Type();
- Type i8Ty = rewriter.getI8Type();
- Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
- auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
-
- for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(
- loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
-
- // For 4xi8 vectors, the intrinsic expects these to be provided as i32
- // scalar types.
- if (arrayTy.getElementType() == i8x4Ty) {
- result.push_back(
- rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
- continue;
- }
-
- // For some element types (i32, f64), we need to unpack the inner
- // vector/array type as well because the intrinsic expects individual
- // scalars to be provided.
- VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
- if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
- innerArrayTy.getElementType() == f64Ty)) {
- for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
- idx < innerSize; idx++) {
- result.push_back(rewriter.create<LLVM::ExtractElementOp>(
- loc, toUse,
- rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
- }
- continue;
- }
- result.push_back(toUse);
- }
- return result;
-}
-
-struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<gpu::MmaSyncOp> {
- using ConvertOpToLLVMPattern<gpu::MmaSyncOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) {
- return failure();
- }
-
- // Get the shapes of the MMAMatrix type being used. The shapes will
- // choose which intrinsic this op will be lowered to.
- auto aType = op.matrixA().getType().cast<VectorType>();
-
- int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
- int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
- int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
- std::array<int64_t, 3> gemmShape{m, n, k};
-
- SmallVector<Value> matA =
- unpackOperandVector(rewriter, loc, adaptor.matrixA());
- SmallVector<Value> matB =
- unpackOperandVector(rewriter, loc, adaptor.matrixB());
- SmallVector<Value> matC =
- unpackOperandVector(rewriter, loc, adaptor.matrixC());
-
- NVVM::MMATypes ptxTypeA;
- NVVM::MMATypes ptxTypeB;
- Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
- if (aType.getElementType().isInteger(8)) {
- ptxTypeA = NVVM::MMATypes::s8;
- ptxTypeB = NVVM::MMATypes::s8;
- overflow = NVVM::MMAIntOverflow::satfinite;
-
- } else if (aType.getElementType().isF16()) {
- ptxTypeA = NVVM::MMATypes::f16;
- ptxTypeB = NVVM::MMATypes::f16;
- } else if (aType.getElementType().isF64()) {
- ptxTypeA = NVVM::MMATypes::f64;
- ptxTypeB = NVVM::MMATypes::f64;
- } else {
- return op->emitError("could not deduce operand PTX types");
- }
-
- Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
- Type intrinsicResTy = inferIntrinsicResultType(
- typeConverter->convertType(op->getResultTypes()[0]));
- Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
- op.getLoc(), intrinsicResTy, matA, matB, matC,
- /*shape=*/gemmShape,
- /*b1Op=*/llvm::None,
- /*intOverflow=*/overflow,
- /*multiplicandPtxTypes=*/
- std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
- /*multiplicandLayouts=*/
- std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
- NVVM::MMALayout::col});
- rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
- desiredRetTy, intrinsicResult,
- rewriter));
- return success();
- }
-};
-
struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
@@ -611,8 +327,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
- GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering,
- MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
+ GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
+ converter);
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
new file mode 100644
index 0000000000000..50750e478095b
--- /dev/null
+++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_conversion_library(MLIRNVGPUToNVVM
+ NVGPUToNVVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVGPUToNVVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRLLVMCommonConversion
+ MLIRLLVMIR
+ MLIRNVVMIR
+ MLIRNVGPU
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
new file mode 100644
index 0000000000000..a304e49e58387
--- /dev/null
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -0,0 +1,308 @@
+//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
+#include "../PassDetail.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+
+using namespace mlir;
+
+/// Returns the type for the intrinsic given the vectorResultType of the
+/// `gpu.mma.sync` operation.
+static Type inferIntrinsicResultType(Type vectorResultType) {
+ MLIRContext *ctx = vectorResultType.getContext();
+ auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
+ auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+ auto i32Ty = IntegerType::get(ctx, 32);
+ auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ Type f64Ty = Float64Type::get(ctx);
+ Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+ if (a.getElementType() == f16x2Ty) {
+ return LLVM::LLVMStructType::getLiteral(
+ ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
+ }
+ if (a.getElementType() == i32x2Ty) {
+ return LLVM::LLVMStructType::getLiteral(
+ ctx,
+ SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
+ }
+ if (a.getElementType() == f64x2Ty) {
+ return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
+ }
+ return vectorResultType;
+}
+
+/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
+/// always an LLVM struct) into a fragment that is compatible with the vector
+/// type of this operation. This involves extracting elements from the struct
+/// and inserting them into an LLVM array. These extra data-movement
+/// operations should be canonicalized away by the LLVM backend.
+static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
+ Type resultType, Value intrinsicResult,
+ RewriterBase &rewriter) {
+ MLIRContext *ctx = rewriter.getContext();
+ auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
+ auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
+ Type i32Ty = rewriter.getI32Type();
+ Type f64Ty = rewriter.getF64Type();
+ Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
+ Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+
+ auto makeConst = [&](int32_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
+ rewriter.getI32IntegerAttr(index));
+ };
+
+ if (arrayType) {
+ SmallVector<Value, 4> elements;
+
+ if (arrayType.getElementType() == f16x2Ty) {
+ for (unsigned i = 0; i < structType.getBody().size(); i++) {
+ elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, structType.getBody()[i], intrinsicResult,
+ rewriter.getI64ArrayAttr(i)));
+ }
+ }
+
+ // The intrinsic returns i32 and f64 values as individual scalars. We need
+ // to extract them from the struct and pack them into vectors.
+ if (arrayType.getElementType() == i32x2Ty ||
+ arrayType.getElementType() == f64x2Ty) {
+ Value vec =
+ rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
+ for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
+ Value x1 = rewriter.create<LLVM::ExtractValueOp>(
+ loc, structType.getBody()[i * 2], intrinsicResult,
+ rewriter.getI64ArrayAttr(i * 2));
+ Value x2 = rewriter.create<LLVM::ExtractValueOp>(
+ loc, structType.getBody()[i * 2 + 1], intrinsicResult,
+ rewriter.getI64ArrayAttr(i * 2 + 1));
+ vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+ x1, makeConst(0));
+ vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+ x2, makeConst(1));
+ }
+ elements.push_back(vec);
+ }
+
+ // Create the final vectorized result.
+ Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
+ for (const auto &el : llvm::enumerate(elements)) {
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, arrayType, result, el.value(),
+ rewriter.getI64ArrayAttr(el.index()));
+ }
+ return result;
+ }
+
+ return intrinsicResult;
+}
+
+/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
+/// given as 2D `vectors` where the rows are 32b or 64b wide. The
+/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
+/// scalars of certain types. This function helps unpack the `vector` arguments
+/// and cast them to the types expected by `nvvm.mma.sync`.
+static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
+ Location loc, Value operand) {
+ SmallVector<Value> result;
+ Type i32Ty = rewriter.getI32Type();
+ Type f64Ty = rewriter.getF64Type();
+ Type i8Ty = rewriter.getI8Type();
+ Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
+ auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
+
+ for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
+ Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+ loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
+
+ // For 4xi8 vectors, the intrinsic expects these to be provided as i32
+ // scalar types.
+ if (arrayTy.getElementType() == i8x4Ty) {
+ result.push_back(
+ rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+ continue;
+ }
+
+ // For some element types (i32, f64), we need to unpack the inner
+ // vector/array type as well because the intrinsic expects individual
+ // scalars to be provided.
+ VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
+ if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
+ innerArrayTy.getElementType() == f64Ty)) {
+ for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
+ idx < innerSize; idx++) {
+ result.push_back(rewriter.create<LLVM::ExtractElementOp>(
+ loc, toUse,
+ rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+ }
+ continue;
+ }
+ result.push_back(toUse);
+ }
+ return result;
+}
+
+namespace {
+
+struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
+ using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = getContext();
+ Location loc = op->getLoc();
+
+ // The result type of ldmatrix will always be a struct of 32bit integer
+ // registers if more than one 32bit value is returned. Otherwise, the result
+ // is a single i32. The result type of the GPU operation is always a vector
+ // of shape (NumRegisters, VectorRegister) where VectorRegister is the
+ // vector type of the result and always 32 bits long. We bitcast the result
+ // of the NVVM::LdMatrix to this vector type.
+ auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ if (!vectorResultType) {
+ return failure();
+ }
+ Type innerVectorType = LLVM::getFixedVectorType(
+ vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+
+ int64_t num32BitRegs = vectorResultType.getDimSize(0);
+
+ Type ldMatrixResultType;
+ if (num32BitRegs > 1) {
+ ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
+ ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
+ } else {
+ ldMatrixResultType = rewriter.getI32Type();
+ }
+
+ auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
+ Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
+ adaptor.indices(), rewriter);
+ Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
+ loc, ldMatrixResultType, srcPtr,
+ /*num=*/op.numTiles(),
+ /*layout=*/op.transpose() ? NVVM::MMALayout::col
+ : NVVM::MMALayout::row);
+
+ // The ldmatrix operation returns either a single i32 value or a struct of
+ // i32 values. Here we unpack those values and cast them back to their
+ // actual vector type (still of width 32b) and repack them into a result
+ // struct.
+ Type finalResultType = typeConverter->convertType(vectorResultType);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+ for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
+ Value i32Register = num32BitRegs > 1
+ ? rewriter.create<LLVM::ExtractValueOp>(
+ loc, rewriter.getI32Type(), ldMatrixResult,
+ rewriter.getI64ArrayAttr(i))
+ : ldMatrixResult;
+ Value casted =
+ rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
+ using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ // Get the shapes of the MMAMatrix type being used. The shapes will
+ // choose which intrinsic this op will be lowered to.
+ auto aType = op.matrixA().getType().cast<VectorType>();
+
+ int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
+ int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
+ int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
+ std::array<int64_t, 3> gemmShape{m, n, k};
+
+ SmallVector<Value> matA =
+ unpackOperandVector(rewriter, loc, adaptor.matrixA());
+ SmallVector<Value> matB =
+ unpackOperandVector(rewriter, loc, adaptor.matrixB());
+ SmallVector<Value> matC =
+ unpackOperandVector(rewriter, loc, adaptor.matrixC());
+
+ NVVM::MMATypes ptxTypeA;
+ NVVM::MMATypes ptxTypeB;
+ Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
+ if (aType.getElementType().isInteger(8)) {
+ ptxTypeA = NVVM::MMATypes::s8;
+ ptxTypeB = NVVM::MMATypes::s8;
+ overflow = NVVM::MMAIntOverflow::satfinite;
+
+ } else if (aType.getElementType().isF16()) {
+ ptxTypeA = NVVM::MMATypes::f16;
+ ptxTypeB = NVVM::MMATypes::f16;
+ } else if (aType.getElementType().isF64()) {
+ ptxTypeA = NVVM::MMATypes::f64;
+ ptxTypeB = NVVM::MMATypes::f64;
+ } else {
+ return op->emitError("could not deduce operand PTX types");
+ }
+
+ Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
+ Type intrinsicResTy = inferIntrinsicResultType(
+ typeConverter->convertType(op->getResultTypes()[0]));
+ Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
+ op.getLoc(), intrinsicResTy, matA, matB, matC,
+ /*shape=*/gemmShape,
+ /*b1Op=*/llvm::None,
+ /*intOverflow=*/overflow,
+ /*multiplicandPtxTypes=*/
+ std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
+ /*multiplicandLayouts=*/
+ std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
+ NVVM::MMALayout::col});
+ rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
+ desiredRetTy, intrinsicResult,
+ rewriter));
+ return success();
+ }
+};
+
+struct ConvertNVGPUToNVVMPass
+ : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
+ ConvertNVGPUToNVVMPass() = default;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ LLVMTypeConverter converter(&getContext());
+ populateNVGPUToNVVMConversionPatterns(converter, patterns);
+ LLVMConversionTarget target(getContext());
+ target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+ target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
+}
+
+std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
+ return std::make_unique<ConvertNVGPUToNVVMPass>();
+}
diff --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
deleted file mode 100644
index 2f70a15badb7d..0000000000000
--- a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
+++ /dev/null
@@ -1,129 +0,0 @@
-// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
-
-gpu.module @test_module {
- // CHECK-LABEL: @m16n8k16_fp16
- func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
- // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>>
-
- // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>>
-
- // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
- // CHECK-NOT llvm.extractvalue
- // CHECK: [[d:%.+]] = nvvm.mma.sync
- // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}
- %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
- // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
- // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
- // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
- return %d : vector<2x2xf16>
- }
-
- // CHECK-LABEL: @m16n8k8_fp16
- func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
- // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>>
-
- // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>>
-
- // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
- // CHECK-NOT llvm.extractvalue
- // CHECK: [[d:%.+]] = nvvm.mma.sync
- // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}
- %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
- // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
- // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
- // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>
- // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
- return %d : vector<2x2xf16>
- }
-
- // CHECK-LABEL: @m16n8k32_int8
- func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
-
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
- // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
- // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
- // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
- // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
- // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
- // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
- // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
-
- // CHECK: [[d:%.+]] = nvvm.mma.sync
- // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
- // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
- // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
- // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
- %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
-
- // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>>
- return %d : vector<2x2xi32>
- }
-
- // CHECK-LABEL: @m8n8k4_f64
- func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
- // CHECK: llvm.extractvalue %arg0
- // CHECK: llvm.extractvalue %arg1
- // CHECK: llvm.extractvalue %arg2
-
- // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
- // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
- %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
- // CHECK: llvm.mlir.undef : vector<2xf64>
- // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
- // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
- // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64>
- // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>>
- // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>>
- return %d : vector<1x2xf64>
- }
-
- // CHECK-LABEL: @ldmatrix_x4
- func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
- %c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
- %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
- // CHECK: llvm.extractvalue
- // CHECK: llvm.bitcast
- // CHECK: llvm.insertvalue
- // CHECK: llvm.extractvalue
- // CHECK: llvm.bitcast
- // CHECK: llvm.insertvalue
- // CHECK: llvm.extractvalue
- // CHECK: llvm.bitcast
- // CHECK: llvm.insertvalue
- // CHECK: llvm.extractvalue
- // CHECK: llvm.bitcast
- // CHECK: llvm.insertvalue
- return %a : vector<4x2xf16>
- }
-
- // CHECK-LABEL: @ldmatrix_x1
- func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
- %c0 = arith.constant 0 : index
- // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
- %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
- // CHECK: llvm.bitcast
- // CHECK: llvm.insertvalue
- return %a : vector<1x2xf16>
- }
-}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
new file mode 100644
index 0000000000000..e095fecdfa87e
--- /dev/null
+++ b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
@@ -0,0 +1,127 @@
+// RUN: mlir-opt --convert-nvgpu-to-nvvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @m16n8k16_fp16
+func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-NOT llvm.extractvalue
+ // CHECK: [[d:%.+]] = nvvm.mma.sync
+ // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>
+ return %d : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @m16n8k8_fp16
+func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<1 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-NOT llvm.extractvalue
+ // CHECK: [[d:%.+]] = nvvm.mma.sync
+ // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: return
+ return %d : vector<2x2xf16>
+}
+
+// -----
+
+
+// CHECK-LABEL: @m16n8k32_int8
+func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+ // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+ // CHECK: [[d:%.+]] = nvvm.mma.sync
+ // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
+ // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
+ // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
+ // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ return %d : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @m8n8k4_f64
+func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
+ // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ // CHECK: llvm.mlir.undef : vector<2xf64>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
+ // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>>
+ // CHECK: return
+ return %d : vector<1x2xf64>
+}
+
+// -----
+
+
+// CHECK-LABEL: @ldmatrix_x4
+func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
+ %c0 = arith.constant 0 : index
+ // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+ %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ return %a : vector<4x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @ldmatrix_x1
+func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
+ %c0 = arith.constant 0 : index
+ // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+ %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ return %a : vector<1x2xf16>
+}
diff --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
index 8a52180676445..5a35d39f1acca 100644
--- a/mlir/test/Dialect/NVGPU/roundtrip.mlir
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -8,3 +8,13 @@ func @ldmatrix(%arg0: memref<?x?xf16, 3>, %x: index, %y: index) {
memref<?x?xf16, 3> -> vector<4x2xf16>
return
}
+
+// CHECK-LABEL: func @mma_sync(
+func @mma_sync(%arg0: vector<4x2xf16>,
+ %arg1: vector<2x2xf16>,
+ %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+// CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} :
+ (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
More information about the Mlir-commits
mailing list