[Mlir-commits] [mlir] 6338932 - [mlir][nvvm] Support predicates in `BasicPtxBuilder` (#67102)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 17 03:42:40 PDT 2023


Author: Guray Ozen
Date: 2023-10-17T12:42:36+02:00
New Revision: 63389326f529fd3e3019f8f8afae662e765a3b72

URL: https://github.com/llvm/llvm-project/commit/63389326f529fd3e3019f8f8afae662e765a3b72
DIFF: https://github.com/llvm/llvm-project/commit/63389326f529fd3e3019f8f8afae662e765a3b72.diff

LOG: [mlir][nvvm] Support predicates in `BasicPtxBuilder` (#67102)

This PR enhances `BasicPtxBuilder` to support predicates in PTX code
generation. The `BasicPtxBuilder` interface was initially introduced for
generating PTX code automatically for Ops that aren't supported by LLVM
core. Predicates, which are typically not supported in LLVM core, are
now supported using the same mechanism.

In PTX programming, instructions can be guarded by predicates as shown
below:. Here `@p` is a predicate register and guard the execution of the
instruction.

```
@p ptx.code op1, op2, op3
```

This PR introduces the `getPredicate` function in the `BasicPtxBuilder`
interface to set an optional predicate. When a predicate is provided,
the instruction is generated with predicate and guarded, otherwise,
predicate is not genearted. Note that the predicate value must always
appear as the last argument on the Op definition.

Additionally, this PR implements predicate usage for the following ops:

- mbarrier.init
- mbarrier.init.shared
- mbarrier.arrive.expect_tx
- mbarrier.arrive.expect_tx.shared
- cp.async.bulk.tensor.shared.cluster.global
- cp.async.bulk.tensor.global.shared.cta

See for more detail in PTX programing model

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index 6f27c8eb47175e6..df5a2448bd77968 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -22,6 +22,8 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 // Basic PTX Builder Interface
 //===----------------------------------------------------------------------===//
 
+def PtxPredicate : Optional<I1>;
+
 def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
   let description = [{
     This interface is used to generate inline assembly with PTX for basic 
@@ -62,6 +64,22 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
   }];
   let cppNamespace = "::mlir::NVVM";
   let methods = [
+    InterfaceMethod<
+        /*desc=*/[{
+          Optional function for setting a predicate, which 
+          always returns a `PtxPredicate` value of type i1. If no predicate is 
+          provided, the instruction is unguarded; otherwise, it's guarded by the 
+          predicate value. The `PtxPredicate` value must always be the last argument. 
+          The provided PTX code by `getPtx` should not include the predicate usage.
+          The interface automatically handles predicate usage in the generated
+          PTX code when necessary.
+        }],
+        /*retType=*/"std::optional<::mlir::Value>",
+        /*methodName=*/"getPredicate",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
     InterfaceMethod<
         /*desc=*/[{ Returns PTX assembly with operand number. }],
         /*retType=*/"std::string",

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0a5d1f274a31566..d550fe1f33140ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -74,6 +74,12 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
   LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
 }
 
+/// Base class that defines BasicPtxBuilderOpInterface. 
+class NVVM_PTXBuilder_Op<string mnemonic, 
+  list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
+  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM attribute definitions
 //===----------------------------------------------------------------------===//
@@ -206,21 +212,31 @@ def NVVM_ReduxOp :
 //===----------------------------------------------------------------------===//
 
 /// mbarrier.init instruction with generic pointer type
-def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
-  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
+def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
   }];
-  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let extraClassDeclaration = [{
+    bool hasIntrinsic() { if(getPredicate()) return false; return true; }
+  }];
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
+  }];
 }
 
 /// mbarrier.init instruction with shared pointer type
-def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
-  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
+def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
   }];
-  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
+  }];
 }
 
 def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -275,26 +291,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,  
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,  
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {    
+  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {  
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -313,8 +326,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
   }];
 }
 
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {  
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -488,7 +500,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
 
 def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
 
-def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
   Arguments<(ins LLVM_i8Ptr_shared:$dst,
                  LLVM_i8Ptr_global:$src,
                  I32Attr:$size,
@@ -1359,12 +1371,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_i64ptr_shared:$dstMem,
                   LLVM_i64ptr_any:$tmaDescriptor,
                   LLVM_i64ptr_shared:$mbar,
-                  Variadic<I32>:$coordinates)> {
-  let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       int dim = getCoordinates().size();
@@ -1382,11 +1406,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
   let hasVerifier = 1;
 }
 
-def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
+  NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_i64ptr_any:$tmaDescriptor,
                   LLVM_i64ptr_shared:$srcMem,
-                  Variadic<I32>:$coordinates)> {
-  let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $tmaDescriptor `,` 
+    $srcMem `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)?  
+    attr-dict  `:` type(operands)
+  }];
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       int dim = getCoordinates().size();
@@ -1408,8 +1442,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -1423,8 +1456,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
   let description = [{
@@ -1437,8 +1469,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
   let arguments = (ins I32Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2d43230938526b9..00baf7b3c741565 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -28,6 +28,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
+#include <optional>
 
 #define DEBUG_TYPE "nvgpu-to-nvvm"
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -830,9 +831,10 @@ struct NVGPUMBarrierInitLowering
     Value count = truncToI32(b, adaptor.getCount());
     if (isMbarrierShared(mbarrierType)) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
-                                                              count);
+                                                              count, Value());
     } else {
-      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count);
+      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+                                                        Value());
     }
     return success();
   }
@@ -927,12 +929,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
-          op, barrier, txcount);
+          op, barrier, txcount, Value());
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
-                                                                txcount);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
+        op, barrier, txcount, Value());
     return success();
   }
 };
@@ -983,7 +985,7 @@ struct NVGPUTmaAsyncLoadOpLowering
     }
 
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
-        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
+        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index fa518cf33428b4c..d1d68e3c9c518c4 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -41,6 +41,7 @@ using namespace mlir;
 using namespace NVVM;
 
 namespace {
+
 struct PtxLowering
     : public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
   using OpInterfaceRewritePattern<

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 121504fc20c018f..f3b674fdb505012 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -123,6 +123,14 @@ LLVM::InlineAsmOp PtxBuilder::build() {
 
   std::string ptxInstruction = interfaceOp.getPtx();
 
+  // Add the predicate to the asm string.
+  if (interfaceOp.getPredicate().has_value() &&
+      interfaceOp.getPredicate().value()) {
+    std::string predicateStr = "@%";
+    predicateStr += std::to_string((ptxOperands.size() - 1));
+    ptxInstruction = predicateStr + " " + ptxInstruction;
+  }
+
   // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
   // Replace all % with $
   std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index c1549f9b9dba528..fcc882f562a4a95 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -4,17 +4,30 @@
 // and the generic `convert-to-llvm` pass.
 // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
 
+// CHECK-LABEL: @init_mbarrier
+llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b" 
+  nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1 
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" 
+  nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
+  llvm.return
+}
+
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"  
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
   nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+  //CHECK :  llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b "
+  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 
   llvm.return
 }
 
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" 
   nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 
   llvm.return
 }
 
@@ -73,82 +86,93 @@ func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
 }
 
 // CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r"
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3}], [$2];", "l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32,i1
   return
 }
 
 // CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r"
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4}], [$2];", "l,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5} ], [$2];", "r,l,r,r,r,r"
+func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5} ], [$2];", "r,l,r,r,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5}], [$2];", "l,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_load_4d
-func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6} ], [$2];", "r,l,r,r,r,r,r"
+func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6} ], [$2];", "r,l,r,r,r,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5}], [$2];", "l,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_load_5d
-func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7} ], [$2];", "r,l,r,r,r,r,r,r"
+func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7} ], [$2];", "r,l,r,r,r,r,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7}], [$2];", "l,r,r,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p  : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_store_1d
-func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
+func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
   nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_store_2d
-func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
+func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
   nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_store_3d
-func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
+func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
   nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_store_4d
-func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
+func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
   nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
   return
 }
 
 // CHECK-LABEL: @tma_store_5d
-func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
+func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
   nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
+
+  // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
   return
 }
 


        


More information about the Mlir-commits mailing list