[Mlir-commits] [mlir] cce3e8e - [MLIR][NVGPU] Introduction of wgmma.generate.descriptor Op
Guray Ozen
llvmlistbot at llvm.org
Tue Aug 22 07:12:38 PDT 2023
Author: Guray Ozen
Date: 2023-08-22T16:12:25+02:00
New Revision: cce3e8ed895b2d4c1396929c363c071e15fdbf8b
URL: https://github.com/llvm/llvm-project/commit/cce3e8ed895b2d4c1396929c363c071e15fdbf8b
DIFF: https://github.com/llvm/llvm-project/commit/cce3e8ed895b2d4c1396929c363c071e15fdbf8b.diff
LOG: [MLIR][NVGPU] Introduction of wgmma.generate.descriptor Op
This work introduces a new Op, `wgmma.generate.descriptor`, designed to create a wgmma descriptor for inputs of matrix multiply and accumulate operations using `wgmma.mma_async` PTX instruction.
The descriptor format specifications can be found in the following link:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
It's important to note that this op is in its initial phase, and it does come with certain limitations. It only supports 128b swizzling and does not incorporate interleaving. In the future, different calculations will be addressed in separate works, expanding the capabilities of the op.
Reviewed By: qcolombet
Differential Revision: https://reviews.llvm.org/D157382
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 85c03d85a7d9aa..8d9237337401a3 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -625,4 +625,35 @@ def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let hasVerifier = 1;
}
+def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
+ let summary = "Generate a wgmma matrix descriptor";
+ let description = [{
+ This Op builds a wgmma descriptor that is used by wgmma matrix multiply
+ and accumulate.
+
+ The descriptor specifies the properties of the matrix in shared memory that
+ is a multiplicand in the matrix multiply and accumulate operation.
+
+ The descriptor is a 64-bit value contained in a register with the following
+ ```
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63|
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits|
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle|
+ +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+
+ ```
+
+ See for more details:
+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
+
+ }];
+ let results = (outs I64:$descriptor);
+ let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$tensor,
+ NVGPU_TensorMapDescriptor:$tensorMap);
+ let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}];
+ let hasVerifier = 1;
+}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 21c6780cc7887f..f826f20db69ebf 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -934,6 +934,76 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};
+struct NVGPUGenerateGmmaDescriptorLowering
+ : public ConvertOpToLLVMPattern<nvgpu::GenerateGmmaDescriptorOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Location loc = op->getLoc();
+
+ nvgpu::TensorMapSwizzleKind swizzleKind =
+ op.getTensorMap().getType().getSwizzle();
+
+ unsigned layout =
+ (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128
+ : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64
+ : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32
+ : 1;
+ unsigned swizzle =
+ (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1
+ : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2
+ : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
+ : 0;
+
+ auto ti64 = rewriter.getIntegerType(64);
+ auto makeConst = [&](uint64_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, ti64, rewriter.getI64IntegerAttr(index));
+ };
+ auto shiftLeft = [&](Value value, unsigned shift) -> Value {
+ return rewriter.create<LLVM::ShlOp>(loc, ti64, value, makeConst(shift));
+ };
+ auto shiftRight = [&](Value value, unsigned shift) -> Value {
+ return rewriter.create<LLVM::LShrOp>(loc, ti64, value, makeConst(shift));
+ };
+ auto insertBit = [&](Value desc, Value val, int startBit) {
+ return rewriter.create<LLVM::OrOp>(loc, ti64, desc,
+ shiftLeft(val, startBit));
+ };
+
+ int ex4LSB = 4;
+ Value strideDim = makeConst((layout << 3) >> ex4LSB);
+ int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
+ Value leadDim = makeConst((sizeN * layout) >> ex4LSB);
+ Value baseAddr = getStridedElementPtr(
+ op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
+ adaptor.getTensor(), {}, rewriter);
+ Value basePtr = rewriter.create<LLVM::PtrToIntOp>(loc, ti64, baseAddr);
+ // Just use 14 bits for base address
+ Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
+
+ int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32,
+ startLeadBit = 16, startBaseAddrBit = 0;
+ Value dsc = makeConst(0);
+ // // [62,64) swizzle type
+ dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit);
+ // // [49,52) base_offset
+ dsc = insertBit(dsc, makeConst(0), startOffsetBit);
+ // // [32,46) stride
+ dsc = insertBit(dsc, strideDim, startStrideBit);
+ // // [16,30) leading dimension
+ dsc = insertBit(dsc, leadDim, startLeadBit);
+ // // [0,14) start_address
+ dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
+
+ rewriter.replaceOp(op, dsc);
+ return success();
+ }
+};
static Value makeI64Const(RewriterBase &rewriter, Operation *op,
int32_t index) {
@@ -1064,6 +1134,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
+ NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 24c490568a4383..d832a983a132d6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -366,6 +366,42 @@ LogicalResult TmaCreateDescriptorOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// NVGPU_GenerateGmmaDescriptorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GenerateGmmaDescriptorOp::verify() {
+ MemRefType memrefType = getTensor().getType();
+ MemRefType tensorMapType = getTensorMap().getType().getTensor();
+
+ if (memrefType != tensorMapType)
+ return emitError() << "memref and tensor map type mismatch";
+
+ if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
+ return emitError() << "supports only static shapes";
+
+ if (memrefType.getRank() != 2)
+ return emitError() << "supports only 2d memref is supported for now";
+
+ if (getTensorMap().getType().getSwizzle() !=
+ TensorMapSwizzleKind::SWIZZLE_128B) {
+ return emitError() << "supports only "
+ << stringifyTensorMapSwizzleKind(
+ TensorMapSwizzleKind::SWIZZLE_128B)
+ << " is supported for the time being";
+ }
+
+ if (getTensorMap().getType().getInterleave() !=
+ TensorMapInterleaveKind::INTERLEAVE_NONE) {
+ return emitError() << "supports only "
+ << stringifyTensorMapInterleaveKind(
+ TensorMapInterleaveKind::INTERLEAVE_NONE)
+ << " is supported for the time being";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 9721b371aa90bf..5604f5027fd3cd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -631,6 +631,47 @@ module @mymodule {
}
}
+!tensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16,3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
+memref.global "private" @dynamicShmem : memref<0xf16,3>
+// CHECK-LABEL: func @create_wgmma_descriptor(
+func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{
+ %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
+ %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+ // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3>
+ // CHECK: %[[Sre:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3>
+ // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[Sre]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64
+ // CHECK: %[[c1024:.+]] = llvm.mlir.constant(1024 : i64) : i64
+ // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S3:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64
+ // CHECK: %[[S4:.+]] = llvm.mlir.constant(46 : i64) : i64
+ // CHECK: %[[S5:.+]] = llvm.shl %[[S3]], %[[S4]] : i64
+ // CHECK: %[[S6:.+]] = llvm.mlir.constant(50 : i64) : i64
+ // CHECK: %[[S7:.+]] = llvm.lshr %[[S5]], %[[S6]] : i64
+ // CHECK: %[[S8:.+]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[S9:.+]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[S10:.+]] = llvm.mlir.constant(62 : i64) : i64
+ // CHECK: %[[S11:.+]] = llvm.shl %[[S9]], %[[S10]] : i64
+ // CHECK: %[[S12:.+]] = llvm.or %[[S8]], %[[S11]] : i64
+ // CHECK: %[[S13:.+]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[S14:.+]] = llvm.mlir.constant(49 : i64) : i64
+ // CHECK: %[[S15:.+]] = llvm.shl %[[S13]], %[[S14]] : i64
+ // CHECK: %[[S16:.+]] = llvm.or %[[S12]], %[[S15]] : i64
+ // CHECK: %[[S18:.+]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK: %[[S19:.+]] = llvm.shl %[[c64]], %[[S18]] : i64
+ // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]] : i64
+ // CHECK: %[[S22:.+]] = llvm.mlir.constant(16 : i64) : i64
+ // CHECK: %[[S23:.+]] = llvm.shl %[[c1024]], %[[S22]] : i64
+ // CHECK: %[[S24:.+]] = llvm.or %[[S20]], %[[S23]] : i64
+ // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64
+ // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64
+ // CHECK: return %[[S27]] : i64
+ %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap
+ func.return %descA : i64
+}
+
+
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
More information about the Mlir-commits
mailing list