[Mlir-commits] [mlir] 21830c9 - [mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (#78413)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 21 23:33:00 PST 2024


Author: Guray Ozen
Date: 2024-01-22T08:32:56+01:00
New Revision: 21830c913505b1fd2cf10e454253483180c7e10b

URL: https://github.com/llvm/llvm-project/commit/21830c913505b1fd2cf10e454253483180c7e10b
DIFF: https://github.com/llvm/llvm-project/commit/21830c913505b1fd2cf10e454253483180c7e10b.diff

LOG: [mlir][nvgpu] Fix 'warpgroup.mma.store' index calculation (#78413)

This PR fixes the 'nvgpu.warpgroup.mma.store' index calculation. When
the destionation memref and current accumulator matrix were small, the
previous code was reaching out of range.

Added: 
    

Modified: 
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 759766275de4a5..ab4dea9d5618d5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1548,12 +1548,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
       return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
     };
 
-    Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
-    Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
-    Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
-    Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
-    Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
-
     auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
                                    TypedValue<::mlir::MemRefType> memref) {
       Type it = b.getIndexType();
@@ -1566,16 +1560,34 @@ struct NVGPUWarpgroupMmaStoreOpLowering
       b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
     };
 
+    Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+    Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
+    Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
+    Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
+    Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4);
+
     Value tj = makeMul(lane4modId, c2);
     Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
     if (offset)
       ti = makeAdd(ti, makeConst(offset));
-    for (int i = 0; i < 2; ++i) {
+
+    auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
+
+    // Number of 32-bit registers owns per thread
+    constexpr unsigned numAdjacentRegisters = 2;
+    // Number of 8x8 matrices one below another per warp
+    constexpr unsigned numStackedMatrices = 2;
+
+    size_t storeCount = (structType.getBody().size() /
+                         (numStackedMatrices * numAdjacentRegisters));
+
+    for (size_t i = 0; i < numStackedMatrices; ++i) {
       Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
-      for (int j = 0; j < 16; ++j) {
+      for (size_t j = 0; j < storeCount; ++j) {
         Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
-        int sIndex = i * 2 + j * 4;
-        makeExtractAndStore(sIndex, matrixD, idx, idy, dstMemref);
+        size_t structIndex = (i * numAdjacentRegisters) +
+                             (j * (numStackedMatrices * numAdjacentRegisters));
+        makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref);
       }
     }
   }

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603bd..b495363e228d8f 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1055,6 +1055,136 @@ func.func @warpgroup_mma_store(
   return 
 }
 
+// CHECK-LABEL: @warpgroup_mma_store_multiple
+func.func @warpgroup_mma_store_multiple(
+    %shmem_m64n8k : memref<64x8xf32>, 
+    %shmem_m64n16k : memref<64x16xf32>, 
+    %shmem_m64n24k : memref<64x24xf32>, 
+    %shmem_m64n32k : memref<64x32xf32>, 
+    %shmem_m64n40k : memref<64x40xf32>, 
+    %shmem_m64n48k : memref<64x48xf32>, 
+    %shmem_m64n56k : memref<64x56xf32>, 
+    %shmem_m64n64k : memref<64x64xf32>, 
+    %shmem_m64n72k : memref<64x72xf32>, 
+    %shmem_m64n80k : memref<64x80xf32>, 
+    %shmem_m64n88k : memref<64x88xf32>, 
+    %shmem_m64n96k : memref<64x96xf32>, 
+    %shmem_m64n104k : memref<64x104xf32>, 
+    %shmem_m64n112k : memref<64x112xf32>, 
+    %shmem_m64n120k : memref<64x120xf32>, 
+    %shmem_m64n128k : memref<64x128xf32>, 
+    %shmem_m64n136k : memref<64x136xf32>, 
+    %shmem_m64n144k : memref<64x144xf32>, 
+    %shmem_m64n152k : memref<64x152xf32>, 
+    %shmem_m64n160k : memref<64x160xf32>, 
+    %shmem_m64n168k : memref<64x168xf32>, 
+    %shmem_m64n176k : memref<64x176xf32>, 
+    %shmem_m64n184k : memref<64x184xf32>, 
+    %shmem_m64n192k : memref<64x192xf32>, 
+    %shmem_m64n200k : memref<64x200xf32>, 
+    %shmem_m64n208k : memref<64x208xf32>, 
+    %shmem_m64n216k : memref<64x216xf32>, 
+    %shmem_m64n224k : memref<64x224xf32>, 
+    %shmem_m64n232k : memref<64x232xf32>, 
+    %shmem_m64n240k : memref<64x240xf32>, 
+    %shmem_m64n248k : memref<64x248xf32>, 
+    %shmem_m64n256k : memref<64x256xf32>, 
+    %res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>, 
+    %res_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>>, 
+    %res_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>>, 
+    %res_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>>, 
+    %res_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>>, 
+    %res_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>>, 
+    %res_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>>, 
+    %res_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>>, 
+    %res_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>>, 
+    %res_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>>, 
+    %res_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>>, 
+    %res_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>>, 
+    %res_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>>, 
+    %res_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>>, 
+    %res_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
+    %res_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>>, 
+    %res_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>>, 
+    %res_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>>, 
+    %res_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>>, 
+    %res_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>>, 
+    %res_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>>, 
+    %res_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>>, 
+    %res_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>>, 
+    %res_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>>, 
+    %res_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>>, 
+    %res_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>>, 
+    %res_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>>, 
+    %res_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>>, 
+    %res_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>>, 
+    %res_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>>, 
+    %res_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>>) {
+    // CHECK-COUNT-8: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32>
+    // CHECK-COUNT-12: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x24xf32>
+    // CHECK-COUNT-16: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x32xf32>
+    // CHECK-COUNT-20: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x40xf32>
+    // CHECK-COUNT-24: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x48xf32>
+    // CHECK-COUNT-28: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x56xf32>
+    // CHECK-COUNT-32: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf32>
+    // CHECK-COUNT-36: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x72xf32>
+    // CHECK-COUNT-40: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x80xf32>
+    // CHECK-COUNT-44: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x88xf32>
+    // CHECK-COUNT-48: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x96xf32>
+    // CHECK-COUNT-52: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x104xf32>
+    // CHECK-COUNT-56: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x112xf32>
+    // CHECK-COUNT-60: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x120xf32>
+    // CHECK-COUNT-64: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x128xf32>
+    // CHECK-COUNT-68: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x136xf32>
+    // CHECK-COUNT-72: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x144xf32>
+    // CHECK-COUNT-76: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x152xf32>
+    // CHECK-COUNT-80: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x160xf32>
+    // CHECK-COUNT-84: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x168xf32>
+    // CHECK-COUNT-88: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x176xf32>
+    // CHECK-COUNT-92: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x184xf32>
+    // CHECK-COUNT-96: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x192xf32>
+    // CHECK-COUNT-100: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x200xf32>
+    // CHECK-COUNT-104: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x208xf32>
+    // CHECK-COUNT-108: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x216xf32>
+    // CHECK-COUNT-112: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x224xf32>
+    // CHECK-COUNT-116: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x232xf32>
+    // CHECK-COUNT-120: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x240xf32>
+    // CHECK-COUNT-124: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x248xf32>
+    // CHECK-COUNT-128: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x256xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n16k, %shmem_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n24k, %shmem_m64n24k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x24xf32>> to memref<64x24xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n32k, %shmem_m64n32k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x32xf32>> to memref<64x32xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n40k, %shmem_m64n40k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x40xf32>> to memref<64x40xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n48k, %shmem_m64n48k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x48xf32>> to memref<64x48xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n56k, %shmem_m64n56k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x56xf32>> to memref<64x56xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n64k, %shmem_m64n64k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x64xf32>> to memref<64x64xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n72k, %shmem_m64n72k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x72xf32>> to memref<64x72xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n80k, %shmem_m64n80k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x80xf32>> to memref<64x80xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n88k, %shmem_m64n88k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x88xf32>> to memref<64x88xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n96k, %shmem_m64n96k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x96xf32>> to memref<64x96xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n104k, %shmem_m64n104k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x104xf32>> to memref<64x104xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n112k, %shmem_m64n112k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x112xf32>> to memref<64x112xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n120k, %shmem_m64n120k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x120xf32>> to memref<64x120xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n128k, %shmem_m64n128k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to memref<64x128xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n136k, %shmem_m64n136k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x136xf32>> to memref<64x136xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n144k, %shmem_m64n144k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x144xf32>> to memref<64x144xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n152k, %shmem_m64n152k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x152xf32>> to memref<64x152xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n160k, %shmem_m64n160k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x160xf32>> to memref<64x160xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n168k, %shmem_m64n168k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x168xf32>> to memref<64x168xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n176k, %shmem_m64n176k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x176xf32>> to memref<64x176xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n184k, %shmem_m64n184k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x184xf32>> to memref<64x184xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n192k, %shmem_m64n192k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x192xf32>> to memref<64x192xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n200k, %shmem_m64n200k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x200xf32>> to memref<64x200xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n208k, %shmem_m64n208k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x208xf32>> to memref<64x208xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n216k, %shmem_m64n216k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x216xf32>> to memref<64x216xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n224k, %shmem_m64n224k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x224xf32>> to memref<64x224xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n232k, %shmem_m64n232k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x232xf32>> to memref<64x232xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n240k, %shmem_m64n240k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x240xf32>> to memref<64x240xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n248k, %shmem_m64n248k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x248xf32>> to memref<64x248xf32>
+    nvgpu.warpgroup.mma.store  %res_m64n256k, %shmem_m64n256k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x256xf32>> to memref<64x256xf32>
+  return 
+}
+
 func.func @warpgroup_mma_init() {
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
   //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>


        


More information about the Mlir-commits mailing list