[Mlir-commits] [mlir] 0a600c3 - [mlir][nvgpu] Make `phaseParity` of `mbarrier.try_wait` `i1` (#81460)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 00:50:38 PST 2024


Author: Guray Ozen
Date: 2024-02-13T09:50:34+01:00
New Revision: 0a600c34c8c1fe87c9661b6020e5044b24da3dc7

URL: https://github.com/llvm/llvm-project/commit/0a600c34c8c1fe87c9661b6020e5044b24da3dc7
DIFF: https://github.com/llvm/llvm-project/commit/0a600c34c8c1fe87c9661b6020e5044b24da3dc7.diff

LOG: [mlir][nvgpu] Make `phaseParity` of `mbarrier.try_wait` `i1` (#81460)

Currently, `phaseParity` argument of `nvgpu.mbarrier.try_wait.parity` is
index. This can cause a problem if it's passed any value different than
0 or 1. Because the PTX instruction only accepts even or odd phase. This
PR makes phaseParity argument i1 to avoid misuse.

Here is the information from PTX doc:

```
The .parity variant of the instructions test for the completion of the phase indicated 
by the operand phaseParity, which is the integer parity of either the current phase or 
the immediately preceding phase of the mbarrier object. An even phase has integer 
parity 0 and an odd phase has integer parity of 1. So the valid values of phaseParity 
operand are 0 and 1.
```
See for more information:

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Dialect/NVGPU/tmaload-transform.mlir
    mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
    mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
    mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a0c0d4cfd8714b..dda8f31e688fe9 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -609,14 +609,16 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
     phase. Suspended thread resumes execution when the specified phase completes 
     OR before the phase completes following a system-dependent time limit. 
 
+    The `$phaseParity` specifies either even phase (0) or odd phase (1) to 
+    wait.
+
     Example:
     ```mlir
-      nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+      nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
     ```
-
   }];
-  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$phase, Index:$ticks, Index:$mbarId);
-  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phase `,` $ticks attr-dict `:` type($barriers)";  
+  let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
+  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";  
 }
 
 def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5080956a458982..9b5d19ebd783a9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -956,7 +956,8 @@ struct NVGPUMBarrierTryWaitParityLowering
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Value ticks = truncToI32(b, adaptor.getTicks());
-    Value phase = truncToI32(b, adaptor.getPhase());
+    Value phase =
+        b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(

diff  --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index c81742233e6d0e..1635297a5447d4 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -1010,7 +1010,8 @@ void HopperBuilder::buildBarrierArriveTx(
 
 void HopperBuilder::buildTryWaitParity(
     TypedValue<nvgpu::MBarrierGroupType> barrier) {
-  Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  Type i1 = rewriter.getI1Type();
+  Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
   // 10M is an arbitrary, not too small or too big number to specify the number
   // of ticks before retry.
   // TODO: hoist this in a default dialect constant.

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 09a873fa331d59..dbf8ead49f78db 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -590,12 +590,12 @@ func.func @mbarrier_txcount() {
     }
       
 
-    %phase = arith.constant 0 : index
+    %phase_c0 = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
     // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
-    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
+    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
 
     func.return 
 }
@@ -626,12 +626,12 @@ func.func @mbarrier_txcount_pred() {
     // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
     nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
 
-    %phase = arith.constant 0 : index
+    %phase_c0 = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
     // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
-    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
+    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
 
     func.return 
 }

diff  --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index 29e300a992d3ad..40acd82cd05585 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -62,7 +62,7 @@ func.func @main() {
     //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
     //      CHECK: }
     //
-    //      CHECK: %[[c0_6:.*]] = arith.constant 0 : index
+    //      CHECK: %[[c0_6:.*]] = llvm.mlir.constant(false) : i1 
     //      CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
     //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
 

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
index 35ca0ee8677cca..51bcf459d83d5c 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir
@@ -197,7 +197,8 @@ func.func @main() {
     {
       %ticks = arith.constant 10000000 : index
       // TMA wait
-      nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
+      %phase_c0 = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
       %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
       %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
       // Descriptor WGMMA

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
index 5a10bbba26d8cf..85bdb38d67f0fe 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir
@@ -206,7 +206,8 @@ func.func @main() {
     {
       %ticks = arith.constant 10000000 : index
       // TMA wait
-      nvgpu.mbarrier.try_wait.parity %barrier[%i], %c0, %ticks : !barrierType
+      %phase_c0 = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %barrier[%i], %phase_c0, %ticks : !barrierType
       %lhsSlice = memref.subview %lhsShmem [%i, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: ?>, 3>
       %rhsSlice = memref.subview %rhsShmem [%i, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: ?>, 3>
       // Descriptor WGMMA

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
index 9c5aacf96b0d69..b50772f8249fb7 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x64_swizzle128b.mlir
@@ -93,7 +93,8 @@ module @mymod {
       }
 
       // Step 8. Wait until TMA is done
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : !barrierType
+      %phase_c0 = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase_c0, %c10000000 : !barrierType
 
       // Step 9. Print loaded data in 128b swizzled
       scf.if %10 {        

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
index 536e71d260f568..65e5fc0aff6aa3 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x64_swizzle128b.mlir
@@ -119,7 +119,8 @@ module @mymod {
       }
 
       // Step 7. Wait until TMA is done
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : !barrierType
+      %phase_c0 = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase_c0, %c10000000 : !barrierType
 
       // Step 8. Print loaded data in 128b swizzled
       scf.if %10 {        

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
index aee265e3faf175..2e59b7234e53df 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_64x8_8x128_noswizzle.mlir
@@ -96,7 +96,8 @@ module @mymod {
       } else {
         nvgpu.mbarrier.arrive.expect_tx %9[%c0], %c0 : <memorySpace = #gpu.address_space<workgroup>>
       }
-      nvgpu.mbarrier.try_wait.parity %9[%c0], %c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
+      %phase_c0 = arith.constant 0 : i1
+      nvgpu.mbarrier.try_wait.parity %9[%c0], %phase_c0, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
       scf.if %10 {
         %11 = memref.load %7[%c45, %c7] : memref<64x8xf32, 3>
         %12 = memref.load %8[%c7, %c0] : memref<8x128xf32, 3>


        


More information about the Mlir-commits mailing list