[Mlir-commits] [mlir] 114ba72 - [mlir][NVGPU] Handle native mma.sync and ldmatrix(x4) sizes
Manish Gupta
llvmlistbot at llvm.org
Wed Oct 19 17:16:59 PDT 2022
Author: Manish Gupta
Date: 2022-10-19T17:10:21-07:00
New Revision: 114ba722c1e58d23bafdf3654e4f8e537150c318
URL: https://github.com/llvm/llvm-project/commit/114ba722c1e58d23bafdf3654e4f8e537150c318
DIFF: https://github.com/llvm/llvm-project/commit/114ba722c1e58d23bafdf3654e4f8e537150c318.diff
LOG: [mlir][NVGPU] Handle native mma.sync and ldmatrix(x4) sizes
This patch handles native `mma.sync` sizes and enables issuing `ldmatrix` on
largest possible tiles for matrixB. It requires handling
`vector.extract_strided_slice` from vector to ngpu lowering.
Differential Revision: https://reviews.llvm.org/D135749
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
index 699e9fdb25a0b..fac99dc048dba 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
@@ -13,25 +13,22 @@
#ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
#define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
namespace mlir {
-namespace vector {
-enum class IteratorType : uint32_t;
-class ContractionOp;
-} // namespace vector
-
-namespace NVVM {
-enum class MMALayout : uint32_t;
-} // namespace NVVM
-
namespace nvgpu {
/// Represents the role of an operand in an MMA instruction:
/// `result := matmul(A, B) + C`
enum class MatMulOperandRole : int32_t { A = 0, B, C };
+/// Returns the first user of the `op` that is vector.contract. If no
+/// vector.contract user exists, return failure.
+FailureOr<vector::ContractionOp> getUserContract(Operation *op);
+
/// Collects information about a warp-level matrix operand represented by a
/// VectorType.
struct WarpMatrixInfo {
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index f4528b178e656..01654fdd6024a 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -192,6 +192,33 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
return convertElementwiseOpToMMA(op).has_value();
}
+/// Returns true if the extract strided slice op is supported with `mma.sync`
+/// path.
+static bool
+extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
+
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return false;
+
+ FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
+ if (failed(contractOp))
+ return false;
+
+ // Handle vector.extract_strided_slice on registers containing
+ // matrixB and matrixC operands. vector.extract_strided_slice op
+ // is not supported on registers containing matrixA operands.
+ if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
+ return (op->getResult(0).getType().cast<VectorType>() ==
+ (*contractOp).getRhs().getType().cast<VectorType>());
+ else if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
+ return (op->getResult(0).getType().cast<VectorType>() ==
+ (*contractOp).getAcc().getType().cast<VectorType>());
+
+ return false;
+}
+
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
@@ -199,6 +226,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite);
+ if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
+ return useNvGpu &&
+ extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
if (auto contract = dyn_cast<vector::ContractionOp>(op))
return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
@@ -338,8 +368,10 @@ struct PrepareContractToGPUMMA
}
};
-// Merge transpose op into the transfer read op. Transpose are not supported on
-// MMA types but MMA load can transpose the matrix when loading.
+// Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
+// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
+// respectively. We can fold the transpose operation when loading the data from
+// Shared Memory to registers.
struct CombineTransferReadOpTranspose final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
@@ -620,7 +652,7 @@ convertTransferReadToLoads(vector::TransferReadOp op,
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
// When we are transposing the B operand, ldmatrix will only work if we have
- // at least 8 rows to read and the width to read for the transpose is 128
+ // at least 8 rows to read and the width to read for the transpose is 128
// bits.
if (!op.getPermutationMap().isMinorIdentity() &&
(bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
@@ -671,6 +703,83 @@ convertTransferWriteToStores(vector::TransferWriteOp op,
return success();
}
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+ SmallVectorImpl<int64_t> &results) {
+ for (auto attr : arrayAttr)
+ results.push_back(attr.cast<IntegerAttr>().getInt());
+}
+
+static LogicalResult
+convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+
+ OpBuilder b(op);
+ Location loc = op->getLoc();
+
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return failure();
+
+ FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
+ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+ if (failed(mmaSyncFragmentInfo))
+ return failure();
+
+ // Find the vector.transer_read whose result vector is being sliced.
+ auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
+ if (!transferReadOp)
+ return failure();
+
+ warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
+ if (failed(warpMatrixInfo))
+ return failure();
+
+ FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
+ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+ if (failed(ldFragmentInfo))
+ return failure();
+
+ assert(
+ (mmaSyncFragmentInfo->elementsPerRegister ==
+ ldFragmentInfo->elementsPerRegister) &&
+ "Number of elements per register should be same for load and mma.sync");
+
+ // Create vector.extract_strided_slice op for thread-owned fragments.
+ std::array<int64_t, 2> strides = {1,
+ 1}; // stride for extract slice is always 1.
+ std::array<int64_t, 2> sliceShape = {
+ mmaSyncFragmentInfo->numRegistersPerFragment,
+ mmaSyncFragmentInfo->elementsPerRegister};
+ auto sourceVector = valueMapping.find(transferReadOp)->second;
+
+ // offset and sizes at warp-level of onwership.
+ SmallVector<int64_t> offsets;
+ populateFromInt64AttrArray(op.getOffsets(), offsets);
+
+ SmallVector<int64_t> sizes;
+ populateFromInt64AttrArray(op.getSizes(), sizes);
+ ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
+
+ // Compute offset in vector registers. Note that the mma.sync vector registers
+ // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
+ // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
+ std::array<int64_t, 2> sliceOffset = {0, 0};
+
+ if (offsets[0] && offsets[1])
+ return op->emitError() << "Slicing fragments in 2D is not supported. ";
+ else if (offsets[0])
+ sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
+ else if (offsets[1])
+ sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
+
+ Value newOp = b.create<vector::ExtractStridedSliceOp>(
+ loc, sourceVector, sliceOffset, sliceShape, strides);
+
+ valueMapping[op] = newOp;
+ return success();
+}
+
static void convertContractOp(vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
@@ -858,6 +967,10 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
return convertTransferWriteToStores(transferWriteOp,
valueMapping);
})
+ .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
+ return convertExtractStridedSlice(extractStridedSliceOp,
+ valueMapping);
+ })
.Case([&](vector::ContractionOp contractionOp) {
return convertContractOpToMmaSync(contractionOp, valueMapping);
})
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 18fc4e600db9e..6de16f84668ad 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -45,14 +45,24 @@ static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
lineSizeBits};
}
+/// Returns the first user of the `op` that is vector.contract. If no
+/// vector.contract user exists, return failure.
+FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
+ for (Operation *user : op->getUsers()) {
+ if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
+ return contractOp;
+ }
+ return failure();
+}
+
FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
WarpMatrixInfo info;
- // Determine the vector type.
+ // Determine the vector type at warp-level.
if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
info.vectorType = writeOp.getVectorType();
} else if (isa<vector::TransferReadOp, vector::ContractionOp,
- arith::ConstantOp>(op)) {
+ vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
info.vectorType = op->getResult(0).getType().cast<VectorType>();
} else {
return op->emitError()
@@ -62,19 +72,15 @@ FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
// Determine the operand role. We assume it is an accumulator/result unless it
// is directly consumed by a `vector.contract` op.
info.operandRole = MatMulOperandRole::C;
- for (Operation *user : op->getUsers()) {
- auto contract = dyn_cast<vector::ContractionOp>(user);
- if (!contract)
- continue;
- if (contract.getLhs() == op->getResult(0)) {
- info.operandRole = MatMulOperandRole::A;
- break;
- }
- if (contract.getRhs() == op->getResult(0)) {
- info.operandRole = MatMulOperandRole::B;
- break;
- }
- }
+ FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
+ if (failed(contractOp))
+ return info;
+
+ if ((*contractOp).getLhs() == op->getResult(0))
+ info.operandRole = MatMulOperandRole::A;
+ else if ((*contractOp).getRhs() == op->getResult(0))
+ info.operandRole = MatMulOperandRole::B;
+
return info;
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
index 42dc06c937d40..ae6329c22eff7 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -164,9 +164,9 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
// -----
-//#########################################################
-// FP16 row-row-row
-//#########################################################
+//#########################################################################
+// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x2 for matrixB)
+//#########################################################################
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -203,6 +203,62 @@ func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<
// -----
+//#########################################################################
+// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
+//#########################################################################
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)>
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n16k16_mmasync16816_fp16_f16_row_row_row
+func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16, 3>, %arg1: memref<32x64xf16, 3>, %arg2: memref<42x64xf16, 3>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %cst = arith.constant 0.000000e+00 : f16
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+ // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true}
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16>
+
+ // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+ // CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+ %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
+ %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+ vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
+
+ // CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+ // CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+ %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
+ %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+ vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
+
+ return
+}
+// -----
+
// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
More information about the Mlir-commits
mailing list