[Mlir-commits] [mlir] [mlir][nvgpu] Select warpgroup id on `warpgroup.mma.store` (PR #85820)
Guray Ozen
llvmlistbot at llvm.org
Tue Mar 19 09:58:48 PDT 2024
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/85820
`warpgroup.mma.store` Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0.
This PR adds a new operand to `warpgroup.mma.store` Op that allows selecting different warpgroup. For example:
```
nvgpu.warpgroup.mma.store
%res_m64n16k,
%shmem_m64n16k,
warpgroup_id = %id // <-- PR adds this one
: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>
to memref<64x16xf32,3>
```
>From 9d4bb956dd4e78c0fb9bae69372064a5192a3762 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 19 Mar 2024 16:53:28 +0000
Subject: [PATCH] [mlir][nvgpu] Select warpgroup id on `warpgroup.mma.store`
`warpgroup.mma.store` Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0.
This PR adds a new operand to `warpgroup.mma.store` Op that allows selecting different warpgroup.
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 12 +++--
.../mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 3 ++
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 30 +++++++-----
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 46 +++++++++++++++++++
4 files changed, 75 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe9..d22b8fd28582c7 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -775,17 +775,21 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
in $matrixD to given memref.
+ Note that, the op must be run with warp group. The operand `warpgroupId`
+ allow to select the warp group to run the operation. When it is not present,
+ the first warp group runs the operation.
+
[See the details of register fragment layout for accumulator matrix D]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
-
- Note that, the op must be run with warp group.
}];
let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
- Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+ Arg<AnyMemRef, "", [MemWrite]>:$dstMemref,
+ Optional<I32>:$warpgroupId);
let assemblyFormat = [{
- $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ $matrixD `,` $dstMemref (`,` `warpgroup_id` `=` $warpgroupId^)?
+ attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 19070f6f062a02..0d16b825821252 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
constexpr int kWarpSize = 32;
+/// Number of threads in warpgroup
+constexpr int kWarpgroupSize = 128;
+
/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9b5d19ebd783a9..5a8f095b84c791 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1087,10 +1087,9 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
+ LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:"
+ << leadDimVal << "\t" << "stride_off :" << strideDimVal
+ << "\t" << "base_offset:" << offsetVal << "\t"
<< "layout_type:" << swizzle << " ("
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr << "\n");
@@ -1382,13 +1381,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
+ LLVM_DEBUG(DBGS() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k"
+ << wgmmaK << "(A[" << (iterationM * wgmmaM) << ":"
<< (iterationM * wgmmaM) + wgmmaM << "]["
<< (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+ << (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
<< wgmmaN << "])\n");
@@ -1535,8 +1533,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
/// \param offset: the offset within the memref where the registers will be
/// stored.
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
- TypedValue<MemRefType> dstMemref,
- int offset) const {
+ TypedValue<MemRefType> dstMemref, int offset,
+ Value warpgroupId) const {
Type i32 = b.getI32Type();
auto makeConst = [&](int32_t index) -> Value {
@@ -1569,6 +1567,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
};
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+ // Normalize the thread index to the beginning of the warpgroup
+ if (warpgroupId) {
+ Value s1 =
+ b.create<arith::MulIOp>(warpgroupId, makeConst(kWarpgroupSize));
+ tidx = b.create<arith::SubIOp>(tidx, s1);
+ }
+
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
@@ -1610,7 +1615,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = matrixD.cast<LLVM::LLVMStructType>();
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
- storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
+ storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset,
+ adaptor.getWarpgroupId());
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dbf8ead49f78db..ae9fcd287bdded 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1182,6 +1182,52 @@ func.func @warpgroup_mma_store_multiple(
return
}
+// CHECK-LABEL: @warpgroup_mma_store_multiple_with_id(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x16xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: i32)
+func.func @warpgroup_mma_store_multiple_with_id(
+ %res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
+ %shmem_m64n16k : memref<64x16xf32, 3>,
+ %id : i32)
+{
+ // CHECK: %[[s0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
+ // CHECK: %[[s1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : memref<64x16xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
+ // CHECK: %[[s3:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[s5:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[s6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[s7:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[s8:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: %[[s9:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+ // CHECK: %[[s10:.*]] = llvm.mlir.constant(128 : i32) : i32
+ // CHECK: %[[s11:.*]] = arith.muli %[[arg2]], %[[s10]] : i32
+ // CHECK: %[[s12:.*]] = arith.subi %[[s9]], %[[s11]] : i32
+ // CHECK: %[[s13:.*]] = llvm.urem %12, %8 : i32
+ // CHECK: %[[s14:.*]] = llvm.udiv %[[s12]], %[[s8]] : i32
+ // CHECK: %[[s15:.*]] = llvm.udiv %[[s13]], %[[s5]] : i32
+ // CHECK: %[[s16:.*]] = llvm.urem %[[s13]], %[[s5]] : i32
+ // CHECK: %[[s17:.*]] = llvm.mul %[[s16]], %[[s4]] : i32
+ // CHECK: %[[s18:.*]] = llvm.mul %[[s14]], %[[s7]] : i32
+ // CHECK: %[[s19:.*]] = llvm.add %[[s15]], %[[s18]] : i32
+ // CHECK: %[[s20:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[s21:.*]] = llvm.mul %[[s20]], %[[s6]] : i32
+ // CHECK: %[[s22:.*]] = llvm.add %[[s19]], %[[s21]] : i32
+ // CHECK: %[[s23:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[s24:.*]] = llvm.mul %[[s23]], %[[s6]] : i32
+ // CHECK: %[[s25:.*]] = llvm.add %[[s17]], %[[s24]] : i32
+ // CHECK: %[[s26:.*]] = arith.index_cast %[[s22]] : i32 to index
+ // CHECK: %[[s27:.*]] = arith.index_cast %[[s25]] : i32 to index
+ // CHECK: %[[s28:.*]] = llvm.add %[[s25]], %[[s3]] : i32
+ // CHECK: %[[s29:.*]] = arith.index_cast %[[s28]] : i32 to index
+ // CHECK: %[[s30:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[s31:.*]] = llvm.extractvalue %[[s2]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: memref.store %[[s30]], %[[arg1]][%[[s26]], %[[s27]]] : memref<64x16xf32, 3>
+ // CHECK: memref.store %[[s31]], %[[arg1]][%[[s26]], %[[s29]]] : memref<64x16xf32, 3>
+ // CHECK-COUNT-6: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32, 3>
+ nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k, warpgroup_id = %id : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32,3>
+ return
+}
+
func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
More information about the Mlir-commits
mailing list