[Mlir-commits] [mlir] [mlir][nvgpu] Select warpgroup id on `warpgroup.mma.store` (PR #85820)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 19 09:59:20 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

`warpgroup.mma.store` Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0.

This PR adds a new operand to `warpgroup.mma.store` Op that allows selecting different warpgroup. For example:
```
nvgpu.warpgroup.mma.store 
%res_m64n16k, 
%shmem_m64n16k, 
warpgroup_id = %id // <-- PR adds this one
: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> 
  to memref<64x16xf32,3>
```

---
Full diff: https://github.com/llvm/llvm-project/pull/85820.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+8-4) 
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+3) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+18-12) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+46) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe9..d22b8fd28582c7 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -775,17 +775,21 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
     The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
     in $matrixD to given memref. 
 
+    Note that, the op must be run with warp group. The operand `warpgroupId` 
+    allow to select the warp group to run the operation. When it is not present,
+    the first warp group runs the operation.
+
     [See the details of register fragment layout for accumulator matrix D]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
-
-    Note that, the op must be run with warp group.
   }];
 
   let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
-                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref,
+                       Optional<I32>:$warpgroupId);
   
   let assemblyFormat = [{
-    $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+    $matrixD `,` $dstMemref (`,` `warpgroup_id` `=` $warpgroupId^)? 
+    attr-dict `:` type($matrixD) `to` type($dstMemref)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 19070f6f062a02..0d16b825821252 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
 
 constexpr int kWarpSize = 32;
 
+/// Number of threads in warpgroup
+constexpr int kWarpgroupSize = 128;
+
 /// M size of wgmma.mma_async instruction
 constexpr int kWgmmaSizeM = 64;
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9b5d19ebd783a9..5a8f095b84c791 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1087,10 +1087,9 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
     // // [0,14)   start_address
     dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
 
-    LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
-                      << "leading_off:" << leadDimVal << "\t"
-                      << "stride_off :" << strideDimVal << "\t"
-                      << "base_offset:" << offsetVal << "\t"
+    LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:"
+                      << leadDimVal << "\t" << "stride_off :" << strideDimVal
+                      << "\t" << "base_offset:" << offsetVal << "\t"
                       << "layout_type:" << swizzle << " ("
                       << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
                       << ")\n start_addr :  " << baseAddr << "\n");
@@ -1382,13 +1381,12 @@ struct NVGPUWarpgroupMmaOpLowering
     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
     /// descriptors and arranges them based on induction variables: i, j, and k.
     Value generateWgmma(int i, int j, int k, Value matrixC) {
-      LLVM_DEBUG(DBGS() << "\t wgmma."
-                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
-                        << "(A[" << (iterationM * wgmmaM) << ":"
+      LLVM_DEBUG(DBGS() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k"
+                        << wgmmaK << "(A[" << (iterationM * wgmmaM) << ":"
                         << (iterationM * wgmmaM) + wgmmaM << "]["
                         << (iterationK * wgmmaK) << ":"
-                        << (iterationK * wgmmaK + wgmmaK) << "] * "
-                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+                        << (iterationK * wgmmaK) << ":"
                         << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
                         << wgmmaN << "])\n");
 
@@ -1535,8 +1533,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   /// \param offset: the offset within the memref where the registers will be
   /// stored.
   void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
-                             TypedValue<MemRefType> dstMemref,
-                             int offset) const {
+                             TypedValue<MemRefType> dstMemref, int offset,
+                             Value warpgroupId) const {
     Type i32 = b.getI32Type();
 
     auto makeConst = [&](int32_t index) -> Value {
@@ -1569,6 +1567,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     };
 
     Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+    // Normalize the thread index to the beginning of the warpgroup
+    if (warpgroupId) {
+      Value s1 =
+          b.create<arith::MulIOp>(warpgroupId, makeConst(kWarpgroupSize));
+      tidx = b.create<arith::SubIOp>(tidx, s1);
+    }
+
     Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
     Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
@@ -1610,7 +1615,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
       auto structType = matrixD.cast<LLVM::LLVMStructType>();
       Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
-      storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
+      storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset,
+                            adaptor.getWarpgroupId());
       offset += structType.getBody().size();
     }
     rewriter.eraseOp(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dbf8ead49f78db..ae9fcd287bdded 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1182,6 +1182,52 @@ func.func @warpgroup_mma_store_multiple(
   return 
 }
 
+// CHECK-LABEL: @warpgroup_mma_store_multiple_with_id(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x16xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: i32)
+func.func @warpgroup_mma_store_multiple_with_id(
+  %res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
+  %shmem_m64n16k : memref<64x16xf32, 3>, 
+  %id : i32) 
+{    
+  // CHECK: %[[s0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
+  // CHECK: %[[s1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : memref<64x16xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)> 
+  // CHECK: %[[s3:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[s5:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[s6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[s7:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[s8:.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: %[[s9:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+  // CHECK: %[[s10:.*]] = llvm.mlir.constant(128 : i32) : i32
+  // CHECK: %[[s11:.*]] = arith.muli %[[arg2]], %[[s10]] : i32
+  // CHECK: %[[s12:.*]] = arith.subi %[[s9]], %[[s11]] : i32
+  // CHECK: %[[s13:.*]] = llvm.urem %12, %8  : i32
+  // CHECK: %[[s14:.*]] = llvm.udiv %[[s12]], %[[s8]]  : i32
+  // CHECK: %[[s15:.*]] = llvm.udiv %[[s13]], %[[s5]]  : i32
+  // CHECK: %[[s16:.*]] = llvm.urem %[[s13]], %[[s5]]  : i32
+  // CHECK: %[[s17:.*]] = llvm.mul %[[s16]], %[[s4]]  : i32
+  // CHECK: %[[s18:.*]] = llvm.mul %[[s14]], %[[s7]]  : i32
+  // CHECK: %[[s19:.*]] = llvm.add %[[s15]], %[[s18]]  : i32
+  // CHECK: %[[s20:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[s21:.*]] = llvm.mul %[[s20]], %[[s6]]  : i32
+  // CHECK: %[[s22:.*]] = llvm.add %[[s19]], %[[s21]]  : i32
+  // CHECK: %[[s23:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[s24:.*]] = llvm.mul %[[s23]], %[[s6]]  : i32
+  // CHECK: %[[s25:.*]] = llvm.add %[[s17]], %[[s24]]  : i32
+  // CHECK: %[[s26:.*]] = arith.index_cast %[[s22]] : i32 to index
+  // CHECK: %[[s27:.*]] = arith.index_cast %[[s25]] : i32 to index
+  // CHECK: %[[s28:.*]] = llvm.add %[[s25]], %[[s3]]  : i32
+  // CHECK: %[[s29:.*]] = arith.index_cast %[[s28]] : i32 to index
+  // CHECK: %[[s30:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[s31:.*]] = llvm.extractvalue %[[s2]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: memref.store %[[s30]], %[[arg1]][%[[s26]], %[[s27]]] : memref<64x16xf32, 3>
+  // CHECK: memref.store %[[s31]], %[[arg1]][%[[s26]], %[[s29]]] : memref<64x16xf32, 3>
+  // CHECK-COUNT-6: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32, 3>
+  nvgpu.warpgroup.mma.store  %res_m64n16k, %shmem_m64n16k, warpgroup_id = %id : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32,3>
+  return
+}
+
 func.func @warpgroup_mma_init() {
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
   //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>

``````````

</details>


https://github.com/llvm/llvm-project/pull/85820


More information about the Mlir-commits mailing list