[Mlir-commits] [mlir] [mlir][nvgpu] Add predicate argument to NVGPU Ops (PR #69322)

Guray Ozen llvmlistbot at llvm.org
Tue Oct 17 04:47:50 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/69322

#67102 introduced predication support in BasicPtxBuilderInterface. The predication is available for any NVVM ops just like PTX.

This PR introduces predicate arguments to the following NVGPU Ops. We pass this argument to the BasicPtxBuilderInterface.

- mbarrier.init
- mbarrier.arrive.expect_tx
- tma.async.load

>From 680e5bb7225956e2c780bfa4c8742813fc385d40 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 17 Oct 2023 13:47:20 +0200
Subject: [PATCH] [mlir][nvgpu] Add predicate argument to NVGPU Ops

#67102 introduced predication support in BasicPtxBuilderInterface. The predication is available for any NVVM ops just like PTX.

This PR introduces predicate arguments to the following NVGPU Ops. We pass this argument to the BasicPtxBuilderInterface.

- mbarrier.init
- mbarrier.arrive.expect_tx
- tma.async.load
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   | 22 ++++---
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 13 ++--
 .../NVGPU/TransformOps/NVGPUTransformOps.cpp  |  8 ++-
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 64 ++++++++++++++++++-
 4 files changed, 88 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dd00355b6d77e33..440f7d0380eb17e 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -522,8 +522,8 @@ def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier.init", []> {
       nvgpu.mbarrier.init %barrier, %num_threads : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
     ```
   }];
-  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$count, Index:$mbarId);
-  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count attr-dict `:` type($barriers)";
+  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$count, Index:$mbarId, Optional<I1>:$predicate);
+  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)";
 }
 
 def NVGPU_MBarrierTestWaitOp : NVGPU_Op<"mbarrier.test.wait", []> {
@@ -597,8 +597,8 @@ def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> {
       nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
     ```
   }];
-  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$txcount, Index:$mbarId);
-  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $txcount  attr-dict `:` type($barriers)";
+  let arguments = (ins NVGPU_MBarrierGroup:$barriers, Index:$txcount, Index:$mbarId, Optional<I1>:$predicate);
+  let assemblyFormat = "$barriers `[` $mbarId `]` `,` $txcount  (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)";
 }
 
 def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
@@ -627,11 +627,11 @@ def NVGPU_TmaPrefetchOp : NVGPU_Op<"tma.prefetch.descriptor", []> {
   }];
   let arguments = (ins NVGPU_TensorMapDescriptor:$tensorMapDescriptor, Optional<I1>:$predicate);
   let assemblyFormat = [{
-    $tensorMapDescriptor (`,` $predicate^)? attr-dict `:` type($tensorMapDescriptor)
+    $tensorMapDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type($tensorMapDescriptor)
   }];
 }
 
-def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
+def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]> {
   let summary = "TMA asynchronous load";
   let description = [{
     The Op loads a tile memory region from global memory to shared memory by 
@@ -646,10 +646,14 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> {
                        NVGPU_MBarrierGroup:$barriers,
                        NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
                        Variadic<Index>:$coordinates, 
-                       Index:$mbarId);
+                       Index:$mbarId,
+                       Optional<I1>:$predicate);
   let assemblyFormat = [{
-    $tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]` `to` $dst
-      attr-dict `:` type($tensorMapDescriptor) `,` type($barriers) `->` type($dst)
+    $tensorMapDescriptor `[` $coordinates `]` `,` $barriers `[` $mbarId `]` 
+      `to` $dst
+      (`,` `predicate` `=` $predicate^)? 
+      attr-dict `:` type($tensorMapDescriptor) `,` type($barriers) 
+      `->` type($dst)
   }];
   let hasVerifier = 1;
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 7eb6f42d2788e35..efcde2ba58bd685 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -830,11 +830,11 @@ struct NVGPUMBarrierInitLowering
                                    adaptor.getMbarId(), rewriter);
     Value count = truncToI32(b, adaptor.getCount());
     if (isMbarrierShared(mbarrierType)) {
-      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
-                                                              count, Value());
+      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
+          op, barrier, count, adaptor.getPredicate());
     } else {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
-                                                        Value());
+                                                        adaptor.getPredicate());
     }
     return success();
   }
@@ -929,12 +929,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
-          op, barrier, txcount, Value());
+          op, barrier, txcount, adaptor.getPredicate());
       return success();
     }
 
     rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
-        op, barrier, txcount, Value());
+        op, barrier, txcount, adaptor.getPredicate());
     return success();
   }
 };
@@ -985,7 +985,8 @@ struct NVGPUTmaAsyncLoadOpLowering
     }
 
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
-        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
+        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords,
+        adaptor.getPredicate());
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index eaaadbbea4d0a75..408c1dc798feeb4 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -922,7 +922,7 @@ HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   rewriter.create<nvgpu::MBarrierInitOp>(
       loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
-      zero);
+      zero, Value());
   rewriter.create<gpu::BarrierOp>(loc);
   return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
 }
@@ -964,7 +964,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
   MLIRContext *ctx = rewriter.getContext();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
-      loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero);
+      loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
+      Value());
   loadOps.push_back(loadOp);
   auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
   SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -989,7 +990,8 @@ void HopperBuilder::buildBarrierArriveTx(
       affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
   Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero);
+  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
+                                                   Value());
 }
 
 void HopperBuilder::buildTryWaitParity(
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index a344578def39e06..c7d28e7443695fc 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -600,6 +600,42 @@ func.func @mbarrier_txcount() {
     func.return 
 }
 
+// CHECK-LABEL: func @mbarrier_txcount_pred
+func.func @mbarrier_txcount_pred() {
+    %mine = arith.constant 1 : index
+    // CHECK: %[[c0:.+]] = arith.constant 0 : index
+    // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
+    // CHECK: %[[S2:.+]] = gpu.thread_id  x
+    // CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
+    %c0 = arith.constant 0 : index  
+    %tidx = gpu.thread_id x
+    %pred = arith.cmpi eq, %tidx, %c0 : index
+
+    // CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier{{.*}} : memref<1xi64, 3>
+    %barrier = nvgpu.mbarrier.create -> !barrierType
+
+    // CHECK: %[[barStr:.+]] =  builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
+    // CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
+    nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
+    
+    %txcount = arith.constant 256 : index
+    // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
+    // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
+    nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
+
+    %phase = arith.constant 0 : index
+    %ticks = arith.constant 10000000 : index
+    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
+    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase, %ticks : !barrierType
+
+    func.return 
+}
+
 // CHECK-LABEL: func @async_tma_load
 !tensorMap1d = !nvgpu.tensormap.descriptor<tensor = memref<128xf32,3>,         swizzle=none,        l2promo = none,        oob = nan,  interleave = none>
 !tensorMap2d = !nvgpu.tensormap.descriptor<tensor = memref<32x32xf32,3>,       swizzle=swizzle_32b, l2promo = none,        oob = zero, interleave = none>
@@ -630,6 +666,32 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
   func.return 
 }
 
+// CHECK-LABEL: func @async_tma_load_pred
+func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                              %buffer1d: memref<128xf32,3>,      
+                              %buffer2d: memref<32x32xf32,3>,    
+                              %buffer3d: memref<2x32x32xf32,3>,  
+                              %buffer4d: memref<2x2x32x32xf32,3>,  
+                              %buffer5d: memref<2x2x2x32x32xf32,3>,
+                              %mbarrier: !mbarrier,
+                              %p: i1) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d, predicate = %p : !tensorMap1d, !mbarrier -> memref<128xf32,3>
+  // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.load %tensorMap2d[%crd0, %crd1], %mbarrier[%c0] to %buffer2d, predicate = %p : !tensorMap2d, !mbarrier -> memref<32x32xf32,3>
+  // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.load %tensorMap3d[%crd0, %crd1, %crd0], %mbarrier[%c0] to %buffer3d, predicate = %p : !tensorMap3d, !mbarrier -> memref<2x32x32xf32,3>
+  // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.load %tensorMap4d[%crd0, %crd1, %crd1, %crd0], %mbarrier[%c0] to %buffer4d, predicate = %p : !tensorMap4d, !mbarrier -> memref<2x2x32x32xf32,3>
+  // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.load %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], %mbarrier[%c0] to %buffer5d, predicate = %p : !tensorMap5d, !mbarrier -> memref<2x2x2x32x32xf32,3>
+  func.return 
+}
+
+
 func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
   %crd0 = arith.constant 64 : index
   %crd1 = arith.constant 128 : index
@@ -650,7 +712,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
   // CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
   nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
   // CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
-  nvgpu.tma.prefetch.descriptor %tensorMap1d, %p: !tensorMap1d
+  nvgpu.tma.prefetch.descriptor %tensorMap1d, predicate = %p: !tensorMap1d
   func.return
 }
 



More information about the Mlir-commits mailing list