[Mlir-commits] [mlir] 77d2c81 - [MLIR][GPU] Add GPU ops nvvm.mma.sync, nvvm.mma.ldmatrix, lane_id
Thomas Raoux
llvmlistbot at llvm.org
Wed Apr 13 15:50:25 PDT 2022
Author: Christopher Bate
Date: 2022-04-13T22:50:07Z
New Revision: 77d2c815f50b20d18f1207e4f442e2cf8eb8cec0
URL: https://github.com/llvm/llvm-project/commit/77d2c815f50b20d18f1207e4f442e2cf8eb8cec0
DIFF: https://github.com/llvm/llvm-project/commit/77d2c815f50b20d18f1207e4f442e2cf8eb8cec0.diff
LOG: [MLIR][GPU] Add GPU ops nvvm.mma.sync, nvvm.mma.ldmatrix, lane_id
This change adds three new operations to the GPU dialect: gpu.mma.sync,
gpu.mma.ldmatrix, and gpu.lane_id. The former two are meant to target
the lower level nvvm.mma.sync and nvvm.ldmatrix instructions, respectively.
Lowerings are added for the new GPU operations for conversion to
NVVM.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D123647
Added:
mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 9b42ecfaae215..eaee42bb8e319 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -96,6 +96,19 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
}];
}
+def GPU_LaneIdOp : GPU_Op<"lane_id", [NoSideEffect]> {
+ let description = [{
+ Returns the lane id within the subgroup (warp/wave).
+
+ Example:
+ ```mlir
+ %laneId = gpu.lane_id
+ ```
+ }];
+ let results = (outs Index:$result);
+ let assemblyFormat = "attr-dict";
+}
+
def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [NoSideEffect]>,
Arguments<(ins)>, Results<(outs Index:$result)> {
let description = [{
@@ -1354,4 +1367,58 @@ def GPU_DeviceAsyncWaitOp : GPU_Op<"device_async_wait", []> {
}];
}
+def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix",
+ [MemoryEffects<[MemRead]>]> {
+ let description = [{
+ The `gpu.mma.ldmatrix` op represents loading a matrix fragment from
+ memory. The load source and result type must be compatible with lowering
+ to the `nvvm.ldmatrix` instruction. This op is meant to represent
+ the distributed version of a `vector.transfer_read` as an intermediate
+ step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`.
+
+ Example:
+
+ ```mlir
+ gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16>
+ ```
+ }];
+
+ let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$srcMemref,
+ Variadic<Index>:$indices, BoolAttr:$transpose,
+ I32Attr:$numTiles);
+ let results = (outs AnyVector:$res);
+ let assemblyFormat = [{
+ $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
+ }];
+}
+
+def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> {
+ let description = [{
+ The `gpu.mma.sync` op represents the distributed form of a collective
+ matrix-multiply-and-accumulate (mma) operation that is compatible with
+ `nvvm.mma.sync`. The operands and results are fragments of the full matrix
+ operands. The full shape of the distributed mma operation is given by the
+ `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.
+
+ This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
+ is an intermediate point between lowering from `vector.contract` to
+ `nvvm.mma.sync`.
+
+ Example:
+
+ ```mlir
+ gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ ```
+ }];
+ let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC,
+ I64ArrayAttr:$mmaShape);
+
+ let results = (outs AnyVector:$res);
+
+ let assemblyFormat = [{
+ `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
+ `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
+ }];
+}
+
#endif // GPU_OPS
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9efdc8a832a35..e5145f6513fdf 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -208,6 +208,314 @@ struct GPUAsyncWaitLowering
}
};
+struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp> {
+ using ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = getContext();
+ Location loc = op->getLoc();
+
+ // The result type of ldmatrix will always be a struct of 32bit integer
+ // registers if more than one 32bit value is returned. Otherwise, the result
+ // is a single i32. The result type of the GPU operation is always a vector
+ // of shape (NumRegisters, VectorRegister) where VectorRegister is the
+ // vector type of the result and always 32 bits long. We bitcast the result
+ // of the NVVM::LdMatrix to this vector type.
+ auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ if (!vectorResultType) {
+ return failure();
+ }
+ Type innerVectorType = LLVM::getFixedVectorType(
+ vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+
+ int64_t num32BitRegs = vectorResultType.getDimSize(0);
+
+ Type ldMatrixResultType;
+ if (num32BitRegs > 1) {
+ ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
+ ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
+ } else {
+ ldMatrixResultType = rewriter.getI32Type();
+ }
+
+ auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
+ Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
+ adaptor.indices(), rewriter);
+ Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
+ loc, ldMatrixResultType, srcPtr,
+ /*num=*/op.numTiles(),
+ /*layout=*/op.transpose() ? NVVM::MMALayout::col
+ : NVVM::MMALayout::row);
+
+ // The ldmatrix operation returns either a single i32 value or a struct of
+ // i32 values. Here we unpack those values and cast them back to their
+ // actual vector type (still of width 32b) and repack them into a result
+ // struct.
+ Type finalResultType = typeConverter->convertType(vectorResultType);
+ Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+ for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
+ Value i32Register = num32BitRegs > 1
+ ? rewriter.create<LLVM::ExtractValueOp>(
+ loc, rewriter.getI32Type(), ldMatrixResult,
+ rewriter.getI64ArrayAttr(i))
+ : ldMatrixResult;
+ Value casted =
+ rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the
+/// op is actually lowered. If the type of an operands is not already
+/// converted it hints a missing typeConversion and failure is returned in
+/// that case.
+LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+ ConversionPatternRewriter &rewriter) {
+ if (!llvm::all_of(operands, [](Value value) {
+ return LLVM::isCompatibleType(value.getType());
+ })) {
+ return rewriter.notifyMatchFailure(
+ op, "cannot convert if operands aren't of LLVM type.");
+ }
+
+ return success();
+}
+
+/// Returns the type for the intrinsic given the vectorResultType of the
+/// `gpu.mma.sync` operation.
+Type inferIntrinsicResultType(Type vectorResultType) {
+ MLIRContext *ctx = vectorResultType.getContext();
+ auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
+ auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+ auto i32Ty = IntegerType::get(ctx, 32);
+ auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ Type f64Ty = Float64Type::get(ctx);
+ Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+ if (a.getElementType() == f16x2Ty) {
+ return LLVM::LLVMStructType::getLiteral(
+ ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
+ }
+ if (a.getElementType() == i32x2Ty) {
+ return LLVM::LLVMStructType::getLiteral(
+ ctx,
+ SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
+ }
+ if (a.getElementType() == f64x2Ty) {
+ return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
+ }
+ return vectorResultType;
+}
+
+/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
+/// always an LLVM struct) into a fragment that is compatible with the vector
+/// type of this operation. This involves extracting elements from the struct
+/// and inserting them into an LLVM array. These extra data-movement
+/// operations should be canonicalized away by the LLVM backend.
+Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
+ Type resultType, Value intrinsicResult,
+ RewriterBase &rewriter) {
+ MLIRContext *ctx = rewriter.getContext();
+ auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
+ auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
+ Type i32Ty = rewriter.getI32Type();
+ Type f64Ty = rewriter.getF64Type();
+ Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
+ Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+
+ auto makeConst = [&](int32_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
+ rewriter.getI32IntegerAttr(index));
+ };
+
+ if (arrayType) {
+ SmallVector<Value, 4> elements;
+
+ if (arrayType.getElementType() == f16x2Ty) {
+ for (unsigned i = 0; i < structType.getBody().size(); i++) {
+ elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, structType.getBody()[i], intrinsicResult,
+ rewriter.getI64ArrayAttr(i)));
+ }
+ }
+
+ // The intrinsic returns i32 and f64 values as individual scalars. We need
+ // to extract them from the struct and pack them into vectors.
+ if (arrayType.getElementType() == i32x2Ty ||
+ arrayType.getElementType() == f64x2Ty) {
+ Value vec =
+ rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
+ for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
+ Value x1 = rewriter.create<LLVM::ExtractValueOp>(
+ loc, structType.getBody()[i * 2], intrinsicResult,
+ rewriter.getI64ArrayAttr(i * 2));
+ Value x2 = rewriter.create<LLVM::ExtractValueOp>(
+ loc, structType.getBody()[i * 2 + 1], intrinsicResult,
+ rewriter.getI64ArrayAttr(i * 2 + 1));
+ vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+ x1, makeConst(0));
+ vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+ x2, makeConst(1));
+ }
+ elements.push_back(vec);
+ }
+
+ // Create the final vectorized result.
+ Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
+ for (const auto &el : llvm::enumerate(elements)) {
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, arrayType, result, el.value(),
+ rewriter.getI64ArrayAttr(el.index()));
+ }
+ return result;
+ }
+
+ return intrinsicResult;
+}
+
+/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
+/// given as 2D `vectors` where the rows are 32b or 64b wide. The
+/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
+/// scalars of certain types. This function helps unpack the `vector` arguments
+/// and cast them to the types expected by `nvvm.mma.sync`.
+SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, Location loc,
+ Value operand) {
+ SmallVector<Value> result;
+ Type i32Ty = rewriter.getI32Type();
+ Type f64Ty = rewriter.getF64Type();
+ Type i8Ty = rewriter.getI8Type();
+ Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
+ auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
+
+ for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
+ Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+ loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
+
+ // For 4xi8 vectors, the intrinsic expects these to be provided as i32
+ // scalar types.
+ if (arrayTy.getElementType() == i8x4Ty) {
+ result.push_back(
+ rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+ continue;
+ }
+
+ // For some element types (i32, f64), we need to unpack the inner
+ // vector/array type as well because the intrinsic expects individual
+ // scalars to be provided.
+ VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
+ if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
+ innerArrayTy.getElementType() == f64Ty)) {
+ for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
+ idx < innerSize; idx++) {
+ result.push_back(rewriter.create<LLVM::ExtractElementOp>(
+ loc, toUse,
+ rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+ }
+ continue;
+ }
+ result.push_back(toUse);
+ }
+ return result;
+}
+
+struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<gpu::MmaSyncOp> {
+ using ConvertOpToLLVMPattern<gpu::MmaSyncOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) {
+ return failure();
+ }
+
+ // Get the shapes of the MMAMatrix type being used. The shapes will
+ // choose which intrinsic this op will be lowered to.
+ auto aType = op.matrixA().getType().cast<VectorType>();
+
+ int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
+ int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
+ int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
+ std::array<int64_t, 3> gemmShape{m, n, k};
+
+ SmallVector<Value> matA =
+ unpackOperandVector(rewriter, loc, adaptor.matrixA());
+ SmallVector<Value> matB =
+ unpackOperandVector(rewriter, loc, adaptor.matrixB());
+ SmallVector<Value> matC =
+ unpackOperandVector(rewriter, loc, adaptor.matrixC());
+
+ NVVM::MMATypes ptxTypeA;
+ NVVM::MMATypes ptxTypeB;
+ Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
+ if (aType.getElementType().isInteger(8)) {
+ ptxTypeA = NVVM::MMATypes::s8;
+ ptxTypeB = NVVM::MMATypes::s8;
+ overflow = NVVM::MMAIntOverflow::satfinite;
+
+ } else if (aType.getElementType().isF16()) {
+ ptxTypeA = NVVM::MMATypes::f16;
+ ptxTypeB = NVVM::MMATypes::f16;
+ } else if (aType.getElementType().isF64()) {
+ ptxTypeA = NVVM::MMATypes::f64;
+ ptxTypeB = NVVM::MMATypes::f64;
+ } else {
+ return op->emitError("could not deduce operand PTX types");
+ }
+
+ Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
+ Type intrinsicResTy = inferIntrinsicResultType(
+ typeConverter->convertType(op->getResultTypes()[0]));
+ Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
+ op.getLoc(), intrinsicResTy, matA, matB, matC,
+ /*shape=*/gemmShape,
+ /*b1Op=*/llvm::None,
+ /*intOverflow=*/overflow,
+ /*multiplicandPtxTypes=*/
+ std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
+ /*multiplicandLayouts=*/
+ std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
+ NVVM::MMALayout::col});
+ rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
+ desiredRetTy, intrinsicResult,
+ rewriter));
+ return success();
+ }
+};
+
+struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
+ using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ MLIRContext *context = rewriter.getContext();
+ Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
+ // Truncate or extend the result depending on the index bitwidth specified
+ // by the LLVMTypeConverter options.
+ const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
+ if (indexBitwidth > 32) {
+ newOp = rewriter.create<LLVM::SExtOp>(
+ loc, IntegerType::get(context, indexBitwidth), newOp);
+ } else if (indexBitwidth < 32) {
+ newOp = rewriter.create<LLVM::TruncOp>(
+ loc, IntegerType::get(context, indexBitwidth), newOp);
+ }
+ rewriter.replaceOp(op, {newOp});
+ return success();
+ }
+};
+
/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"
@@ -303,7 +611,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
- GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
+ GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering,
+ MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 973166b2b5b4d..771e2f10d3bdb 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -6,7 +6,8 @@ gpu.module @test_module {
// CHECK32-LABEL: func @gpu_index_ops()
func.func @gpu_index_ops()
-> (index, index, index, index, index, index,
- index, index, index, index, index, index) {
+ index, index, index, index, index, index,
+ index) {
// CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64
// CHECK: = nvvm.read.ptx.sreg.tid.x : i32
@@ -49,10 +50,17 @@ gpu.module @test_module {
// CHECK: = llvm.sext %{{.*}} : i32 to i64
%gDimZ = gpu.grid_dim z
+
+ // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+ // CHECK: = llvm.sext %{{.*}} : i32 to i64
+ %laneId = gpu.lane_id
+
func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
- %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ
+ %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
+ %laneId
: index, index, index, index, index, index,
- index, index, index, index, index, index
+ index, index, index, index, index, index,
+ index
}
}
diff --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
new file mode 100644
index 0000000000000..2f70a15badb7d
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
@@ -0,0 +1,129 @@
+// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
+
+gpu.module @test_module {
+ // CHECK-LABEL: @m16n8k16_fp16
+ func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-NOT llvm.extractvalue
+ // CHECK: [[d:%.+]] = nvvm.mma.sync
+ // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}
+ %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
+ return %d : vector<2x2xf16>
+ }
+
+ // CHECK-LABEL: @m16n8k8_fp16
+ func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-NOT llvm.extractvalue
+ // CHECK: [[d:%.+]] = nvvm.mma.sync
+ // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}
+ %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
+ return %d : vector<2x2xf16>
+ }
+
+ // CHECK-LABEL: @m16n8k32_int8
+ func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+ // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+ // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+
+ // CHECK: [[d:%.+]] = nvvm.mma.sync
+ // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
+ // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
+ // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
+ // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+ %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+
+ // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>>
+ return %d : vector<2x2xi32>
+ }
+
+ // CHECK-LABEL: @m8n8k4_f64
+ func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+ // CHECK: llvm.extractvalue %arg0
+ // CHECK: llvm.extractvalue %arg1
+ // CHECK: llvm.extractvalue %arg2
+
+ // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
+ // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
+ %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ // CHECK: llvm.mlir.undef : vector<2xf64>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
+ // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64>
+ // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>>
+ // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>>
+ return %d : vector<1x2xf64>
+ }
+
+ // CHECK-LABEL: @ldmatrix_x4
+ func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
+ %c0 = arith.constant 0 : index
+ // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+ %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ // CHECK: llvm.extractvalue
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ return %a : vector<4x2xf16>
+ }
+
+ // CHECK-LABEL: @ldmatrix_x1
+ func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
+ %c0 = arith.constant 0 : index
+ // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+ %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
+ // CHECK: llvm.bitcast
+ // CHECK: llvm.insertvalue
+ return %a : vector<1x2xf16>
+ }
+}
More information about the Mlir-commits
mailing list