[Mlir-commits] [mlir] [mlir][nvgpu] Make `phaseParity` of `mbarrier.try_wait` `i1` (PR #81460)

Guray Ozen llvmlistbot at llvm.org
Mon Feb 12 05:35:21 PST 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/81460

>From fa0f9c2616f1a82df054edaa6280c329cc41ef26 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 10:10:47 +0000
Subject: [PATCH 1/4] [mlir][nvgpu] Make `phaseParity` argument of
 nvgpu.mbarrier.try_wait.parity i1

Currently, `phaseParity` argument of `nvgpu.mbarrier.try_wait.parity` is index. This can cause a problem if it's passed any value different than 0 or 1. Because the PTX instruction only accepts even or odd phase. This PR makes phaseParity argument i1 to avoid misuse.

Here is the information from PTX doc:

```
The .parity variant of the instructions test for the completion of the phase indicated by the operand phaseParity, which is the integer parity of either the current phase or the immediately preceding phase of the mbarrier object. An even phase has integer parity 0 and an odd phase has integer parity of 1. So the valid values of phaseParity operand are 0 and 1.
```
See for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td            | 10 ++++++----
 mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp        |  3 ++-
 .../Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp   |  3 ++-
 mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir    |  4 ++--
 mlir/test/Dialect/NVGPU/tmaload-transform.mlir         |  2 +-
 .../GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir    |  3 ++-
 .../CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir   |  3 ++-
 .../GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir     |  3 ++-
 .../GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir      |  3 ++-
 .../GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir   |  3 ++-
 10 files changed, 23 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a0c0d4cfd8714b..dda8f31e688fe9 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -609,14 +609,16 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
     phase. Suspended thread resumes execution when the specified phase completes 
     OR before the phase completes following a system-dependent time limit. 
 
+    The `$phaseParity` specifies either even phase (0) or odd phase (1) to 
+    wait.
+
     Example:
     ```mlir
-      nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+      nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
     ```
-
   }];
-  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$phase, Index:$ticks, Index:$mbarId);
-  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phase `,` $ticks attr-dict `:` type($barriers)";  
+  let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
+  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";  
 }
 
 def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5080956a458982..9b5d19ebd783a9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -956,7 +956,8 @@ struct NVGPUMBarrierTryWaitParityLowering
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Value ticks = truncToI32(b, adaptor.getTicks());
-    Value phase = truncToI32(b, adaptor.getPhase());
+    Value phase =
+        b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index c81742233e6d0e..1635297a5447d4 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -1010,7 +1010,8 @@ void HopperBuilder::buildBarrierArriveTx(
 
 void HopperBuilder::buildTryWaitParity(
     TypedValue<nvgpu::MBarrierGroupType> barrier) {
-  Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Type i1 = rewriter.getI1Type();
+  Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
   // 10M is an arbitrary, not too small or too big number to specify the number
   // of ticks before retry.
   // TODO: hoist this in a default dialect constant.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 09a873fa331d59..625c5b480c74be 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -590,7 +590,7 @@ func.func @mbarrier_txcount() {
     }
       
 
-    %phase = arith.constant 0 : index
+    %phase = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
     // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
@@ -626,7 +626,7 @@ func.func @mbarrier_txcount_pred() {
     // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
     nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
 
-    %phase = arith.constant 0 : index
+    %phase = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
     // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index 29e300a992d3ad..40acd82cd05585 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -62,7 +62,7 @@ func.func @main() {
     //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
     //      CHECK: }
     //
-    //      CHECK: %[[c0_6:.*]] = arith.constant 0 : index
+    //      CHECK: %[[c0_6:.*]] = llvm.mlir.constant(false) : i1 
     //      CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
     //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
 
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
index 35ca0ee8677cca..c786d621213ec2 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
@@ -197,7 +197,8 @@ func.func @main() {
     {
       %ticks = arith.constant 10000000 : index
       // TMA wait
-      nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
+      %phase = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase, %ticks : !barrierType
       %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
       %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
       // Descriptor WGMMA
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
index 5a10bbba26d8cf..655ec92b104353 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
@@ -206,7 +206,8 @@ func.func @main() {
     {
       %ticks = arith.constant 10000000 : index
       // TMA wait
-      nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
+      %phase = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase, %ticks : !barrierType
       %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
       %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
       // Descriptor WGMMA
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
index 9c5aacf96b0d69..ad91873707a909 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
@@ -93,7 +93,8 @@ module @mymod {
       }
 
       // Step 8. Wait until TMA is done
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : !barrierType
+      %phase = arith.constant 0 : o1
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase, %c10000000 : !barrierType
 
       // Step 9. Print loaded data in 128b swizzled
       scf.if %10 {        
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
index 536e71d260f568..cddd035c1ca802 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
@@ -119,7 +119,8 @@ module @mymod {
       }
 
       // Step 7. Wait until TMA is done
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : !barrierType
+      %phase = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase, %c10000000 : !barrierType
 
       // Step 8. Print loaded data in 128b swizzled
       scf.if %10 {        
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index aee265e3faf175..e68e50fba51601 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -96,7 +96,8 @@ module @mymod {
       } else {
         nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c0 : <memorySpace = #gpu.address_space<workgroup>>
       }
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
+      %phase = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
       scf.if %10 {
         %11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3>
         %12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3>

>From 4e9c9449eccaa98b47023cf069d44208c4291fbb Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 10:32:03 +0000
Subject: [PATCH 2/4] fix

---
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp                     | 3 ++-
 .../Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir | 2 +-
 .../GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir           | 2 +-
 3 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 4b6327479a219c..66118eff6343e7 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -362,7 +362,8 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
                              << kMaxTMADimension << " but it is " << dim;
     }
   }
-  if (descMemref.getRank() > 1) {
+  if (descMemref.getRank() > 1 && descType.getSwizzle() !=
+                                      TensorMapSwizzleKind::SWIZZLE_NONE) {
     unsigned lastDimensionByte =
         descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
     if (lastDimensionByte != kMaxTMALastdimByte)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
index ad91873707a909..272b9cfec0cf8c 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
@@ -93,7 +93,7 @@ module @mymod {
       }
 
       // Step 8. Wait until TMA is done
-      %phase = arith.constant 0 : o1
+      %phase = arith.constant 0 : i1
       nvgpu.mbarrier.try_wait.parity %9[%c0], %phase, %c10000000 : !barrierType
 
       // Step 9. Print loaded data in 128b swizzled
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index e68e50fba51601..7ddb9266bbdcfd 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -97,7 +97,7 @@ module @mymod {
         nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c0 : <memorySpace = #gpu.address_space<workgroup>>
       }
       %phase = arith.constant 0 : i1
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
       scf.if %10 {
         %11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3>
         %12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3>

>From 16bda065c778b451ee29b59ca3ff5a9ca3e228c3 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 10:33:50 +0000
Subject: [PATCH 3/4] type fox

---
 .../GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir            | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index 7ddb9266bbdcfd..e68e50fba51601 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -97,7 +97,7 @@ module @mymod {
         nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c0 : <memorySpace = #gpu.address_space<workgroup>>
       }
       %phase = arith.constant 0 : i1
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
       scf.if %10 {
         %11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3>
         %12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3>

>From f6599d819f79ede1f23b0ef808f673d4445e4ac6 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:35:08 +0000
Subject: [PATCH 4/4] format

---
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 66118eff6343e7..26f831f10a4e40 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -362,8 +362,8 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
                              << kMaxTMADimension << " but it is " << dim;
     }
   }
-  if (descMemref.getRank() > 1 && descType.getSwizzle() !=
-                                      TensorMapSwizzleKind::SWIZZLE_NONE) {
+  if (descMemref.getRank() > 1 &&
+      descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
     unsigned lastDimensionByte =
         descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
     if (lastDimensionByte != kMaxTMALastdimByte)



More information about the Mlir-commits mailing list