[Mlir-commits] [mlir] [mlir][nvvm] Support predicates in `BasicPtxBuilder` (PR #67102)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 22 01:56:19 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

<details>
<summary>Changes</summary>

This PR enhances `BasicPtxBuilder` to support predicates in PTX code generation. The `BasicPtxBuilder` interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism.

In PTX programming, instructions can be guarded by predicates as shown below:. Here `@<!-- -->p` is a predicate register and guard the execution of the instruction.

```
@<!-- -->p ptx.code op1, op2, op3
```

This PR introduces the `getPredicate` function in the `BasicPtxBuilder` interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition.

Additionally, this PR implements predicate usage for the following ops:

- mbarrier.init
- mbarrier.init.shared
- mbarrier.arrive.expect_tx
- mbarrier.arrive.expect_tx.shared
- cp.async.bulk.tensor.shared.cluster.global
- cp.async.bulk.tensor.global.shared.cta

See for more detail in PTX programing model
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions

---

Patch is 27.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67102.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+93-47) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+8-6) 
- (modified) mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp (+9) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+46-12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a528e015523e174..a1cfb305d8d5e50 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -86,6 +86,8 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
 // Basic PTX Builder Interface
 //===----------------------------------------------------------------------===//
 
+def PtxPredicate : Optional<I1>;
+
 // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
 def Read : I32EnumAttrCase<"Read", 0, "read">;
 def Write : I32EnumAttrCase<"Write", 2, "write">;
@@ -118,36 +120,49 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
     is started from the results and they are used as write, followed by the 
     operands and attributes.
 
+    `getPredicate` is an optional function for setting a predicate, which 
+    always returns a `PtxPredicate` value of type i1. If no predicate is 
+    provided, the instruction is unguarded; otherwise, it's guarded by the 
+    predicate value. The `PtxPredicate` value must always be the last argument. 
+    The provided PTX code by `getPtx` should not include the predicate usage.
+    The interface automatically handles predicate usage in the generated
+    PTX code when necessary.
+
     Example:
     If we have following Op definition that returns PTX code by `getPtx`. 
     
     ```tablegen
-      def NVVM_MyOp : NVVM_Op<"myop",
-          [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-        Results<(outs LLVM_Type:$res)>,
-        Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
-        ...
+      def NVVM_OpCode : NVVM_PTXBuilder_Op<"opcode">,  
+        Arguments<(ins I32:$op1, I32:$op2, PtxPredicate:$predicate)> {
+        let assemblyFormat = "$op1 `,` $op2 (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
         let extraClassDefinition = [{
-          std::string $cppClass::getPtx() { 
-            return std::string("my.ptx.code %0, %1, %2;"); 
-          }
-      } ];
+          std::string $cppClass::getPtx() { return std::string("opcode [%0], %1"); }
+        } ];
+      }
     ```
 
-    The NVVM Op will look like below:
+    The NVVM Op can be look like one of these:
     ```mlir
-      %0 = my.ptx.code %1, %2 : !llvm.ptr, i32 -> i32
+      %s1 = nvvm.opcode %1, %2 : i32, i32
+      %s2 = nvvm.opcode %1, %2, predicate = %p : i32, i32, i1
     ```
 
-    The `convert-nvvm-to-llvm` Pass generates the PTX code below. The order of 
-    arguments are kept the same. The read and write modifiers are set based on
-    the input and result types.
+    The `convert-nvvm-to-llvm` Pass generates PTX code with preserved argument 
+    order and sets read and write modifiers based on input and result types.
     ```mlir
-      %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.ptx.code %0, %1, %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+      %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.opcode %0, %1;", "r, r" %0, %1 : (i32, i32)
+      %0 = llvm.inline_asm has_side_effects asm_dialect = att "@%2 my.opcode %0, %1;", "r, r, b" %0, %1, %p : (i32, i32, i1)
     ```
-
   }];
   let methods = [
+    InterfaceMethod<
+        /*desc=*/[{Returns an optional predicate value.}],
+        /*retType=*/"std::optional<::mlir::Value>",
+        /*methodName=*/"getPredicate",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
     InterfaceMethod<
         /*desc=*/[{
           Returns whether the operation has intrinsic support in LLVM.
@@ -211,6 +226,12 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
   ];
 }
 
+/// Base class that defines BasicPtxBuilderOpInterface. 
+class NVVM_PTXBuilder_Op<string mnemonic, 
+  list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
+  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
@@ -334,21 +355,31 @@ def NVVM_ReduxOp :
 //===----------------------------------------------------------------------===//
 
 /// mbarrier.init instruction with generic pointer type
-def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
-  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
+def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
   }];
-  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let extraClassDeclaration = [{
+    bool hasIntrinsic() { if(getPredicate()) return false; return true; }
+  }];
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
+  }];
 }
 
 /// mbarrier.init instruction with shared pointer type
-def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
-  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
+def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
   }];
-  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
+  }];
 }
 
 def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -403,26 +434,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,  
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,  
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {    
+  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {  
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -441,8 +469,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
   }];
 }
 
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {  
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -596,7 +623,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
 
 def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
 
-def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
   Arguments<(ins LLVM_i8Ptr_shared:$dst,
                  LLVM_i8Ptr_global:$src,
                  I32Attr:$size,
@@ -1467,12 +1494,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_i64ptr_shared:$dstMem,
                   LLVM_i64ptr_any:$tmaDescriptor,
                   LLVM_i64ptr_shared:$mbar,
-                  Variadic<I32>:$coordinates)> {
-  let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       int dim = getCoordinates().size();
@@ -1490,11 +1529,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
   let hasVerifier = 1;
 }
 
-def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
+  NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_i64ptr_any:$tmaDescriptor,
                   LLVM_i64ptr_shared:$srcMem,
-                  Variadic<I32>:$coordinates)> {
-  let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $tmaDescriptor `,` 
+    $srcMem `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)?  
+    attr-dict  `:` type(operands)
+  }];
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       int dim = getCoordinates().size();
@@ -1516,8 +1565,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -1531,8 +1579,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
   let description = [{
@@ -1545,8 +1592,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
   let arguments = (ins I32Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..1f866039e20e564 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include <optional>
 
 #define DEBUG_TYPE "nvgpu-to-nvvm"
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -793,9 +794,10 @@ struct NVGPUMBarrierInitLowering
 
     if (isMbarrierShared(op.getBarrier().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
-                                                              count);
+                                                              count, Value());
     } else {
-      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count);
+      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+                                                        Value());
     }
     return success();
   }
@@ -886,12 +888,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
 
     if (isMbarrierShared(op.getBarrier().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
-          op, barrier, txcount);
+          op, barrier, txcount, Value());
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
-                                                                txcount);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
+        op, barrier, txcount, Value());
     return success();
   }
 };
@@ -939,7 +941,7 @@ struct NVGPUTmaAsyncLoadOpLowering
     }
 
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
-        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
+        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 2d7a441e950045c..df3b1850e8d343f 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -63,6 +63,8 @@ class PtxBuilder {
 
   // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
   char getRegisterType(Type type) {
+    if (type.isInteger(1))
+      return 'b';
     if (type.isInteger(16))
       return 'h';
     if (type.isInteger(32))
@@ -158,6 +160,13 @@ class PtxBuilder {
         asmConstraints[asmConstraints.size() - 1] == ',')
       asmConstraints.pop_back();
 
+    // Add the predicate to the asm string.
+    if (op.getPredicate().has_value() && op.getPredicate().value()) {
+      std::string predicateStr = "@%";
+      predicateStr += std::to_string((asmVals.size() - 1));
+      asmStr = predicateStr + " " + asmStr;
+    }
+
     // asm keywords expects %, but inline assembly uses $. Replace all % with $
     std::replace(asmStr.begin(), asmStr.end(), '%', '$');
 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 7ffe1ad2bb2b111..228f249db0a0700 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -4,17 +4,30 @@
 // and the generic `convert-to-llvm` pass.
 // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
 
+// CHECK-LABEL: @init_mbarrier
+llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b" 
+  nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1 
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" 
+  nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
+  llvm.return
+}
+
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
   nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+  //CHECK :  llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b "
+  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 
   llvm.return
 }
 
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" 
   nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 
   llvm.return
 }
 
@@ -51,72 +64,93 @@ func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
 }
 
 // CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32) {
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3}], [$2];", "l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32,i1
   return
 }
 
 // CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4}],...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67102


More information about the Mlir-commits mailing list