[Mlir-commits] [mlir] [mlir][nvvm] fix a bug when checking layout that cannot be transposed (PR #97538)

bangyu shen llvmlistbot at llvm.org
Wed Jul 3 07:33:40 PDT 2024


https://github.com/shubaoyu2 updated https://github.com/llvm/llvm-project/pull/97538

>From f2f072920b322857ed3ad623aac30cfbc0e27265 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 16:36:56 +0800
Subject: [PATCH 1/7] fix a bug when checking layout that cannot be transposed

the WGMMA expect layouts for A/B are row/col, the transposed version should be col/row. when checking other datatypes cannot use transposed layout, it should reject col-major for A and row-major for B
---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 036a9a15af838..48f44165ccc58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -880,7 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   // Check transpose (only available for f16/bf16)
   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
-       getLayoutB() == mlir::NVVM::MMALayout::col)) {
+       getLayoutB() == mlir::NVVM::MMALayout::row)) {
     return emitOpError()
            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
            << " and layout_b = " << stringifyMMALayout(getLayoutB())

>From aec1049febb1893a5cdb57ff7c2981001dbe355d Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:01:35 +0800
Subject: [PATCH 2/7] ch

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 1834 ++++++++------------
 1 file changed, 693 insertions(+), 1141 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 48f44165ccc58..375e2951a037c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1,1142 +1,694 @@
-//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the types and operation details for the NVVM IR dialect in
-// MLIR, and the LLVM IR dialect.  It also registers the dialect.
-//
-// The NVVM dialect only contains GPU specific additions on top of the general
-// LLVM dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-
-#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
-#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/Types.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/AsmParser/Parser.h"
-#include "llvm/IR/Attributes.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Type.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-#include <cassert>
-#include <optional>
-#include <string>
-
-using namespace mlir;
-using namespace NVVM;
-
-#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
-#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
-
-//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
-  p << " " << op->getOperands();
-  if (op->getNumResults() > 0)
-    p << " : " << op->getResultTypes();
+// RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s
+
+// Same below, but using the `ConvertToLLVMPatternInterface` entry point
+// and the generic `convert-to-llvm` pass.
+// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @init_mbarrier
+llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b" 
+  nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1 
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" 
+  nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
+  llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
+  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+  //CHECK:  llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
+  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 
+  llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" 
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 
+  llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_try_wait_shared
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{
+  // CHECK-SAME: .reg .pred       P1;
+  // CHECK-SAME: LAB_WAIT: 
+  // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
+  // CHECK-SAME: @P1 bra.uni DONE;
+  // CHECK-SAME: bra.uni     LAB_WAIT;
+  // CHECK-SAME: DONE:
+  // CHECK-SAME: }",
+  // CHECK-SAME: "r,r,r"
+   nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+  llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_try_wait
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+  // CHECK-SAME: "{
+  // CHECK-SAME: .reg .pred       P1;
+  // CHECK-SAME: LAB_WAIT: 
+  // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
+  // CHECK-SAME: @P1 bra.uni DONE;
+  // CHECK-SAME: bra.uni     LAB_WAIT;
+  // CHECK-SAME: DONE:
+  // CHECK-SAME: }",
+  // CHECK-SAME: "l,r,r"
+  nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
+  llvm.return
+}
+
+// CHECK-LABEL: @async_cp
+func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache =  cg : !llvm.ptr<3>, !llvm.ptr<1>
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg : !llvm.ptr<3>, !llvm.ptr<1>
+  return
+}
+
+// CHECK-LABEL: @async_cp_zfill
+func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", 
+  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", 
+  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+  nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  return
+}
+
+// CHECK-LABEL: @cp_async_mbarrier_arrive
+func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
+  // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}}
+  nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
+  // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true}
+  nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
+  // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}}
+  nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3>
+  // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} {noinc = true}
+  nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3>
+  llvm.return
+}
+
+// CHECK-LABEL: @tma_load_3d_all
+func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_4d_all
+func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_5d_all
+func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
+  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_1d
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_2d
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_3d
+func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_4d
+func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_5d
+func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask  predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask  predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask: !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @tma_store_1d
+func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_store_2d
+func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_store_3d
+func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_store_4d
+func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @tma_store_5d
+func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+  // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
+
+  // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
+  return
+}
+
+// CHECK-LABEL: @wgmma_execute
+func.func @wgmma_execute() {  
+  nvvm.wgmma.fence.aligned
+  nvvm.wgmma.commit.group.sync.aligned
+  nvvm.wgmma.wait.group.sync.aligned 0
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
+  // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
+  
+
+  nvvm.wgmma.fence.aligned
+  nvvm.wgmma.commit.group.sync.aligned
+  nvvm.wgmma.wait.group.sync.aligned 5
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
+  // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
+  return
+}
+
+
+// -----
+
+!mat64f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_f16_f16(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{  
+  // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+  // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
+  // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
+  // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  
+  // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{
+  // CHECK-SAME: reg .pred p;
+  // CHECK-SAME: setp.ne.b32 p, $34, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 
+  // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35,  $36, $37,  $38;\0A}\0A", 
+  // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" 
+  // CHECK-SAME: %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] 
+  // CHECK-SAME: : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) 
+  // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+  // CHECK: %[[DESCa:.+]] = llvm.add %[[ARG0]], %[[C2]] : i64
+  // CHECK: %[[DESCb:.+]] = llvm.add %[[ARG1]], %[[C2]] : i64
+  // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  
+  // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{
+    // CHECK-SAME: .reg .pred p;
+    // CHECK-SAME: setp.ne.b32 p, $34, 0;
+    // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 
+    // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35,  $36, $37,  $38;\0A}\0A", 
+    // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" 
+    // CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} 
+  %result = llvm.mlir.undef : !mat64f32
+  %result1 = nvvm.wgmma.mma_async 
+      %descA, %descB, %result,
+      #nvvm.shape<m = 64, n = 32, k = 16>, 
+      D [<f32>, #nvvm.wgmma_scale_out<zero>],
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      :!mat64f32 -> !mat64f32
+  %c2 = arith.constant 2 : i64
+  %descAnext = arith.addi %descA, %c2 : i64
+  %descBnext = arith.addi %descB, %c2 : i64
+  %result2 = nvvm.wgmma.mma_async 
+      %descAnext, %descBnext, %result1,
+      #nvvm.shape<m = 64, n = 32, k = 16>, 
+      D [<f32>, #nvvm.wgmma_scale_out<zero>],
+      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
+      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+      : !mat64f32 -> !mat64f32
+  return %result2 : !mat64f32
+}
+
+// -----
+
+!mat16i32 = !llvm.struct<(i32, i32, i32, i32)>
+
+// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{  
+  %result = llvm.mlir.undef : !mat16i32
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] =  llvm.inline_asm has_side_effects asm_dialect = att 
+// CHECK-SAME: "{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite 
+// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" 
+// CHECK-SAME: %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : 
+// CHECK-SAME: (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
+// CHECK-SAME: "{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite 
+// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", 
+// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" 
+// CHECK-SAME: %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
+// CHECK-SAME:"{
+// CHECK-SAME:.reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
+// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" 
+// CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
+      : !mat16i32 -> !mat16i32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
+      : !mat16i32 -> !mat16i32
+  %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, 
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
+      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
+      : !mat16i32 -> !mat16i32
+  return %result3 : !mat16i32
+}
+
+// CHECK-LABEL: @wgmma_s32_u8_u8(
+  // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {  
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] =  llvm.inline_asm has_side_effects asm_dialect = att 
+// CHECK-SAME: "{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
+// CHECK-SAME: }\0A",
+// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : 
+// CHECK-SAME:(i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
+// CHECK-SAME:"{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
+// CHECK-SAME: }\0A",
+// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
+// CHECK-SAME:"{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
+// CHECK-SAME:}\0A", 
+// CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
+  %result = llvm.mlir.undef : !mat16i32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [<s32>, #nvvm.wgmma_scale_out<one>],
+      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
+      : !mat16i32 -> !mat16i32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [<s32>, #nvvm.wgmma_scale_out<one>],
+      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
+      : !mat16i32 -> !mat16i32
+  %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
+      #nvvm.shape<m = 64, n = 8, k = 32>, 
+      D [<s32>, #nvvm.wgmma_scale_out<one>],
+      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
+      : !mat16i32 -> !mat16i32
+  return %result3 : !mat16i32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_tf32_tf32
+func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {  
+  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME:"{
+  // CHECK-SAME: .reg .pred p;
+  // CHECK-SAME: setp.ne.b32 p, $66, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{
+  // CHECK-SAME: .reg .pred p;
+  // CHECK-SAME: setp.ne.b32 p, $66, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  %result = llvm.mlir.undef : !mat32f32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+      #nvvm.shape<m = 64, n = 64, k = 8>, 
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+       : !mat32f32 -> !mat32f32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+      #nvvm.shape<m = 64, n = 64, k = 8>, 
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+      : !mat32f32 -> !mat32f32
+  return %result2 : !mat32f32
+}
+
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e4m3_e4m3
+func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {  
+  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  %result = llvm.mlir.undef : !mat32f32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+       : !mat32f32 -> !mat32f32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+      : !mat32f32 -> !mat32f32
+  return %result2 : !mat32f32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32, 
+  f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e5m2_e4m3
+func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {  
+  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
+  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+  %result = llvm.mlir.undef : !mat32f32
+  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+       : !mat32f32 -> !mat32f32
+  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+      #nvvm.shape<m = 64, n = 64, k = 32>, 
+      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+      A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+      : !mat32f32 -> !mat32f32
+  return %result2 : !mat32f32
+}
+
+// -----
+
+func.func @elect_one_leader_sync() {  
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
+  // CHECK-SAME: .reg .u32 rx;
+  // CHECK-SAME: .reg .pred px;
+  // CHECK-SAME: mov.pred $0, 0;
+  // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
+  // CHECK-SAME: @px mov.pred $0, 1;
+  // CHECK-SAME: "=b"  : () -> i1
+  %cnd = nvvm.elect.sync -> i1 
+  return 
+}
+
+// -----
+
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, 
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+  llvm.return 
+}
+
+// -----
+
+// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
+llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
+  nvvm.prefetch.tensormap %desc : !llvm.ptr
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
+  nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
+  llvm.return
+}
+
+// -----
+
+func.func @set_max_register() {
+  // CHECK: nvvm.setmaxregister increase 232
+  nvvm.setmaxregister increase 232
+
+  // CHECK: nvvm.setmaxregister decrease 40
+  nvvm.setmaxregister decrease 40
+  func.return
+}
+
+// -----
+
+func.func @cp_async_bulk_commit() {
+  // CHECK: nvvm.cp.async.bulk.commit.group
+  nvvm.cp.async.bulk.commit.group
+  func.return
+}
+
+// -----
+
+func.func @cp_async_bulk_wait_group() {
+  // CHECK: nvvm.cp.async.bulk.wait_group 1
+  // CHECK: nvvm.cp.async.bulk.wait_group 0
+  // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+  // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
+  nvvm.cp.async.bulk.wait_group 1
+  nvvm.cp.async.bulk.wait_group 0
+  nvvm.cp.async.bulk.wait_group 5 {read}
+  nvvm.cp.async.bulk.wait_group 0 {read}
+  func.return
+}
+
+// -----
+
+func.func @fence_mbarrier_init() {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;"
+  nvvm.fence.mbarrier.init
+  func.return 
+}
+// -----
+
+func.func @fence_proxy() {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", ""  : () -> ()
+  nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>}
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", ""  : () -> ()
+  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>}
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", ""  : () -> ()
+  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>}
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", ""  : () -> ()
+  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>}
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", ""  : () -> ()
+  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: @llvm_nvvm_barrier_arrive
+// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
+llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0, $0;", "r" %[[numberOfThreads]] : (i32) -> ()
+  nvvm.barrier.arrive number_of_threads = %numberOfThreads
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1;", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
+  nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
+  llvm.return
 }
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
-  MLIRContext *context = parser.getContext();
-  auto int32Ty = IntegerType::get(context, 32);
-  auto int1Ty = IntegerType::get(context, 1);
-
-  SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
-  Type type;
-  return failure(parser.parseOperandList(ops) ||
-                 parser.parseOptionalAttrDict(result.attributes) ||
-                 parser.parseColonType(type) ||
-                 parser.addTypeToList(type, result.types) ||
-                 parser.resolveOperands(ops, {int32Ty, int1Ty},
-                                        parser.getNameLoc(), result.operands));
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
-  if (getCoordinates().empty() || getCoordinates().size() > 5)
-    return emitError("expects coordinates between 1 to 5 dimension");
-
-  // Check for im2col mode
-  if (!getIm2colOffsets().empty()) {
-    if (getCoordinates().size() < 3)
-      return emitError(
-          "to use im2col mode, the tensor has to be at least 3-dimensional");
-    if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
-      return emitError(
-          "im2col offsets must be 2 less than number of coordinates");
-  }
-  return success();
-}
-
-LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
-  if (getCoordinates().size() > 5)
-    return emitError("Maximum 5 coordinates and dimension is supported.");
-  return success();
-}
-
-LogicalResult CpAsyncOp::verify() {
-  if (getModifier() != LoadCacheModifierKind::CG &&
-      getModifier() != LoadCacheModifierKind::CA)
-    return emitError("Only CG and CA cache modifiers are supported.");
-  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
-    return emitError("expected byte size to be either 4, 8 or 16.");
-  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
-    return emitError("CG cache modifier is only support for 16 bytes copy.");
-  return success();
-}
-
-// Given the element type of an operand and whether or not it is an accumulator,
-// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
-// operand's element type.
-std::optional<mlir::NVVM::MMATypes>
-MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
-  auto half2Type =
-      LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
-  if (operandElType.isF64())
-    return NVVM::MMATypes::f64;
-  if (operandElType.isF16() || operandElType == half2Type)
-    return NVVM::MMATypes::f16;
-  if (operandElType.isF32() && isAccumulator)
-    return NVVM::MMATypes::f32;
-  if (operandElType.isF32() && !isAccumulator)
-    return NVVM::MMATypes::tf32;
-  if (llvm::isa<IntegerType>(operandElType)) {
-    if (isAccumulator)
-      return NVVM::MMATypes::s32;
-    return std::nullopt;
-  }
-
-  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
-    if (structType.getBody().empty())
-      return std::nullopt;
-    return inferOperandMMAType(structType.getBody()[0], isAccumulator);
-  }
-
-  return std::nullopt;
-}
-
-static bool isInt4PtxType(MMATypes type) {
-  return (type == MMATypes::u4 || type == MMATypes::s4);
-}
-
-static bool isInt8PtxType(MMATypes type) {
-  return (type == MMATypes::u8 || type == MMATypes::s8);
-}
-
-static bool isIntegerPtxType(MMATypes type) {
-  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
-         type == MMATypes::s32;
-}
-
-MMATypes MmaOp::accumPtxType() {
-  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
-      getODSOperands(2).getTypes().front(), /*isAccum=*/true);
-  assert(val.has_value() && "accumulator PTX type should always be inferrable");
-  return val.value();
-}
-
-MMATypes MmaOp::resultPtxType() {
-  std::optional<mlir::NVVM::MMATypes> val =
-      inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
-  assert(val.has_value() && "result PTX type should always be inferrable");
-  return val.value();
-}
-
-void MmaOp::print(OpAsmPrinter &p) {
-  SmallVector<Type, 4> regTypes;
-  struct OperandFragment {
-    StringRef operandName;
-    StringRef ptxTypeAttr;
-    SmallVector<Value, 4> regs;
-    explicit OperandFragment(StringRef name, StringRef ptxTypeName)
-        : operandName(name), ptxTypeAttr(ptxTypeName) {}
-  };
-
-  std::array<OperandFragment, 3> frags{
-      OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
-      OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
-      OperandFragment("C", "")};
-  SmallVector<StringRef, 4> ignoreAttrNames{
-      mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
-
-  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
-    auto &frag = frags[fragIdx];
-    auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
-    for (auto operandIdx = varOperandSpec.first;
-         operandIdx < varOperandSpec.first + varOperandSpec.second;
-         operandIdx++) {
-      frag.regs.push_back(this->getOperand(operandIdx));
-      if (operandIdx == 0) {
-        regTypes.push_back(this->getOperand(operandIdx).getType());
-      }
-    }
-    std::optional<MMATypes> inferredType =
-        inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
-    if (inferredType)
-      ignoreAttrNames.push_back(frag.ptxTypeAttr);
-  }
-
-  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
-    p << " " << frag.operandName;
-    p << "[";
-    p.printOperands(frag.regs);
-    p << "] ";
-  };
-
-  for (const auto &frag : frags) {
-    printMmaOperand(frag);
-  }
-
-  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
-
-  // Print the types of the operands and result.
-  p << " : " << "(";
-  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
-                                             frags[1].regs[0].getType(),
-                                             frags[2].regs[0].getType()},
-                        p);
-  p << ")";
-  p.printArrowTypeList(TypeRange{this->getRes().getType()});
-}
-
-void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
-                  ValueRange operandA, ValueRange operandB, ValueRange operandC,
-                  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
-                  std::optional<MMAIntOverflow> intOverflow,
-                  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
-                  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
-
-  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
-  MLIRContext *ctx = builder.getContext();
-  result.addAttribute(
-      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
-
-  result.addOperands(operandA);
-  result.addOperands(operandB);
-  result.addOperands(operandC);
-
-  if (multiplicandPtxTypes) {
-    result.addAttribute("multiplicandAPtxType",
-                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
-    result.addAttribute("multiplicandBPtxType",
-                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
-  } else {
-    if (auto res = inferOperandMMAType(operandA[0].getType(), false))
-      result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
-    if (auto res = inferOperandMMAType(operandB[0].getType(), false))
-      result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
-  }
-
-  if (multiplicandLayouts) {
-    result.addAttribute("layoutA",
-                        MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
-    result.addAttribute("layoutB",
-                        MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
-  } else {
-    result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
-    result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
-  }
-
-  if (intOverflow.has_value())
-    result.addAttribute("intOverflowBehavior",
-                        MMAIntOverflowAttr::get(ctx, *intOverflow));
-  if (b1Op.has_value())
-    result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
-
-  result.addTypes(resultType);
-  result.addAttribute(
-      MmaOp::getOperandSegmentSizeAttr(),
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
-                                    static_cast<int32_t>(operandB.size()),
-                                    static_cast<int32_t>(operandC.size())}));
-}
-
-// <operation> :=
-//   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
-//   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
-//     `->` type($res)
-ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
-  struct OperandFragment {
-    std::optional<MMATypes> elemtype;
-    SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
-    SmallVector<Type> regTypes;
-  };
-
-  Builder &builder = parser.getBuilder();
-  std::array<OperandFragment, 4> frags;
-
-  NamedAttrList namedAttributes;
-
-  // A helper to parse the operand segments.
-  auto parseMmaOperand = [&](StringRef operandName,
-                             OperandFragment &frag) -> LogicalResult {
-    if (parser.parseKeyword(operandName).failed())
-      return failure();
-    if (parser
-            .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
-            .failed())
-      return failure();
-    return success();
-  };
-
-  // Parse the operand segments.
-  if (parseMmaOperand("A", frags[0]).failed())
-    return failure();
-  if (parseMmaOperand("B", frags[1]).failed())
-    return failure();
-  if (parseMmaOperand("C", frags[2]).failed())
-    return failure();
-
-  if (parser.parseOptionalAttrDict(namedAttributes).failed())
-    return failure();
-
-  // Parse the type specification and resolve operands.
-  SmallVector<Type, 3> operandTypes;
-  if (failed(parser.parseColon()))
-    return failure();
-  if (failed(parser.parseLParen()))
-    return failure();
-  if (failed(parser.parseTypeList(operandTypes)))
-    return failure();
-  if (failed(parser.parseRParen()))
-    if (operandTypes.size() != 3)
-      return parser.emitError(
-          parser.getNameLoc(),
-          "expected one type for each operand segment but got " +
-              Twine(operandTypes.size()) + " types");
-  for (const auto &iter : llvm::enumerate(operandTypes)) {
-    auto &frag = frags[iter.index()];
-    frag.regTypes.resize(frag.regs.size(), iter.value());
-    if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
-                                      parser.getNameLoc(), result.operands)))
-      return failure();
-    frag.elemtype =
-        inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
-  }
-
-  Type resultType;
-  if (parser.parseArrow() || parser.parseType(resultType))
-    return failure();
-  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
-
-  std::array<StringRef, 2> names{"multiplicandAPtxType",
-                                 "multiplicandBPtxType"};
-  for (unsigned idx = 0; idx < names.size(); idx++) {
-    const auto &frag = frags[idx];
-    std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
-    if (!frag.elemtype.has_value() && !attr.has_value()) {
-      return parser.emitError(
-          parser.getNameLoc(),
-          "attribute " + names[idx] +
-              " is not provided explicitly and cannot be inferred");
-    }
-    if (!attr.has_value())
-      result.addAttribute(
-          names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
-  }
-
-  result.addTypes(resultType);
-  if (!namedAttributes.empty())
-    result.addAttributes(namedAttributes);
-  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
-                      builder.getDenseI32ArrayAttr({
-                          static_cast<int32_t>(frags[0].regs.size()),
-                          static_cast<int32_t>(frags[1].regs.size()),
-                          static_cast<int32_t>(frags[2].regs.size()),
-                      }));
-  return success();
-}
-
-LogicalResult MmaOp::verify() {
-  MLIRContext *context = getContext();
-  auto f16Ty = Float16Type::get(context);
-  auto i32Ty = IntegerType::get(context, 32);
-  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
-  auto f32Ty = Float32Type::get(context);
-  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
-      context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
-
-  auto s32x4StructTy =
-      LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
-  auto f32x8StructTy =
-      LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
-  auto f16x2x2StructTy =
-      LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
-  auto f32x4StructTy =
-      LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
-  auto s32x2StructTy =
-      LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
-
-  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
-                                  getShapeAttr().getK()};
-
-  // These variables define the set of allowed data types for matrices A, B, C,
-  // and result.
-  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
-  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
-  AllowedShapes allowedShapes;
-  AllowedTypes expectedA;
-  AllowedTypes expectedB;
-  AllowedTypes expectedC;
-  SmallVector<Type> expectedResult;
-
-  // When M = 16, we just need to calculate the number of 8xk tiles, where
-  // k is a factor that depends on the data type.
-  if (mmaShape[0] == 16) {
-    int64_t kFactor;
-    Type multiplicandFragType;
-    switch (*getMultiplicandAPtxType()) {
-    case MMATypes::tf32:
-      kFactor = 4;
-      multiplicandFragType = i32Ty;
-      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
-          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
-      break;
-    case MMATypes::f16:
-    case MMATypes::bf16:
-      kFactor = 8;
-      multiplicandFragType = f16x2Ty;
-      expectedResult.push_back(f16x2x2StructTy);
-      expectedResult.push_back(f32x4StructTy);
-      break;
-    case MMATypes::s4:
-    case MMATypes::u4:
-      kFactor = 32;
-      break;
-    case MMATypes::b1:
-      kFactor = 128;
-      break;
-    case MMATypes::s8:
-    case MMATypes::u8:
-      kFactor = 16;
-      break;
-    default:
-      return emitError("invalid shape or multiplicand type: " +
-                       stringifyEnum(getMultiplicandAPtxType().value()));
-    }
-
-    if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
-      expectedResult.push_back(s32x4StructTy);
-      expectedC.emplace_back(4, i32Ty);
-      multiplicandFragType = i32Ty;
-    } else {
-      expectedC.emplace_back(2, f16x2Ty);
-      expectedC.emplace_back(4, f32Ty);
-    }
-
-    int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
-    int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
-    expectedA.emplace_back(unitA, multiplicandFragType);
-    expectedB.emplace_back(unitB, multiplicandFragType);
-    allowedShapes.push_back({16, 8, kFactor});
-    allowedShapes.push_back({16, 8, kFactor * 2});
-  }
-
-  // In the M=8 case, there is only 1 possible case per data type.
-  if (mmaShape[0] == 8) {
-    if (*getMultiplicandAPtxType() == MMATypes::f16) {
-      expectedA.emplace_back(2, f16x2Ty);
-      expectedB.emplace_back(2, f16x2Ty);
-      expectedResult.push_back(f16x2x4StructTy);
-      expectedResult.push_back(f32x8StructTy);
-      expectedC.emplace_back(4, f16x2Ty);
-      expectedC.emplace_back(8, f32Ty);
-      allowedShapes.push_back({8, 8, 4});
-    }
-    if (*getMultiplicandAPtxType() == MMATypes::f64) {
-      Type f64Ty = Float64Type::get(context);
-      expectedA.emplace_back(1, f64Ty);
-      expectedB.emplace_back(1, f64Ty);
-      expectedC.emplace_back(2, f64Ty);
-      // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
-      expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
-          context, SmallVector<Type>(2, f64Ty)));
-      allowedShapes.push_back({8, 8, 4});
-    }
-    if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
-      expectedA.push_back({i32Ty});
-      expectedB.push_back({i32Ty});
-      expectedC.push_back({i32Ty, i32Ty});
-      expectedResult.push_back(s32x2StructTy);
-      if (isInt4PtxType(getMultiplicandAPtxType().value()))
-        allowedShapes.push_back({8, 8, 32});
-      if (isInt8PtxType(getMultiplicandAPtxType().value()))
-        allowedShapes.push_back({8, 8, 16});
-      if (getMultiplicandAPtxType().value() == MMATypes::b1)
-        allowedShapes.push_back({8, 8, 128});
-    }
-  }
-
-  std::string errorMessage;
-  llvm::raw_string_ostream errorStream(errorMessage);
-
-  // Check that we matched an existing shape/dtype combination.
-  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
-      !llvm::is_contained(allowedShapes, mmaShape)) {
-    errorStream << "unimplemented variant for MMA shape <";
-    llvm::interleaveComma(mmaShape, errorStream);
-    errorStream << ">";
-    return emitOpError(errorMessage);
-  }
-
-  // Verify the operand types for segments of A, B, and C operands.
-  std::array<StringRef, 3> operandNames{"A", "B", "C"};
-  for (const auto &iter : llvm::enumerate(
-           SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
-    auto spec = this->getODSOperandIndexAndLength(iter.index());
-    SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
-                                      operand_type_begin() + spec.first +
-                                          spec.second);
-    bool match = llvm::is_contained(iter.value(), operandTySeg);
-
-    if (!match) {
-      errorStream << "Could not match types for the "
-                  << operandNames[iter.index()]
-                  << " operands; expected one of ";
-      for (const auto &x : iter.value()) {
-        errorStream << x.size() << "x" << x[0] << " ";
-      }
-      errorStream << "but got ";
-      llvm::interleaveComma(operandTySeg, errorStream);
-      return emitOpError(errorStream.str());
-    }
-  }
-
-  // Check the result type
-  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
-        return expectedResultType == getResult().getType();
-      })) {
-    errorStream
-        << "Could not match allowed types for the result; expected one of ";
-    llvm::interleaveComma(expectedResult, errorStream);
-    errorStream << " but got " << getResult().getType();
-    return emitOpError(errorStream.str());
-  }
-
-  // Ensure that binary MMA variants have a b1 MMA operation defined.
-  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
-    return emitOpError("op requires " + getB1OpAttrName().strref() +
-                       " attribute");
-  }
-
-  // Ensure int4/int8 MMA variants specify the accum overflow behavior
-  // attribute.
-  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
-      isInt8PtxType(*getMultiplicandAPtxType())) {
-    if (!getIntOverflowBehavior())
-      return emitOpError("op requires " +
-                         getIntOverflowBehaviorAttrName().strref() +
-                         " attribute");
-  }
-
-  return success();
-}
-
-LogicalResult ShflOp::verify() {
-  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
-    return success();
-  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
-  auto elementType = (type && type.getBody().size() == 2)
-                         ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
-                         : nullptr;
-  if (!elementType || elementType.getWidth() != 1)
-    return emitError("expected return type to be a two-element struct with "
-                     "i1 as the second element");
-  return success();
-}
-
-std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
-                                                   NVVM::MMAFrag frag, int nRow,
-                                                   int nCol,
-                                                   MLIRContext *context) {
-  unsigned numberElements = 0;
-  Type elementType;
-  OpBuilder builder(context);
-  Type f16x2 = VectorType::get(2, builder.getF16Type());
-  if (type == NVVM::MMATypes::f16) {
-    elementType = f16x2;
-    if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
-      numberElements = 8;
-    else
-      numberElements = 4;
-  } else if (type == NVVM::MMATypes::f32) {
-    elementType = builder.getF32Type();
-    numberElements = 8;
-  } else if (type == NVVM::MMATypes::tf32) {
-    elementType = builder.getI32Type();
-    numberElements = 4;
-  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
-    elementType = builder.getI32Type();
-    int parallelSize = 0;
-    if (frag == NVVM::MMAFrag::a)
-      parallelSize = nRow;
-    if (frag == NVVM::MMAFrag::b)
-      parallelSize = nCol;
-
-    // m == 16 && n == 16 && k == 16
-    if (parallelSize == 16)
-      numberElements = 2;
-    // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
-    else if (parallelSize == 8)
-      numberElements = 1;
-    else if (parallelSize == 32)
-      numberElements = 4;
-  } else if (type == NVVM::MMATypes::s32) {
-    elementType = builder.getI32Type();
-    numberElements = 8;
-  }
-  assert(numberElements != 0 && elementType != nullptr);
-  return std::make_pair(elementType, numberElements);
-}
-
-static std::pair<mlir::Type, unsigned>
-inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
-                    int k, MLIRContext *context) {
-  int nRow, nCol;
-  if (frag == NVVM::MMAFrag::a) {
-    nRow = m;
-    nCol = k;
-  } else if (frag == NVVM::MMAFrag::b) {
-    nRow = k;
-    nCol = n;
-  } else {
-    nRow = m;
-    nCol = n;
-  }
-  assert(nRow && nCol);
-  return inferMMAType(type, frag, nRow, nCol, context);
-}
-
-LogicalResult NVVM::WMMALoadOp::verify() {
-  unsigned addressSpace =
-      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
-  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
-      addressSpace != NVVM::kSharedMemorySpace)
-    return emitOpError("expected source pointer in memory "
-                       "space 0, 1, 3");
-
-  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
-                                       getEltype(), getFrag()) == 0)
-    return emitOpError() << "invalid attribute combination";
-  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
-      getEltype(), getFrag(), getM(), getN(), getK(), getContext());
-  Type dstType = LLVM::LLVMStructType::getLiteral(
-      getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
-  if (getType() != dstType)
-    return emitOpError("expected destination type is a structure of ")
-           << typeInfo.second << " elements of type " << typeInfo.first;
-  return success();
-}
-
-LogicalResult NVVM::WMMAStoreOp::verify() {
-  unsigned addressSpace =
-      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
-  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
-      addressSpace != NVVM::kSharedMemorySpace)
-    return emitOpError("expected operands to be a source pointer in memory "
-                       "space 0, 1, 3");
-
-  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
-                                        getEltype()) == 0)
-    return emitOpError() << "invalid attribute combination";
-  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
-      getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
-  if (getArgs().size() != typeInfo.second)
-    return emitOpError() << "expected " << typeInfo.second << " data operands";
-  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
-        return operands.getType() != typeInfo.first;
-      }))
-    return emitOpError() << "expected data operands of type " << typeInfo.first;
-  return success();
-}
-
-LogicalResult NVVM::WMMAMmaOp::verify() {
-  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
-                                      getLayoutB(), getEltypeA(),
-                                      getEltypeB()) == 0)
-    return emitOpError() << "invalid attribute combination";
-  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
-      getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
-  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
-      getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
-  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
-      getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
-  SmallVector<Type, 32> arguments;
-  arguments.append(typeInfoA.second, typeInfoA.first);
-  arguments.append(typeInfoB.second, typeInfoB.first);
-  arguments.append(typeInfoC.second, typeInfoC.first);
-  unsigned numArgs = arguments.size();
-  if (getArgs().size() != numArgs)
-    return emitOpError() << "expected " << numArgs << " arguments";
-  for (unsigned i = 0; i < numArgs; i++) {
-    if (getArgs()[i].getType() != arguments[i])
-      return emitOpError() << "expected argument " << i << " to be of type "
-                           << arguments[i];
-  }
-  Type dstType = LLVM::LLVMStructType::getLiteral(
-      getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
-  if (getType() != dstType)
-    return emitOpError("expected destination type is a structure of ")
-           << typeInfoC.second << " elements of type " << typeInfoC.first;
-  return success();
-}
-
-LogicalResult NVVM::LdMatrixOp::verify() {
-  unsigned addressSpace =
-      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
-  if (addressSpace != NVVM::kSharedMemorySpace)
-    return emitOpError("expected source pointer in memory space 3");
-
-  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
-    return emitOpError("expected num attribute to be 1, 2 or 4");
-
-  Type i32 = IntegerType::get(getContext(), 32);
-  if (getNum() == 1 && getType() != i32)
-    return emitOpError("expected destination type is i32");
-  if (getNum() == 2 || getNum() == 4) {
-    Type dstType = LLVM::LLVMStructType::getLiteral(
-        getContext(), SmallVector<Type>(getNum(), i32));
-    if (getType() != dstType)
-      return emitOpError("expected destination type is a structure of ")
-             << getNum() << " elements of type i32";
-  }
-  return success();
-}
-
-LogicalResult NVVM::StMatrixOp::verify() {
-  unsigned addressSpace =
-      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
-  if (addressSpace != NVVM::kSharedMemorySpace)
-    return emitOpError("expected source pointer in memory space 3");
-
-  int numMatrix = getSources().size();
-  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
-    return emitOpError("expected num attribute to be 1, 2 or 4");
-
-  return success();
-}
-
-FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
-  if (typeA == NVVM::WGMMATypes::tf32)
-    return 8;
-  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
-    return 16;
-  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
-    return 32;
-  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
-    return 32;
-  if (typeA == NVVM::WGMMATypes::b1)
-    return 256;
-  return failure();
-}
-
-LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
-                                     NVVM::WGMMATypes typeA,
-                                     NVVM::WGMMATypes typeB) {
-  switch (typeA) {
-  case NVVM::WGMMATypes::f16:
-    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
-        typeB == NVVM::WGMMATypes::f16)
-      return success();
-    break;
-  case NVVM::WGMMATypes::tf32:
-    if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
-      return success();
-    break;
-  case NVVM::WGMMATypes::u8:
-  case NVVM::WGMMATypes::s8:
-    if (typeD == NVVM::WGMMATypes::s32 &&
-        (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
-      return success();
-    break;
-  case NVVM::WGMMATypes::b1:
-    if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
-      return success();
-    break;
-  case NVVM::WGMMATypes::bf16:
-    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
-        typeB == NVVM::WGMMATypes::bf16)
-      return success();
-    break;
-  case NVVM::WGMMATypes::e4m3:
-  case NVVM::WGMMATypes::e5m2:
-    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
-        (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
-      return success();
-    break;
-  case WGMMATypes::f32:
-  case WGMMATypes::s32:
-    llvm_unreachable("unsupported input types");
-    break;
-  }
-  return failure();
-}
-
-LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
-  SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
-                               72,  80,  88,  96,  104, 112, 120, 128,
-                               136, 144, 152, 160, 168, 176, 184, 192,
-                               200, 208, 216, 224, 232, 240, 248, 256};
-  SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
-                                    80,  96,  112, 128, 144, 160,
-                                    176, 192, 208, 224, 240, 256};
-  switch (typeA) {
-  case WGMMATypes::f16:
-  case WGMMATypes::tf32:
-  case WGMMATypes::bf16:
-  case WGMMATypes::e4m3:
-  case WGMMATypes::e5m2:
-    if (llvm::is_contained(allowedN, sizeN))
-      return success();
-    break;
-  case WGMMATypes::u8:
-  case WGMMATypes::s8:
-  case WGMMATypes::b1:
-    if (llvm::is_contained(allowedNshort, sizeN))
-      return success();
-    break;
-  case WGMMATypes::f32:
-  case WGMMATypes::s32:
-    llvm_unreachable("unsupported input types");
-    break;
-  }
-  return failure();
-}
-
-LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
-  Value outValue = getResults();
-  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
-  if (!stype)
-    return emitOpError() << "expected results to be struct";
-  int outputSize = stype.getBody().size();
-  WGMMATypes typeD = getTypeD();
-  WGMMATypes typeA = getTypeA();
-  WGMMATypes typeB = getTypeB();
-
-  for (Type t : stype.getBody()) {
-    if (t != stype.getBody().front())
-      return emitOpError()
-             << "all elements in struct must be same type but there is " << t;
-  }
-
-  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
-      typeD != WGMMATypes::s32) {
-    return emitOpError() << "does not support the given output type "
-                         << NVVM::stringifyWGMMATypes(typeD);
-  }
-  if (typeD == WGMMATypes::s32 &&
-      (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
-    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
-  }
-
-  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
-    return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
-                         << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
-                         << NVVM::stringifyWGMMATypes(typeB)
-                         << ", it is not supported.";
-  }
-
-  // Check M
-  if (getShape().getM() != 64)
-    return emitOpError() << "shape 'm' must be 64";
-
-  // Check K
-  FailureOr<int> allowedK = getAllowedSizeK(typeA);
-  if (failed(allowedK) || allowedK.value() != getShape().getK())
-    return emitOpError() << "shape 'k' must be " << allowedK.value()
-                         << " for input type "
-                         << NVVM::stringifyWGMMATypes(typeA);
-
-  // Check N
-  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
-    return emitOpError() << "has input type "
-                         << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
-                         << getShape().getN() << ", it is not supported.";
-  }
-
-  // Check transpose (only available for f16/bf16)
-  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
-      (getLayoutA() == mlir::NVVM::MMALayout::col ||
-       getLayoutB() == mlir::NVVM::MMALayout::row)) {
-    return emitOpError()
-           << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
-           << " and layout_b = " << stringifyMMALayout(getLayoutB())
-           << " for input types " << stringifyWGMMATypes(typeA) << " and "
-           << stringifyWGMMATypes(typeB)
-           << " requires transpose. However, this is only supported for: "
-           << stringifyMMATypes(MMATypes::f16) << " and "
-           << stringifyMMATypes(MMATypes::bf16);
-  }
-
-  // Check result registers
-  int expectedOutput = 0;
-  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
-    expectedOutput = getShape().getN() / 2;
-  if (typeD == WGMMATypes::f16)
-    expectedOutput = getShape().getN() / 4;
-  if (outputSize != expectedOutput) {
-    return emitOpError() << "results " << expectedOutput
-                         << ", however output struct has " << outputSize
-                         << " elements";
-  }
-  // Check satfinite (only available for s32 accumulator)
-  if (typeD != WGMMATypes::s32 &&
-      getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
-          NVVM::MMAIntOverflow::satfinite) {
-    return emitOpError()
-           << " `satfinite` can be only used with s32 accumulator, however "
-              "the current accumulator is "
-           << NVVM::stringifyWGMMATypes(typeD);
-  }
-
-  return success();
-}
-
-std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
-
-  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
-  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
-
-  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
-
-  int expectedOutputRegisters = 0;
-  if (getTypeD() == WGMMATypes::f16)
-    expectedOutputRegisters = getShape().getN() / 4;
-  else
-    expectedOutputRegisters = getShape().getN() / 2;
-
-  std::string ptx;
-  llvm::raw_string_ostream ss(ptx);
-
-  ss << "{\n"
-        ".reg .pred p;\n"
-        "setp.ne.b32 p, $"
-     << ((expectedOutputRegisters * 2) + 2)
-     << ", 0;\n"
-        "wgmma.mma_async.sync.aligned.m"
-     << m << "n" << n << "k" << k << "." << outputTypeName << "."
-     << stringifyWGMMATypes(getTypeA()) << "."
-     << stringifyWGMMATypes(getTypeB());
-  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
-      NVVM::MMAIntOverflow::satfinite)
-    ss << ".satfinite";
-  ss << " {";
-  int regCnt = 0;
-  for (; regCnt < expectedOutputRegisters; ++regCnt) {
-    ss << "$" << regCnt;
-    if (regCnt != expectedOutputRegisters - 1)
-      ss << ", ";
-  }
-
-  ss << "},";
-  // Need to map read/write registers correctly.
-  regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
-  if (getTypeD() != WGMMATypes::s32) {
-    ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
-  }
-  // Don't add transpose parameters unless needed.
-  if (isF16) {
-    ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
-  }
-  ss << ";\n"
-     << "}\n";
-  ss.flush();
-  return ptx;
-}
-
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
-    RewriterBase &rewriter,
-    llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
-        &asmValues) {
-  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
-  if (getResults())
-    asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
-  if (getInouts())
-    asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
-  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
-  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
-  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
-                       mlir::NVVM::PTXRegisterMod::Read});
-  if (getTypeD() != WGMMATypes::s32) {
-    asmValues.push_back(
-        {makeConstantI32(rewriter,
-                         getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
-         mlir::NVVM::PTXRegisterMod::Read});
-    asmValues.push_back(
-        {makeConstantI32(rewriter,
-                         getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
-         mlir::NVVM::PTXRegisterMod::Read});
-  }
-  if (isF16) {
-    asmValues.push_back(
-        {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
-         mlir::NVVM::PTXRegisterMod::Read});
-    asmValues.push_back(
-        {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
-         mlir::NVVM::PTXRegisterMod::Read});
-  }
-}
-LogicalResult NVVM::FenceProxyOp::verify() {
-  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
-    return emitOpError() << "async_shared fence requires space attribute";
-  }
-  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
-    return emitOpError() << "only async_shared fence can have space attribute";
-  }
-  return success();
-}
-
-LogicalResult NVVM::SetMaxRegisterOp::verify() {
-  if (getRegCount() % 8)
-    return emitOpError("new register size must be multiple of 8");
-  if (getRegCount() < 24 || getRegCount() > 256)
-    return emitOpError("new register size must be in between 24 to 256");
-  return success();
-}
-
-LogicalResult NVVM::BarrierOp::verify() {
-  if (getNumberOfThreads() && !getBarrierId())
-    return emitOpError(
-        "barrier id is missing, it should be set between 0 to 15");
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// NVVMDialect initialization, type parsing, and registration.
-//===----------------------------------------------------------------------===//
-
-// TODO: This should be the llvm.nvvm dialect once this is supported.
-void NVVMDialect::initialize() {
-  addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
-      >();
-  addAttributes<
-#define GET_ATTRDEF_LIST
-#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
-      >();
-
-  // Support unknown operations because not all NVVM operations are
-  // registered.
-  allowUnknownOperations();
-  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
-  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
-}
-
-LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
-                                                    NamedAttribute attr) {
-  StringAttr attrName = attr.getName();
-  // Kernel function attribute should be attached to functions.
-  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
-    if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
-                             << "' attribute attached to unexpected op";
-    }
-  }
-  // If maxntid and reqntid exist, it must be an array with max 3 dim
-  if (attrName == NVVMDialect::getMaxntidAttrName() ||
-      attrName == NVVMDialect::getReqntidAttrName()) {
-    auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
-    if (!values || values.empty() || values.size() > 3)
-      return op->emitError()
-             << "'" << attrName
-             << "' attribute must be integer array with maximum 3 index";
-  }
-  // If minctasm and maxnreg exist, it must be an integer attribute
-  if (attrName == NVVMDialect::getMinctasmAttrName() ||
-      attrName == NVVMDialect::getMaxnregAttrName()) {
-    if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
-      return op->emitError()
-             << "'" << attrName << "' attribute must be integer constant";
-  }
-
-  return success();
-}
-
-LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
-                                                    unsigned regionIndex,
-                                                    unsigned argIndex,
-                                                    NamedAttribute argAttr) {
-  auto funcOp = dyn_cast<FunctionOpInterface>(op);
-  if (!funcOp)
-    return success();
-
-  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
-  StringAttr attrName = argAttr.getName();
-  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
-    if (!isKernel) {
-      return op->emitError()
-             << "'" << attrName
-             << "' attribute must be present only on kernel arguments";
-    }
-    if (!isa<UnitAttr>(argAttr.getValue()))
-      return op->emitError() << "'" << attrName << "' must be a unit attribute";
-    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
-      return op->emitError()
-             << "'" << attrName
-             << "' attribute requires the argument to also have attribute '"
-             << LLVM::LLVMDialect::getByValAttrName() << "'";
-    }
-  }
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// NVVM target attribute.
-//===----------------------------------------------------------------------===//
-LogicalResult
-NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                       int optLevel, StringRef triple, StringRef chip,
-                       StringRef features, DictionaryAttr flags,
-                       ArrayAttr files) {
-  if (optLevel < 0 || optLevel > 3) {
-    emitError() << "The optimization level must be a number between 0 and 3.";
-    return failure();
-  }
-  if (triple.empty()) {
-    emitError() << "The target triple cannot be empty.";
-    return failure();
-  }
-  if (chip.empty()) {
-    emitError() << "The target chip cannot be empty.";
-    return failure();
-  }
-  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
-        return attr && mlir::isa<StringAttr>(attr);
-      })) {
-    emitError() << "All the elements in the `link` array must be strings.";
-    return failure();
-  }
-  return success();
-}
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
-
-#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"

>From 7c9053ca23bbf94e10c7174e8c8c4ae178e5c0e7 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:03:29 +0800
Subject: [PATCH 3/7] Update NVVMDialect.cpp

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 1834 ++++++++++++--------
 1 file changed, 1141 insertions(+), 693 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 375e2951a037c..036a9a15af838 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1,694 +1,1142 @@
-// RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s
-
-// Same below, but using the `ConvertToLLVMPatternInterface` entry point
-// and the generic `convert-to-llvm` pass.
-// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: @init_mbarrier
-llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b" 
-  nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1 
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" 
-  nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
-  llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
-  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
-  //CHECK:  llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
-  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 
-  llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" 
-  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
-  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 
-  llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_try_wait_shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred       P1;
-  // CHECK-SAME: LAB_WAIT: 
-  // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
-  // CHECK-SAME: @P1 bra.uni DONE;
-  // CHECK-SAME: bra.uni     LAB_WAIT;
-  // CHECK-SAME: DONE:
-  // CHECK-SAME: }",
-  // CHECK-SAME: "r,r,r"
-   nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
-  llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_try_wait
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred       P1;
-  // CHECK-SAME: LAB_WAIT: 
-  // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
-  // CHECK-SAME: @P1 bra.uni DONE;
-  // CHECK-SAME: bra.uni     LAB_WAIT;
-  // CHECK-SAME: DONE:
-  // CHECK-SAME: }",
-  // CHECK-SAME: "l,r,r"
-  nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
-  llvm.return
-}
-
-// CHECK-LABEL: @async_cp
-func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
-  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-  nvvm.cp.async.shared.global %dst, %src, 16, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache =  cg : !llvm.ptr<3>, !llvm.ptr<1>
-  nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg : !llvm.ptr<3>, !llvm.ptr<1>
-  return
-}
-
-// CHECK-LABEL: @async_cp_zfill
-func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", 
-  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
-  nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", 
-  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
-  nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
-  return
-}
-
-// CHECK-LABEL: @cp_async_mbarrier_arrive
-func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
-  // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}}
-  nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
-  // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true}
-  nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
-  // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}}
-  nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3>
-  // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} {noinc = true}
-  nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3>
-  llvm.return
-}
-
-// CHECK-LABEL: @tma_load_3d_all
-func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_4d_all
-func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_5d_all
-func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_4d
-func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_5d
-func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_multicast1d
-func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_multicast2d
-func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask  predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_multicast3d
-func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask  predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_multicast4d
-func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask: !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_load_multicast5d
-func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask predicate=%p  : !llvm.ptr<3>, !llvm.ptr
-  return
-}
-
-// CHECK-LABEL: @tma_store_1d
-func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
-  return
-}
-
-// CHECK-LABEL: @tma_store_2d
-func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
-  return
-}
-
-// CHECK-LABEL: @tma_store_3d
-func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
-  return
-}
-
-// CHECK-LABEL: @tma_store_4d
-func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
-  return
-}
-
-// CHECK-LABEL: @tma_store_5d
-func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
-  // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
-
-  // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
-  return
-}
-
-// CHECK-LABEL: @wgmma_execute
-func.func @wgmma_execute() {  
-  nvvm.wgmma.fence.aligned
-  nvvm.wgmma.commit.group.sync.aligned
-  nvvm.wgmma.wait.group.sync.aligned 0
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
-  // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
-  
-
-  nvvm.wgmma.fence.aligned
-  nvvm.wgmma.commit.group.sync.aligned
-  nvvm.wgmma.wait.group.sync.aligned 5
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
-  // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
-  return
-}
-
-
-// -----
-
-!mat64f32 = !llvm.struct<(
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_f16_f16(
-// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
-func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{  
-  // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
-  // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
-  // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
-  // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  
-  // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{
-  // CHECK-SAME: reg .pred p;
-  // CHECK-SAME: setp.ne.b32 p, $34, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 
-  // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35,  $36, $37,  $38;\0A}\0A", 
-  // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" 
-  // CHECK-SAME: %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] 
-  // CHECK-SAME: : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32) 
-  // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-  // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
-  // CHECK: %[[DESCa:.+]] = llvm.add %[[ARG0]], %[[C2]] : i64
-  // CHECK: %[[DESCb:.+]] = llvm.add %[[ARG1]], %[[C2]] : i64
-  // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>  
-  // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> 
-  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{
-    // CHECK-SAME: .reg .pred p;
-    // CHECK-SAME: setp.ne.b32 p, $34, 0;
-    // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 
-    // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35,  $36, $37,  $38;\0A}\0A", 
-    // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n" 
-    // CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} 
-  %result = llvm.mlir.undef : !mat64f32
-  %result1 = nvvm.wgmma.mma_async 
-      %descA, %descB, %result,
-      #nvvm.shape<m = 64, n = 32, k = 16>, 
-      D [<f32>, #nvvm.wgmma_scale_out<zero>],
-      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
-      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
-      :!mat64f32 -> !mat64f32
-  %c2 = arith.constant 2 : i64
-  %descAnext = arith.addi %descA, %c2 : i64
-  %descBnext = arith.addi %descB, %c2 : i64
-  %result2 = nvvm.wgmma.mma_async 
-      %descAnext, %descBnext, %result1,
-      #nvvm.shape<m = 64, n = 32, k = 16>, 
-      D [<f32>, #nvvm.wgmma_scale_out<zero>],
-      A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], 
-      B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
-      : !mat64f32 -> !mat64f32
-  return %result2 : !mat64f32
-}
-
-// -----
-
-!mat16i32 = !llvm.struct<(i32, i32, i32, i32)>
-
-// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite(
-// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
-func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{  
-  %result = llvm.mlir.undef : !mat16i32
-// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
-// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
-// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
-// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
-// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
-// CHECK: %[[RES_2:.*]] =  llvm.inline_asm has_side_effects asm_dialect = att 
-// CHECK-SAME: "{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite 
-// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" 
-// CHECK-SAME: %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : 
-// CHECK-SAME: (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
-// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
-// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
-// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
-// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
-// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
-// CHECK-SAME: "{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite 
-// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", 
-// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" 
-// CHECK-SAME: %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
-// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
-// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
-// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
-// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
-// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
-// CHECK-SAME:"{
-// CHECK-SAME:.reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
-// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n" 
-// CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result, 
-      #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
-      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
-      : !mat16i32 -> !mat16i32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, 
-      #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
-      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
-      : !mat16i32 -> !mat16i32
-  %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, 
-      #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
-      A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
-      : !mat16i32 -> !mat16i32
-  return %result3 : !mat16i32
-}
-
-// CHECK-LABEL: @wgmma_s32_u8_u8(
-  // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
-func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {  
-// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
-// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
-// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
-// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
-// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
-// CHECK: %[[RES_2:.*]] =  llvm.inline_asm has_side_effects asm_dialect = att 
-// CHECK-SAME: "{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
-// CHECK-SAME: }\0A",
-// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] : 
-// CHECK-SAME:(i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
-// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
-// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
-// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
-// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
-// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
-// CHECK-SAME:"{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
-// CHECK-SAME: }\0A",
-// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
-// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
-// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
-// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
-// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
-// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att 
-// CHECK-SAME:"{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
-// CHECK-SAME:}\0A", 
-// CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}} 
-  %result = llvm.mlir.undef : !mat16i32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
-      #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [<s32>, #nvvm.wgmma_scale_out<one>],
-      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
-      : !mat16i32 -> !mat16i32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
-      #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [<s32>, #nvvm.wgmma_scale_out<one>],
-      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
-      : !mat16i32 -> !mat16i32
-  %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
-      #nvvm.shape<m = 64, n = 8, k = 32>, 
-      D [<s32>, #nvvm.wgmma_scale_out<one>],
-      A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
-      : !mat16i32 -> !mat16i32
-  return %result3 : !mat16i32
-}
-
-// -----
-
-!mat32f32 = !llvm.struct<(
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_tf32_tf32
-func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {  
-  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME:"{
-  // CHECK-SAME: .reg .pred p;
-  // CHECK-SAME: setp.ne.b32 p, $66, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
-  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred p;
-  // CHECK-SAME: setp.ne.b32 p, $66, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
-  %result = llvm.mlir.undef : !mat32f32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
-      #nvvm.shape<m = 64, n = 64, k = 8>, 
-      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
-       : !mat32f32 -> !mat32f32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
-      #nvvm.shape<m = 64, n = 64, k = 8>, 
-      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
-      : !mat32f32 -> !mat32f32
-  return %result2 : !mat32f32
-}
-
-
-// -----
-
-!mat32f32 = !llvm.struct<(
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_e4m3_e4m3
-func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {  
-  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
-  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
-  %result = llvm.mlir.undef : !mat32f32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
-      #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
-       : !mat32f32 -> !mat32f32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
-      #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
-      : !mat32f32 -> !mat32f32
-  return %result2 : !mat32f32
-}
-
-// -----
-
-!mat32f32 = !llvm.struct<(
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32, 
-  f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_e5m2_e4m3
-func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {  
-  // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
-  // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
-  // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67,  $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
-  %result = llvm.mlir.undef : !mat32f32
-  %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
-      #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
-       : !mat32f32 -> !mat32f32
-  %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
-      #nvvm.shape<m = 64, n = 64, k = 32>, 
-      D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
-      A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
-      : !mat32f32 -> !mat32f32
-  return %result2 : !mat32f32
-}
-
-// -----
-
-func.func @elect_one_leader_sync() {  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
-  // CHECK-SAME: .reg .u32 rx;
-  // CHECK-SAME: .reg .pred px;
-  // CHECK-SAME: mov.pred $0, 0;
-  // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
-  // CHECK-SAME: @px mov.pred $0, 1;
-  // CHECK-SAME: "=b"  : () -> i1
-  %cnd = nvvm.elect.sync -> i1 
-  return 
-}
-
-// -----
-
-// CHECK-LABEL: @stmatrix(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, 
-// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
-llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
-  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
-  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
-  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
-  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
-  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
-  llvm.return 
-}
-
-// -----
-
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
-  nvvm.prefetch.tensormap %desc : !llvm.ptr
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
-  nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
-  llvm.return
-}
-
-// -----
-
-func.func @set_max_register() {
-  // CHECK: nvvm.setmaxregister increase 232
-  nvvm.setmaxregister increase 232
-
-  // CHECK: nvvm.setmaxregister decrease 40
-  nvvm.setmaxregister decrease 40
-  func.return
-}
-
-// -----
-
-func.func @cp_async_bulk_commit() {
-  // CHECK: nvvm.cp.async.bulk.commit.group
-  nvvm.cp.async.bulk.commit.group
-  func.return
-}
-
-// -----
-
-func.func @cp_async_bulk_wait_group() {
-  // CHECK: nvvm.cp.async.bulk.wait_group 1
-  // CHECK: nvvm.cp.async.bulk.wait_group 0
-  // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
-  // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
-  nvvm.cp.async.bulk.wait_group 1
-  nvvm.cp.async.bulk.wait_group 0
-  nvvm.cp.async.bulk.wait_group 5 {read}
-  nvvm.cp.async.bulk.wait_group 0 {read}
-  func.return
-}
-
-// -----
-
-func.func @fence_mbarrier_init() {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;"
-  nvvm.fence.mbarrier.init
-  func.return 
-}
-// -----
-
-func.func @fence_proxy() {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", ""  : () -> ()
-  nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>}
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", ""  : () -> ()
-  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>}
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", ""  : () -> ()
-  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>}
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", ""  : () -> ()
-  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>}
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", ""  : () -> ()
-  nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
-  func.return
-}
-
-// -----
-
-// CHECK-LABEL: @llvm_nvvm_barrier_arrive
-// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
-llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0, $0;", "r" %[[numberOfThreads]] : (i32) -> ()
-  nvvm.barrier.arrive number_of_threads = %numberOfThreads
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1;", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
-  nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
-  llvm.return
+//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the types and operation details for the NVVM IR dialect in
+// MLIR, and the LLVM IR dialect.  It also registers the dialect.
+//
+// The NVVM dialect only contains GPU specific additions on top of the general
+// LLVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <optional>
+#include <string>
+
+using namespace mlir;
+using namespace NVVM;
+
+#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
+#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for NVVM ops
+//===----------------------------------------------------------------------===//
+
+static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
+  p << " " << op->getOperands();
+  if (op->getNumResults() > 0)
+    p << " : " << op->getResultTypes();
 }
+
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+  MLIRContext *context = parser.getContext();
+  auto int32Ty = IntegerType::get(context, 32);
+  auto int1Ty = IntegerType::get(context, 1);
+
+  SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
+  Type type;
+  return failure(parser.parseOperandList(ops) ||
+                 parser.parseOptionalAttrDict(result.attributes) ||
+                 parser.parseColonType(type) ||
+                 parser.addTypeToList(type, result.types) ||
+                 parser.resolveOperands(ops, {int32Ty, int1Ty},
+                                        parser.getNameLoc(), result.operands));
+}
+
+void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+  if (getCoordinates().empty() || getCoordinates().size() > 5)
+    return emitError("expects coordinates between 1 to 5 dimension");
+
+  // Check for im2col mode
+  if (!getIm2colOffsets().empty()) {
+    if (getCoordinates().size() < 3)
+      return emitError(
+          "to use im2col mode, the tensor has to be at least 3-dimensional");
+    if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
+      return emitError(
+          "im2col offsets must be 2 less than number of coordinates");
+  }
+  return success();
+}
+
+LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
+  if (getCoordinates().size() > 5)
+    return emitError("Maximum 5 coordinates and dimension is supported.");
+  return success();
+}
+
+LogicalResult CpAsyncOp::verify() {
+  if (getModifier() != LoadCacheModifierKind::CG &&
+      getModifier() != LoadCacheModifierKind::CA)
+    return emitError("Only CG and CA cache modifiers are supported.");
+  if (getSize() != 4 && getSize() != 8 && getSize() != 16)
+    return emitError("expected byte size to be either 4, 8 or 16.");
+  if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
+    return emitError("CG cache modifier is only support for 16 bytes copy.");
+  return success();
+}
+
+// Given the element type of an operand and whether or not it is an accumulator,
+// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
+// operand's element type.
+std::optional<mlir::NVVM::MMATypes>
+MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
+  auto half2Type =
+      LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+  if (operandElType.isF64())
+    return NVVM::MMATypes::f64;
+  if (operandElType.isF16() || operandElType == half2Type)
+    return NVVM::MMATypes::f16;
+  if (operandElType.isF32() && isAccumulator)
+    return NVVM::MMATypes::f32;
+  if (operandElType.isF32() && !isAccumulator)
+    return NVVM::MMATypes::tf32;
+  if (llvm::isa<IntegerType>(operandElType)) {
+    if (isAccumulator)
+      return NVVM::MMATypes::s32;
+    return std::nullopt;
+  }
+
+  if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
+    if (structType.getBody().empty())
+      return std::nullopt;
+    return inferOperandMMAType(structType.getBody()[0], isAccumulator);
+  }
+
+  return std::nullopt;
+}
+
+static bool isInt4PtxType(MMATypes type) {
+  return (type == MMATypes::u4 || type == MMATypes::s4);
+}
+
+static bool isInt8PtxType(MMATypes type) {
+  return (type == MMATypes::u8 || type == MMATypes::s8);
+}
+
+static bool isIntegerPtxType(MMATypes type) {
+  return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
+         type == MMATypes::s32;
+}
+
+MMATypes MmaOp::accumPtxType() {
+  std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
+      getODSOperands(2).getTypes().front(), /*isAccum=*/true);
+  assert(val.has_value() && "accumulator PTX type should always be inferrable");
+  return val.value();
+}
+
+MMATypes MmaOp::resultPtxType() {
+  std::optional<mlir::NVVM::MMATypes> val =
+      inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
+  assert(val.has_value() && "result PTX type should always be inferrable");
+  return val.value();
+}
+
+void MmaOp::print(OpAsmPrinter &p) {
+  SmallVector<Type, 4> regTypes;
+  struct OperandFragment {
+    StringRef operandName;
+    StringRef ptxTypeAttr;
+    SmallVector<Value, 4> regs;
+    explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+        : operandName(name), ptxTypeAttr(ptxTypeName) {}
+  };
+
+  std::array<OperandFragment, 3> frags{
+      OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+      OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+      OperandFragment("C", "")};
+  SmallVector<StringRef, 4> ignoreAttrNames{
+      mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
+
+  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+    auto &frag = frags[fragIdx];
+    auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+    for (auto operandIdx = varOperandSpec.first;
+         operandIdx < varOperandSpec.first + varOperandSpec.second;
+         operandIdx++) {
+      frag.regs.push_back(this->getOperand(operandIdx));
+      if (operandIdx == 0) {
+        regTypes.push_back(this->getOperand(operandIdx).getType());
+      }
+    }
+    std::optional<MMATypes> inferredType =
+        inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
+    if (inferredType)
+      ignoreAttrNames.push_back(frag.ptxTypeAttr);
+  }
+
+  auto printMmaOperand = [&](const OperandFragment &frag) -> void {
+    p << " " << frag.operandName;
+    p << "[";
+    p.printOperands(frag.regs);
+    p << "] ";
+  };
+
+  for (const auto &frag : frags) {
+    printMmaOperand(frag);
+  }
+
+  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+  // Print the types of the operands and result.
+  p << " : " << "(";
+  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+                                             frags[1].regs[0].getType(),
+                                             frags[2].regs[0].getType()},
+                        p);
+  p << ")";
+  p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  ValueRange operandA, ValueRange operandB, ValueRange operandC,
+                  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
+                  std::optional<MMAIntOverflow> intOverflow,
+                  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+                  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
+
+  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+  MLIRContext *ctx = builder.getContext();
+  result.addAttribute(
+      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+  result.addOperands(operandA);
+  result.addOperands(operandB);
+  result.addOperands(operandC);
+
+  if (multiplicandPtxTypes) {
+    result.addAttribute("multiplicandAPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+    result.addAttribute("multiplicandBPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+  } else {
+    if (auto res = inferOperandMMAType(operandA[0].getType(), false))
+      result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+    if (auto res = inferOperandMMAType(operandB[0].getType(), false))
+      result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+  }
+
+  if (multiplicandLayouts) {
+    result.addAttribute("layoutA",
+                        MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
+    result.addAttribute("layoutB",
+                        MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
+  } else {
+    result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
+    result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
+  }
+
+  if (intOverflow.has_value())
+    result.addAttribute("intOverflowBehavior",
+                        MMAIntOverflowAttr::get(ctx, *intOverflow));
+  if (b1Op.has_value())
+    result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
+
+  result.addTypes(resultType);
+  result.addAttribute(
+      MmaOp::getOperandSegmentSizeAttr(),
+      builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+                                    static_cast<int32_t>(operandB.size()),
+                                    static_cast<int32_t>(operandC.size())}));
+}
+
+// <operation> :=
+//   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
+//   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
+//     `->` type($res)
+ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
+  struct OperandFragment {
+    std::optional<MMATypes> elemtype;
+    SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+    SmallVector<Type> regTypes;
+  };
+
+  Builder &builder = parser.getBuilder();
+  std::array<OperandFragment, 4> frags;
+
+  NamedAttrList namedAttributes;
+
+  // A helper to parse the operand segments.
+  auto parseMmaOperand = [&](StringRef operandName,
+                             OperandFragment &frag) -> LogicalResult {
+    if (parser.parseKeyword(operandName).failed())
+      return failure();
+    if (parser
+            .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+            .failed())
+      return failure();
+    return success();
+  };
+
+  // Parse the operand segments.
+  if (parseMmaOperand("A", frags[0]).failed())
+    return failure();
+  if (parseMmaOperand("B", frags[1]).failed())
+    return failure();
+  if (parseMmaOperand("C", frags[2]).failed())
+    return failure();
+
+  if (parser.parseOptionalAttrDict(namedAttributes).failed())
+    return failure();
+
+  // Parse the type specification and resolve operands.
+  SmallVector<Type, 3> operandTypes;
+  if (failed(parser.parseColon()))
+    return failure();
+  if (failed(parser.parseLParen()))
+    return failure();
+  if (failed(parser.parseTypeList(operandTypes)))
+    return failure();
+  if (failed(parser.parseRParen()))
+    if (operandTypes.size() != 3)
+      return parser.emitError(
+          parser.getNameLoc(),
+          "expected one type for each operand segment but got " +
+              Twine(operandTypes.size()) + " types");
+  for (const auto &iter : llvm::enumerate(operandTypes)) {
+    auto &frag = frags[iter.index()];
+    frag.regTypes.resize(frag.regs.size(), iter.value());
+    if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+                                      parser.getNameLoc(), result.operands)))
+      return failure();
+    frag.elemtype =
+        inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
+  }
+
+  Type resultType;
+  if (parser.parseArrow() || parser.parseType(resultType))
+    return failure();
+  frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
+
+  std::array<StringRef, 2> names{"multiplicandAPtxType",
+                                 "multiplicandBPtxType"};
+  for (unsigned idx = 0; idx < names.size(); idx++) {
+    const auto &frag = frags[idx];
+    std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+    if (!frag.elemtype.has_value() && !attr.has_value()) {
+      return parser.emitError(
+          parser.getNameLoc(),
+          "attribute " + names[idx] +
+              " is not provided explicitly and cannot be inferred");
+    }
+    if (!attr.has_value())
+      result.addAttribute(
+          names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+  }
+
+  result.addTypes(resultType);
+  if (!namedAttributes.empty())
+    result.addAttributes(namedAttributes);
+  result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(frags[0].regs.size()),
+                          static_cast<int32_t>(frags[1].regs.size()),
+                          static_cast<int32_t>(frags[2].regs.size()),
+                      }));
+  return success();
+}
+
+LogicalResult MmaOp::verify() {
+  MLIRContext *context = getContext();
+  auto f16Ty = Float16Type::get(context);
+  auto i32Ty = IntegerType::get(context, 32);
+  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+  auto f32Ty = Float32Type::get(context);
+  auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+      context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+  auto s32x4StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+  auto f32x8StructTy =
+      LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+  auto f16x2x2StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+  auto f32x4StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+  auto s32x2StructTy =
+      LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+  std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
+                                  getShapeAttr().getK()};
+
+  // These variables define the set of allowed data types for matrices A, B, C,
+  // and result.
+  using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+  using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+  AllowedShapes allowedShapes;
+  AllowedTypes expectedA;
+  AllowedTypes expectedB;
+  AllowedTypes expectedC;
+  SmallVector<Type> expectedResult;
+
+  // When M = 16, we just need to calculate the number of 8xk tiles, where
+  // k is a factor that depends on the data type.
+  if (mmaShape[0] == 16) {
+    int64_t kFactor;
+    Type multiplicandFragType;
+    switch (*getMultiplicandAPtxType()) {
+    case MMATypes::tf32:
+      kFactor = 4;
+      multiplicandFragType = i32Ty;
+      expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+          context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+      break;
+    case MMATypes::f16:
+    case MMATypes::bf16:
+      kFactor = 8;
+      multiplicandFragType = f16x2Ty;
+      expectedResult.push_back(f16x2x2StructTy);
+      expectedResult.push_back(f32x4StructTy);
+      break;
+    case MMATypes::s4:
+    case MMATypes::u4:
+      kFactor = 32;
+      break;
+    case MMATypes::b1:
+      kFactor = 128;
+      break;
+    case MMATypes::s8:
+    case MMATypes::u8:
+      kFactor = 16;
+      break;
+    default:
+      return emitError("invalid shape or multiplicand type: " +
+                       stringifyEnum(getMultiplicandAPtxType().value()));
+    }
+
+    if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+      expectedResult.push_back(s32x4StructTy);
+      expectedC.emplace_back(4, i32Ty);
+      multiplicandFragType = i32Ty;
+    } else {
+      expectedC.emplace_back(2, f16x2Ty);
+      expectedC.emplace_back(4, f32Ty);
+    }
+
+    int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
+    int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+    expectedA.emplace_back(unitA, multiplicandFragType);
+    expectedB.emplace_back(unitB, multiplicandFragType);
+    allowedShapes.push_back({16, 8, kFactor});
+    allowedShapes.push_back({16, 8, kFactor * 2});
+  }
+
+  // In the M=8 case, there is only 1 possible case per data type.
+  if (mmaShape[0] == 8) {
+    if (*getMultiplicandAPtxType() == MMATypes::f16) {
+      expectedA.emplace_back(2, f16x2Ty);
+      expectedB.emplace_back(2, f16x2Ty);
+      expectedResult.push_back(f16x2x4StructTy);
+      expectedResult.push_back(f32x8StructTy);
+      expectedC.emplace_back(4, f16x2Ty);
+      expectedC.emplace_back(8, f32Ty);
+      allowedShapes.push_back({8, 8, 4});
+    }
+    if (*getMultiplicandAPtxType() == MMATypes::f64) {
+      Type f64Ty = Float64Type::get(context);
+      expectedA.emplace_back(1, f64Ty);
+      expectedB.emplace_back(1, f64Ty);
+      expectedC.emplace_back(2, f64Ty);
+      // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+      expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+          context, SmallVector<Type>(2, f64Ty)));
+      allowedShapes.push_back({8, 8, 4});
+    }
+    if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+      expectedA.push_back({i32Ty});
+      expectedB.push_back({i32Ty});
+      expectedC.push_back({i32Ty, i32Ty});
+      expectedResult.push_back(s32x2StructTy);
+      if (isInt4PtxType(getMultiplicandAPtxType().value()))
+        allowedShapes.push_back({8, 8, 32});
+      if (isInt8PtxType(getMultiplicandAPtxType().value()))
+        allowedShapes.push_back({8, 8, 16});
+      if (getMultiplicandAPtxType().value() == MMATypes::b1)
+        allowedShapes.push_back({8, 8, 128});
+    }
+  }
+
+  std::string errorMessage;
+  llvm::raw_string_ostream errorStream(errorMessage);
+
+  // Check that we matched an existing shape/dtype combination.
+  if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+      !llvm::is_contained(allowedShapes, mmaShape)) {
+    errorStream << "unimplemented variant for MMA shape <";
+    llvm::interleaveComma(mmaShape, errorStream);
+    errorStream << ">";
+    return emitOpError(errorMessage);
+  }
+
+  // Verify the operand types for segments of A, B, and C operands.
+  std::array<StringRef, 3> operandNames{"A", "B", "C"};
+  for (const auto &iter : llvm::enumerate(
+           SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+    auto spec = this->getODSOperandIndexAndLength(iter.index());
+    SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+                                      operand_type_begin() + spec.first +
+                                          spec.second);
+    bool match = llvm::is_contained(iter.value(), operandTySeg);
+
+    if (!match) {
+      errorStream << "Could not match types for the "
+                  << operandNames[iter.index()]
+                  << " operands; expected one of ";
+      for (const auto &x : iter.value()) {
+        errorStream << x.size() << "x" << x[0] << " ";
+      }
+      errorStream << "but got ";
+      llvm::interleaveComma(operandTySeg, errorStream);
+      return emitOpError(errorStream.str());
+    }
+  }
+
+  // Check the result type
+  if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+        return expectedResultType == getResult().getType();
+      })) {
+    errorStream
+        << "Could not match allowed types for the result; expected one of ";
+    llvm::interleaveComma(expectedResult, errorStream);
+    errorStream << " but got " << getResult().getType();
+    return emitOpError(errorStream.str());
+  }
+
+  // Ensure that binary MMA variants have a b1 MMA operation defined.
+  if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
+    return emitOpError("op requires " + getB1OpAttrName().strref() +
+                       " attribute");
+  }
+
+  // Ensure int4/int8 MMA variants specify the accum overflow behavior
+  // attribute.
+  if (isInt4PtxType(*getMultiplicandAPtxType()) ||
+      isInt8PtxType(*getMultiplicandAPtxType())) {
+    if (!getIntOverflowBehavior())
+      return emitOpError("op requires " +
+                         getIntOverflowBehaviorAttrName().strref() +
+                         " attribute");
+  }
+
+  return success();
+}
+
+LogicalResult ShflOp::verify() {
+  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
+    return success();
+  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+  auto elementType = (type && type.getBody().size() == 2)
+                         ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
+                         : nullptr;
+  if (!elementType || elementType.getWidth() != 1)
+    return emitError("expected return type to be a two-element struct with "
+                     "i1 as the second element");
+  return success();
+}
+
+std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
+                                                   NVVM::MMAFrag frag, int nRow,
+                                                   int nCol,
+                                                   MLIRContext *context) {
+  unsigned numberElements = 0;
+  Type elementType;
+  OpBuilder builder(context);
+  Type f16x2 = VectorType::get(2, builder.getF16Type());
+  if (type == NVVM::MMATypes::f16) {
+    elementType = f16x2;
+    if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+      numberElements = 8;
+    else
+      numberElements = 4;
+  } else if (type == NVVM::MMATypes::f32) {
+    elementType = builder.getF32Type();
+    numberElements = 8;
+  } else if (type == NVVM::MMATypes::tf32) {
+    elementType = builder.getI32Type();
+    numberElements = 4;
+  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
+    elementType = builder.getI32Type();
+    int parallelSize = 0;
+    if (frag == NVVM::MMAFrag::a)
+      parallelSize = nRow;
+    if (frag == NVVM::MMAFrag::b)
+      parallelSize = nCol;
+
+    // m == 16 && n == 16 && k == 16
+    if (parallelSize == 16)
+      numberElements = 2;
+    // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
+    else if (parallelSize == 8)
+      numberElements = 1;
+    else if (parallelSize == 32)
+      numberElements = 4;
+  } else if (type == NVVM::MMATypes::s32) {
+    elementType = builder.getI32Type();
+    numberElements = 8;
+  }
+  assert(numberElements != 0 && elementType != nullptr);
+  return std::make_pair(elementType, numberElements);
+}
+
+static std::pair<mlir::Type, unsigned>
+inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
+                    int k, MLIRContext *context) {
+  int nRow, nCol;
+  if (frag == NVVM::MMAFrag::a) {
+    nRow = m;
+    nCol = k;
+  } else if (frag == NVVM::MMAFrag::b) {
+    nRow = k;
+    nCol = n;
+  } else {
+    nRow = m;
+    nCol = n;
+  }
+  assert(nRow && nCol);
+  return inferMMAType(type, frag, nRow, nCol, context);
+}
+
+LogicalResult NVVM::WMMALoadOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
+      addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected source pointer in memory "
+                       "space 0, 1, 3");
+
+  if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
+                                       getEltype(), getFrag()) == 0)
+    return emitOpError() << "invalid attribute combination";
+  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
+      getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+  Type dstType = LLVM::LLVMStructType::getLiteral(
+      getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
+  if (getType() != dstType)
+    return emitOpError("expected destination type is a structure of ")
+           << typeInfo.second << " elements of type " << typeInfo.first;
+  return success();
+}
+
+LogicalResult NVVM::WMMAStoreOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
+      addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected operands to be a source pointer in memory "
+                       "space 0, 1, 3");
+
+  if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
+                                        getEltype()) == 0)
+    return emitOpError() << "invalid attribute combination";
+  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
+      getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
+  if (getArgs().size() != typeInfo.second)
+    return emitOpError() << "expected " << typeInfo.second << " data operands";
+  if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
+        return operands.getType() != typeInfo.first;
+      }))
+    return emitOpError() << "expected data operands of type " << typeInfo.first;
+  return success();
+}
+
+LogicalResult NVVM::WMMAMmaOp::verify() {
+  if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
+                                      getLayoutB(), getEltypeA(),
+                                      getEltypeB()) == 0)
+    return emitOpError() << "invalid attribute combination";
+  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
+      getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
+  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
+      getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
+  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
+      getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
+  SmallVector<Type, 32> arguments;
+  arguments.append(typeInfoA.second, typeInfoA.first);
+  arguments.append(typeInfoB.second, typeInfoB.first);
+  arguments.append(typeInfoC.second, typeInfoC.first);
+  unsigned numArgs = arguments.size();
+  if (getArgs().size() != numArgs)
+    return emitOpError() << "expected " << numArgs << " arguments";
+  for (unsigned i = 0; i < numArgs; i++) {
+    if (getArgs()[i].getType() != arguments[i])
+      return emitOpError() << "expected argument " << i << " to be of type "
+                           << arguments[i];
+  }
+  Type dstType = LLVM::LLVMStructType::getLiteral(
+      getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
+  if (getType() != dstType)
+    return emitOpError("expected destination type is a structure of ")
+           << typeInfoC.second << " elements of type " << typeInfoC.first;
+  return success();
+}
+
+LogicalResult NVVM::LdMatrixOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected source pointer in memory space 3");
+
+  if (getNum() != 1 && getNum() != 2 && getNum() != 4)
+    return emitOpError("expected num attribute to be 1, 2 or 4");
+
+  Type i32 = IntegerType::get(getContext(), 32);
+  if (getNum() == 1 && getType() != i32)
+    return emitOpError("expected destination type is i32");
+  if (getNum() == 2 || getNum() == 4) {
+    Type dstType = LLVM::LLVMStructType::getLiteral(
+        getContext(), SmallVector<Type>(getNum(), i32));
+    if (getType() != dstType)
+      return emitOpError("expected destination type is a structure of ")
+             << getNum() << " elements of type i32";
+  }
+  return success();
+}
+
+LogicalResult NVVM::StMatrixOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected source pointer in memory space 3");
+
+  int numMatrix = getSources().size();
+  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+    return emitOpError("expected num attribute to be 1, 2 or 4");
+
+  return success();
+}
+
+FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
+  if (typeA == NVVM::WGMMATypes::tf32)
+    return 8;
+  if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
+    return 16;
+  if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
+    return 32;
+  if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
+    return 32;
+  if (typeA == NVVM::WGMMATypes::b1)
+    return 256;
+  return failure();
+}
+
+LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
+                                     NVVM::WGMMATypes typeA,
+                                     NVVM::WGMMATypes typeB) {
+  switch (typeA) {
+  case NVVM::WGMMATypes::f16:
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        typeB == NVVM::WGMMATypes::f16)
+      return success();
+    break;
+  case NVVM::WGMMATypes::tf32:
+    if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
+      return success();
+    break;
+  case NVVM::WGMMATypes::u8:
+  case NVVM::WGMMATypes::s8:
+    if (typeD == NVVM::WGMMATypes::s32 &&
+        (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
+      return success();
+    break;
+  case NVVM::WGMMATypes::b1:
+    if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
+      return success();
+    break;
+  case NVVM::WGMMATypes::bf16:
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        typeB == NVVM::WGMMATypes::bf16)
+      return success();
+    break;
+  case NVVM::WGMMATypes::e4m3:
+  case NVVM::WGMMATypes::e5m2:
+    if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+        (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
+      return success();
+    break;
+  case WGMMATypes::f32:
+  case WGMMATypes::s32:
+    llvm_unreachable("unsupported input types");
+    break;
+  }
+  return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
+  SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
+                               72,  80,  88,  96,  104, 112, 120, 128,
+                               136, 144, 152, 160, 168, 176, 184, 192,
+                               200, 208, 216, 224, 232, 240, 248, 256};
+  SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
+                                    80,  96,  112, 128, 144, 160,
+                                    176, 192, 208, 224, 240, 256};
+  switch (typeA) {
+  case WGMMATypes::f16:
+  case WGMMATypes::tf32:
+  case WGMMATypes::bf16:
+  case WGMMATypes::e4m3:
+  case WGMMATypes::e5m2:
+    if (llvm::is_contained(allowedN, sizeN))
+      return success();
+    break;
+  case WGMMATypes::u8:
+  case WGMMATypes::s8:
+  case WGMMATypes::b1:
+    if (llvm::is_contained(allowedNshort, sizeN))
+      return success();
+    break;
+  case WGMMATypes::f32:
+  case WGMMATypes::s32:
+    llvm_unreachable("unsupported input types");
+    break;
+  }
+  return failure();
+}
+
+LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
+  Value outValue = getResults();
+  auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+  if (!stype)
+    return emitOpError() << "expected results to be struct";
+  int outputSize = stype.getBody().size();
+  WGMMATypes typeD = getTypeD();
+  WGMMATypes typeA = getTypeA();
+  WGMMATypes typeB = getTypeB();
+
+  for (Type t : stype.getBody()) {
+    if (t != stype.getBody().front())
+      return emitOpError()
+             << "all elements in struct must be same type but there is " << t;
+  }
+
+  if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+      typeD != WGMMATypes::s32) {
+    return emitOpError() << "does not support the given output type "
+                         << NVVM::stringifyWGMMATypes(typeD);
+  }
+  if (typeD == WGMMATypes::s32 &&
+      (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
+    return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+  }
+
+  if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
+    return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
+                         << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
+                         << NVVM::stringifyWGMMATypes(typeB)
+                         << ", it is not supported.";
+  }
+
+  // Check M
+  if (getShape().getM() != 64)
+    return emitOpError() << "shape 'm' must be 64";
+
+  // Check K
+  FailureOr<int> allowedK = getAllowedSizeK(typeA);
+  if (failed(allowedK) || allowedK.value() != getShape().getK())
+    return emitOpError() << "shape 'k' must be " << allowedK.value()
+                         << " for input type "
+                         << NVVM::stringifyWGMMATypes(typeA);
+
+  // Check N
+  if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
+    return emitOpError() << "has input type "
+                         << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
+                         << getShape().getN() << ", it is not supported.";
+  }
+
+  // Check transpose (only available for f16/bf16)
+  if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
+      (getLayoutA() == mlir::NVVM::MMALayout::col ||
+       getLayoutB() == mlir::NVVM::MMALayout::col)) {
+    return emitOpError()
+           << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
+           << " and layout_b = " << stringifyMMALayout(getLayoutB())
+           << " for input types " << stringifyWGMMATypes(typeA) << " and "
+           << stringifyWGMMATypes(typeB)
+           << " requires transpose. However, this is only supported for: "
+           << stringifyMMATypes(MMATypes::f16) << " and "
+           << stringifyMMATypes(MMATypes::bf16);
+  }
+
+  // Check result registers
+  int expectedOutput = 0;
+  if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
+    expectedOutput = getShape().getN() / 2;
+  if (typeD == WGMMATypes::f16)
+    expectedOutput = getShape().getN() / 4;
+  if (outputSize != expectedOutput) {
+    return emitOpError() << "results " << expectedOutput
+                         << ", however output struct has " << outputSize
+                         << " elements";
+  }
+  // Check satfinite (only available for s32 accumulator)
+  if (typeD != WGMMATypes::s32 &&
+      getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+          NVVM::MMAIntOverflow::satfinite) {
+    return emitOpError()
+           << " `satfinite` can be only used with s32 accumulator, however "
+              "the current accumulator is "
+           << NVVM::stringifyWGMMATypes(typeD);
+  }
+
+  return success();
+}
+
+std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
+
+  int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
+  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
+
+  StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
+
+  int expectedOutputRegisters = 0;
+  if (getTypeD() == WGMMATypes::f16)
+    expectedOutputRegisters = getShape().getN() / 4;
+  else
+    expectedOutputRegisters = getShape().getN() / 2;
+
+  std::string ptx;
+  llvm::raw_string_ostream ss(ptx);
+
+  ss << "{\n"
+        ".reg .pred p;\n"
+        "setp.ne.b32 p, $"
+     << ((expectedOutputRegisters * 2) + 2)
+     << ", 0;\n"
+        "wgmma.mma_async.sync.aligned.m"
+     << m << "n" << n << "k" << k << "." << outputTypeName << "."
+     << stringifyWGMMATypes(getTypeA()) << "."
+     << stringifyWGMMATypes(getTypeB());
+  if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+      NVVM::MMAIntOverflow::satfinite)
+    ss << ".satfinite";
+  ss << " {";
+  int regCnt = 0;
+  for (; regCnt < expectedOutputRegisters; ++regCnt) {
+    ss << "$" << regCnt;
+    if (regCnt != expectedOutputRegisters - 1)
+      ss << ", ";
+  }
+
+  ss << "},";
+  // Need to map read/write registers correctly.
+  regCnt = (regCnt * 2);
+  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+  if (getTypeD() != WGMMATypes::s32) {
+    ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
+  }
+  // Don't add transpose parameters unless needed.
+  if (isF16) {
+    ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
+  }
+  ss << ";\n"
+     << "}\n";
+  ss.flush();
+  return ptx;
+}
+
+void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+    RewriterBase &rewriter,
+    llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+        &asmValues) {
+  bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
+  if (getResults())
+    asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
+  if (getInouts())
+    asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
+  asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
+  asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
+  asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
+                       mlir::NVVM::PTXRegisterMod::Read});
+  if (getTypeD() != WGMMATypes::s32) {
+    asmValues.push_back(
+        {makeConstantI32(rewriter,
+                         getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+         mlir::NVVM::PTXRegisterMod::Read});
+    asmValues.push_back(
+        {makeConstantI32(rewriter,
+                         getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+         mlir::NVVM::PTXRegisterMod::Read});
+  }
+  if (isF16) {
+    asmValues.push_back(
+        {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
+         mlir::NVVM::PTXRegisterMod::Read});
+    asmValues.push_back(
+        {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
+         mlir::NVVM::PTXRegisterMod::Read});
+  }
+}
+LogicalResult NVVM::FenceProxyOp::verify() {
+  if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
+    return emitOpError() << "async_shared fence requires space attribute";
+  }
+  if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
+    return emitOpError() << "only async_shared fence can have space attribute";
+  }
+  return success();
+}
+
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+  if (getRegCount() % 8)
+    return emitOpError("new register size must be multiple of 8");
+  if (getRegCount() < 24 || getRegCount() > 256)
+    return emitOpError("new register size must be in between 24 to 256");
+  return success();
+}
+
+LogicalResult NVVM::BarrierOp::verify() {
+  if (getNumberOfThreads() && !getBarrierId())
+    return emitOpError(
+        "barrier id is missing, it should be set between 0 to 15");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVMDialect initialization, type parsing, and registration.
+//===----------------------------------------------------------------------===//
+
+// TODO: This should be the llvm.nvvm dialect once this is supported.
+void NVVMDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
+      >();
+
+  // Support unknown operations because not all NVVM operations are
+  // registered.
+  allowUnknownOperations();
+  declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
+  declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
+}
+
+LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
+                                                    NamedAttribute attr) {
+  StringAttr attrName = attr.getName();
+  // Kernel function attribute should be attached to functions.
+  if (attrName == NVVMDialect::getKernelFuncAttrName()) {
+    if (!isa<LLVM::LLVMFuncOp>(op)) {
+      return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
+                             << "' attribute attached to unexpected op";
+    }
+  }
+  // If maxntid and reqntid exist, it must be an array with max 3 dim
+  if (attrName == NVVMDialect::getMaxntidAttrName() ||
+      attrName == NVVMDialect::getReqntidAttrName()) {
+    auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
+    if (!values || values.empty() || values.size() > 3)
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute must be integer array with maximum 3 index";
+  }
+  // If minctasm and maxnreg exist, it must be an integer attribute
+  if (attrName == NVVMDialect::getMinctasmAttrName() ||
+      attrName == NVVMDialect::getMaxnregAttrName()) {
+    if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
+      return op->emitError()
+             << "'" << attrName << "' attribute must be integer constant";
+  }
+
+  return success();
+}
+
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+                                                    unsigned regionIndex,
+                                                    unsigned argIndex,
+                                                    NamedAttribute argAttr) {
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+
+  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+  StringAttr attrName = argAttr.getName();
+  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+    if (!isKernel) {
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute must be present only on kernel arguments";
+    }
+    if (!isa<UnitAttr>(argAttr.getValue()))
+      return op->emitError() << "'" << attrName << "' must be a unit attribute";
+    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute requires the argument to also have attribute '"
+             << LLVM::LLVMDialect::getByValAttrName() << "'";
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM target attribute.
+//===----------------------------------------------------------------------===//
+LogicalResult
+NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                       int optLevel, StringRef triple, StringRef chip,
+                       StringRef features, DictionaryAttr flags,
+                       ArrayAttr files) {
+  if (optLevel < 0 || optLevel > 3) {
+    emitError() << "The optimization level must be a number between 0 and 3.";
+    return failure();
+  }
+  if (triple.empty()) {
+    emitError() << "The target triple cannot be empty.";
+    return failure();
+  }
+  if (chip.empty()) {
+    emitError() << "The target chip cannot be empty.";
+    return failure();
+  }
+  if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
+        return attr && mlir::isa<StringAttr>(attr);
+      })) {
+    emitError() << "All the elements in the `link` array must be strings.";
+    return failure();
+  }
+  return success();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"

>From 740ba25c18140a87f5311725daf9ae74663dc067 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:08:00 +0800
Subject: [PATCH 4/7] change check cases when ab cannot be transposed in wgmma

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 036a9a15af838..48f44165ccc58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -880,7 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   // Check transpose (only available for f16/bf16)
   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
-       getLayoutB() == mlir::NVVM::MMALayout::col)) {
+       getLayoutB() == mlir::NVVM::MMALayout::row)) {
     return emitOpError()
            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
            << " and layout_b = " << stringifyMMALayout(getLayoutB())

>From b4a2937afd11ee53332f36febed83bc1306c58dc Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:10:39 +0800
Subject: [PATCH 5/7] change some test cases to match the previous check change

---
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e..375e2951a037c 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -397,19 +397,19 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2, 
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
       A [<s8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   return %result3 : !mat16i32
 }
@@ -458,19 +458,19 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
       #nvvm.shape<m = 64, n = 8, k = 32>, 
       D [<s32>, #nvvm.wgmma_scale_out<one>],
       A [<u8>, #nvvm.wgmma_scale_in<one>, <row>], 
-      B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
+      B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
       : !mat16i32 -> !mat16i32
   return %result3 : !mat16i32
 }
@@ -500,13 +500,13 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
       #nvvm.shape<m = 64, n = 64, k = 8>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 8>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }
@@ -533,13 +533,13 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }
@@ -565,13 +565,13 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
        : !mat32f32 -> !mat32f32
   %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
       #nvvm.shape<m = 64, n = 64, k = 32>, 
       D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
       A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>], 
-      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
+      B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
       : !mat32f32 -> !mat32f32
   return %result2 : !mat32f32
 }

>From f6404c653d97863b973775460334a7f9e2b651ca Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 21:16:37 +0800
Subject: [PATCH 6/7] add comment

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 48f44165ccc58..ceeb1168eb13e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -878,6 +878,9 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
   }
 
   // Check transpose (only available for f16/bf16)
+  // Matrices A should be stored in row-major and B in column-major.
+  // Only f16/bf16 matrices can be stored in either column-major or row-major 
+  // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||
        getLayoutB() == mlir::NVVM::MMALayout::row)) {

>From 319ec0edf21b483379b3490f8b2129b8555c6719 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 22:33:28 +0800
Subject: [PATCH 7/7] format code

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ceeb1168eb13e..4d1896551101e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -879,7 +879,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
 
   // Check transpose (only available for f16/bf16)
   // Matrices A should be stored in row-major and B in column-major.
-  // Only f16/bf16 matrices can be stored in either column-major or row-major 
+  // Only f16/bf16 matrices can be stored in either column-major or row-major
   // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
       (getLayoutA() == mlir::NVVM::MMALayout::col ||



More information about the Mlir-commits mailing list