[Mlir-commits] [mlir] [mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (PR #78413)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 17 01:06:10 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When the destionation memref and current accumulator matrix were small, the previous code was reaching out of range.

---
Full diff: https://github.com/llvm/llvm-project/pull/78413.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-3) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5f..9e4ae219eefd60b 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1554,6 +1554,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
     Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
 
+    auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
                                    TypedValue<::mlir::MemRefType> memref) {
       Type it = b.getIndexType();
@@ -1570,11 +1571,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
     if (offset)
       ti = makeAdd(ti, makeConst(offset));
-    for (int i = 0; i < 2; ++i) {
+    for (size_t i = 0; i < 2; ++i) {
       Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
-      for (int j = 0; j < 16; ++j) {
+      for (size_t j = 0; j < (structType.getBody().size() / 8); ++j) {
         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
-        int sIndex = i * 2 + j * 4;
+        size_t sIndex = i * 2 + j * 4;
         makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
       }
     }
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bdd..ce81fd859fd02ae 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,34 @@ func.func @warpgroup_mma_store(
   return 
 }
 
+// CHECK-LABEL: @warpgroup_mma_store_multiplie(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x128xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: memref<64x32xf32, 3>, %[[arg4:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, %[[arg5:[a-zA-Z0-9_]+]]: memref<64x64xf32, 3>)
+func.func @warpgroup_mma_store_multiplie(
+    %result128 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
+    %matrixD128: memref<64x128xf32,3>,
+    %result32 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, 
+    %matrixD32: memref<64x32xf32,3>,
+    %result64 : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, 
+    %matrixD64: memref<64x64xf32,3>) {
+  
+  // CHECK-COUNT-32:  memref.store %{{.*}}, %[[arg1]][%{{.*}}, %{{.*}}] : memref<64x128xf32, 3>
+  nvgpu.warpgroup.mma.store %result128, %matrixD128 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
+    to memref<64x128xf32,3>
+
+
+  // CHECK-COUNT-8: memref.store %{{.*}}, %[[arg3]][%{{.*}}, %{{.*}}] : memref<64x32xf32, 3>
+  nvgpu.warpgroup.mma.store %result32, %matrixD32 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x32xf32>> 
+    to memref<64x32xf32,3>
+
+  // CHECK-COUNT-16: memref.store %{{.*}}, %[[arg5]][%{{.*}}, %{{.*}}] : memref<64x64xf32, 3>
+  nvgpu.warpgroup.mma.store %result64, %matrixD64 : 
+    !nvgpu.warpgroup.accumulator< fragmented = vector<64x64xf32>> 
+    to memref<64x64xf32,3>
+  return 
+}
+
 func.func @warpgroup_mma_init() {
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
   //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>

``````````

</details>


https://github.com/llvm/llvm-project/pull/78413


More information about the Mlir-commits mailing list