[Mlir-commits] [mlir] [mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (PR #78413)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 17 01:06:10 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.
---
Full diff: https://github.com/llvm/llvm-project/pull/78413.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-3)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+28)
``````````diff
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5f..9e4ae219eefd60b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+ auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
TypedValue<::mlir::MemRefType> memref) {
Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
if (offset)
ti = makeAdd(ti, makeConst(offset));
- for (int i = 0; i < 2; ++i) {
+ for (size_t i = 0; i < 2; ++i) {
Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
- for (int j = 0; j < 16; ++j) {
+ for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
- int sIndex = i * 2 + j * 4;
+ size_t sIndex = i * 2 + j * 4;
makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
}
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bdd..ce81fd859fd02ae 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
return
}
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+ %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ %matrixD128: memref<64x128xf32,3>,
+ %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>,
+ %matrixD32: memref<64x32xf32,3>,
+ %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>,
+ %matrixD64: memref<64x64xf32,3>) {
+
+ // CHECK-COUNT-32: memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+ nvgpu.warpgroup.mma.store %result128, %matrixD128 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ to memref<64x128xf32,3>
+
+
+ // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+ nvgpu.warpgroup.mma.store %result32, %matrixD32 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>>
+ to memref<64x32xf32,3>
+
+ // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+ nvgpu.warpgroup.mma.store %result64, %matrixD64 :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>>
+ to memref<64x64xf32,3>
+ return
+}
+
func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
``````````
</details>
https://github.com/llvm/llvm-project/pull/78413
More information about the Mlir-commits
mailing list