[Mlir-commits] [mlir] fix a bug when checking layout that cannot be transposed (PR #97538)
bangyu shen
llvmlistbot at llvm.org
Wed Jul 3 02:08:10 PDT 2024
https://github.com/shubaoyu2 updated https://github.com/llvm/llvm-project/pull/97538
>From f2f072920b322857ed3ad623aac30cfbc0e27265 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 16:36:56 +0800
Subject: [PATCH 1/4] fix a bug when checking layout that cannot be transposed
the WGMMA expect layouts for A/B are row/col, the transposed version should be col/row. when checking other datatypes cannot use transposed layout, it should reject col-major for A and row-major for B
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 036a9a15af838..48f44165ccc58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -880,7 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
// Check transpose (only available for f16/bf16)
if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
- getLayoutB() == mlir::NVVM::MMALayout::col)) {
+ getLayoutB() == mlir::NVVM::MMALayout::row)) {
return emitOpError()
<< "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
<< " and layout_b = " << stringifyMMALayout(getLayoutB())
>From aec1049febb1893a5cdb57ff7c2981001dbe355d Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:01:35 +0800
Subject: [PATCH 2/4] ch
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 1834 ++++++++------------
1 file changed, 693 insertions(+), 1141 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 48f44165ccc58..375e2951a037c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1,1142 +1,694 @@
-//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the types and operation details for the NVVM IR dialect in
-// MLIR, and the LLVM IR dialect. It also registers the dialect.
-//
-// The NVVM dialect only contains GPU specific additions on top of the general
-// LLVM dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-
-#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
-#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/OperationSupport.h"
-#include "mlir/IR/Types.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/AsmParser/Parser.h"
-#include "llvm/IR/Attributes.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/Type.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-#include <cassert>
-#include <optional>
-#include <string>
-
-using namespace mlir;
-using namespace NVVM;
-
-#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
-#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
-
-//===----------------------------------------------------------------------===//
-// Printing/parsing for NVVM ops
-//===----------------------------------------------------------------------===//
-
-static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
- p << " " << op->getOperands();
- if (op->getNumResults() > 0)
- p << " : " << op->getResultTypes();
+// RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s
+
+// Same below, but using the `ConvertToLLVMPatternInterface` entry point
+// and the generic `convert-to-llvm` pass.
+// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @init_mbarrier
+llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b"
+ nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
+ nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
+ llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
+ nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
+ nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r"
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1
+ llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_try_wait_shared
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{
+ // CHECK-SAME: .reg .pred P1;
+ // CHECK-SAME: LAB_WAIT:
+ // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
+ // CHECK-SAME: @P1 bra.uni DONE;
+ // CHECK-SAME: bra.uni LAB_WAIT;
+ // CHECK-SAME: DONE:
+ // CHECK-SAME: }",
+ // CHECK-SAME: "r,r,r"
+ nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+ llvm.return
+}
+
+// CHECK-LABEL: @init_mbarrier_try_wait
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{
+ // CHECK-SAME: .reg .pred P1;
+ // CHECK-SAME: LAB_WAIT:
+ // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
+ // CHECK-SAME: @P1 bra.uni DONE;
+ // CHECK-SAME: bra.uni LAB_WAIT;
+ // CHECK-SAME: DONE:
+ // CHECK-SAME: }",
+ // CHECK-SAME: "l,r,r"
+ nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
+ llvm.return
+}
+
+// CHECK-LABEL: @async_cp
+func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
+ // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
+ nvvm.cp.async.shared.global %dst, %src, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
+ // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
+ nvvm.cp.async.shared.global %dst, %src, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
+ return
+}
+
+// CHECK-LABEL: @async_cp_zfill
+func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A",
+ // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+ nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A",
+ // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+ nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+ return
+}
+
+// CHECK-LABEL: @cp_async_mbarrier_arrive
+func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
+ // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}}
+ nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
+ // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true}
+ nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
+ // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}}
+ nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3>
+ // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} {noinc = true}
+ nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3>
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_load_3d_all
+func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_4d_all
+func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_5d_all
+func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_1d
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_2d
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_3d
+func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_4d
+func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_5d
+func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast1d
+func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast2d
+func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast3d
+func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast4d
+func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask: !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_load_multicast5d
+func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
+ return
+}
+
+// CHECK-LABEL: @tma_store_1d
+func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_store_2d
+func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_store_3d
+func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_store_4d
+func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @tma_store_5d
+func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
+ // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
+
+ // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
+ return
+}
+
+// CHECK-LABEL: @wgmma_execute
+func.func @wgmma_execute() {
+ nvvm.wgmma.fence.aligned
+ nvvm.wgmma.commit.group.sync.aligned
+ nvvm.wgmma.wait.group.sync.aligned 0
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
+ // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
+
+
+ nvvm.wgmma.fence.aligned
+ nvvm.wgmma.commit.group.sync.aligned
+ nvvm.wgmma.wait.group.sync.aligned 5
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
+ // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
+ return
+}
+
+
+// -----
+
+!mat64f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_f16_f16(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
+ // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+ // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{
+ // CHECK-SAME: reg .pred p;
+ // CHECK-SAME: setp.ne.b32 p, $34, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16
+ // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35, $36, $37, $38;\0A}\0A",
+ // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n"
+ // CHECK-SAME: %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]]
+ // CHECK-SAME: : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32)
+ // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[DESCa:.+]] = llvm.add %[[ARG0]], %[[C2]] : i64
+ // CHECK: %[[DESCb:.+]] = llvm.add %[[ARG1]], %[[C2]] : i64
+ // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{
+ // CHECK-SAME: .reg .pred p;
+ // CHECK-SAME: setp.ne.b32 p, $34, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16
+ // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35, $36, $37, $38;\0A}\0A",
+ // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n"
+ // CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+ %result = llvm.mlir.undef : !mat64f32
+ %result1 = nvvm.wgmma.mma_async
+ %descA, %descB, %result,
+ #nvvm.shape<m = 64, n = 32, k = 16>,
+ D [<f32>, #nvvm.wgmma_scale_out<zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ :!mat64f32 -> !mat64f32
+ %c2 = arith.constant 2 : i64
+ %descAnext = arith.addi %descA, %c2 : i64
+ %descBnext = arith.addi %descB, %c2 : i64
+ %result2 = nvvm.wgmma.mma_async
+ %descAnext, %descBnext, %result1,
+ #nvvm.shape<m = 64, n = 32, k = 16>,
+ D [<f32>, #nvvm.wgmma_scale_out<zero>],
+ A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
+ B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
+ : !mat64f32 -> !mat64f32
+ return %result2 : !mat64f32
+}
+
+// -----
+
+!mat16i32 = !llvm.struct<(i32, i32, i32, i32)>
+
+// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite(
+// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
+ %result = llvm.mlir.undef : !mat16i32
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
+// CHECK-SAME: "{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
+// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n"
+// CHECK-SAME: %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] :
+// CHECK-SAME: (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
+// CHECK-SAME: "{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
+// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A",
+// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n"
+// CHECK-SAME: %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
+// CHECK-SAME:"{
+// CHECK-SAME:.reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
+// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n"
+// CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
+ : !mat16i32 -> !mat16i32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
+ : !mat16i32 -> !mat16i32
+ %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
+ A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
+ : !mat16i32 -> !mat16i32
+ return %result3 : !mat16i32
+}
+
+// CHECK-LABEL: @wgmma_s32_u8_u8(
+ // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
+func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
+// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
+// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
+// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
+// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
+// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
+// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
+// CHECK-SAME: "{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
+// CHECK-SAME: }\0A",
+// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] :
+// CHECK-SAME:(i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
+// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
+// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
+// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
+// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
+// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
+// CHECK-SAME:"{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
+// CHECK-SAME: }\0A",
+// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
+// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
+// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
+// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
+// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
+// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
+// CHECK-SAME:"{
+// CHECK-SAME: .reg .pred p;
+// CHECK-SAME: setp.ne.b32 p, $10, 0;
+// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
+// CHECK-SAME:}\0A",
+// CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
+ %result = llvm.mlir.undef : !mat16i32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [<s32>, #nvvm.wgmma_scale_out<one>],
+ A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
+ : !mat16i32 -> !mat16i32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [<s32>, #nvvm.wgmma_scale_out<one>],
+ A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
+ : !mat16i32 -> !mat16i32
+ %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
+ #nvvm.shape<m = 64, n = 8, k = 32>,
+ D [<s32>, #nvvm.wgmma_scale_out<one>],
+ A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
+ B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
+ : !mat16i32 -> !mat16i32
+ return %result3 : !mat16i32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_tf32_tf32
+func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
+ // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME:"{
+ // CHECK-SAME: .reg .pred p;
+ // CHECK-SAME: setp.ne.b32 p, $66, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{
+ // CHECK-SAME: .reg .pred p;
+ // CHECK-SAME: setp.ne.b32 p, $66, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ %result = llvm.mlir.undef : !mat32f32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+ #nvvm.shape<m = 64, n = 64, k = 8>,
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+ : !mat32f32 -> !mat32f32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+ #nvvm.shape<m = 64, n = 64, k = 8>,
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+ : !mat32f32 -> !mat32f32
+ return %result2 : !mat32f32
+}
+
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e4m3_e4m3
+func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
+ // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ %result = llvm.mlir.undef : !mat32f32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+ : !mat32f32 -> !mat32f32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+ : !mat32f32 -> !mat32f32
+ return %result2 : !mat32f32
+}
+
+// -----
+
+!mat32f32 = !llvm.struct<(
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @wgmma_f32_e5m2_e4m3
+func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
+ // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
+ // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
+ %result = llvm.mlir.undef : !mat32f32
+ %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+ : !mat32f32 -> !mat32f32
+ %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
+ #nvvm.shape<m = 64, n = 64, k = 32>,
+ D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
+ A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
+ B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
+ : !mat32f32 -> !mat32f32
+ return %result2 : !mat32f32
+}
+
+// -----
+
+func.func @elect_one_leader_sync() {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
+ // CHECK-SAME: .reg .u32 rx;
+ // CHECK-SAME: .reg .pred px;
+ // CHECK-SAME: mov.pred $0, 0;
+ // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
+ // CHECK-SAME: @px mov.pred $0, 1;
+ // CHECK-SAME: "=b" : () -> i1
+ %cnd = nvvm.elect.sync -> i1
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
+llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
+ nvvm.prefetch.tensormap %desc : !llvm.ptr
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
+ nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
+ llvm.return
+}
+
+// -----
+
+func.func @set_max_register() {
+ // CHECK: nvvm.setmaxregister increase 232
+ nvvm.setmaxregister increase 232
+
+ // CHECK: nvvm.setmaxregister decrease 40
+ nvvm.setmaxregister decrease 40
+ func.return
+}
+
+// -----
+
+func.func @cp_async_bulk_commit() {
+ // CHECK: nvvm.cp.async.bulk.commit.group
+ nvvm.cp.async.bulk.commit.group
+ func.return
+}
+
+// -----
+
+func.func @cp_async_bulk_wait_group() {
+ // CHECK: nvvm.cp.async.bulk.wait_group 1
+ // CHECK: nvvm.cp.async.bulk.wait_group 0
+ // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+ // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
+ nvvm.cp.async.bulk.wait_group 1
+ nvvm.cp.async.bulk.wait_group 0
+ nvvm.cp.async.bulk.wait_group 5 {read}
+ nvvm.cp.async.bulk.wait_group 0 {read}
+ func.return
+}
+
+// -----
+
+func.func @fence_mbarrier_init() {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;"
+ nvvm.fence.mbarrier.init
+ func.return
+}
+// -----
+
+func.func @fence_proxy() {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", "" : () -> ()
+ nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>}
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", "" : () -> ()
+ nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>}
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", "" : () -> ()
+ nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>}
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", "" : () -> ()
+ nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>}
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", "" : () -> ()
+ nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: @llvm_nvvm_barrier_arrive
+// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
+llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0, $0;", "r" %[[numberOfThreads]] : (i32) -> ()
+ nvvm.barrier.arrive number_of_threads = %numberOfThreads
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1;", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
+ nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
+ llvm.return
}
-
-// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
-ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
- MLIRContext *context = parser.getContext();
- auto int32Ty = IntegerType::get(context, 32);
- auto int1Ty = IntegerType::get(context, 1);
-
- SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
- Type type;
- return failure(parser.parseOperandList(ops) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(type) ||
- parser.addTypeToList(type, result.types) ||
- parser.resolveOperands(ops, {int32Ty, int1Ty},
- parser.getNameLoc(), result.operands));
-}
-
-void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
- if (getCoordinates().empty() || getCoordinates().size() > 5)
- return emitError("expects coordinates between 1 to 5 dimension");
-
- // Check for im2col mode
- if (!getIm2colOffsets().empty()) {
- if (getCoordinates().size() < 3)
- return emitError(
- "to use im2col mode, the tensor has to be at least 3-dimensional");
- if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
- return emitError(
- "im2col offsets must be 2 less than number of coordinates");
- }
- return success();
-}
-
-LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
- if (getCoordinates().size() > 5)
- return emitError("Maximum 5 coordinates and dimension is supported.");
- return success();
-}
-
-LogicalResult CpAsyncOp::verify() {
- if (getModifier() != LoadCacheModifierKind::CG &&
- getModifier() != LoadCacheModifierKind::CA)
- return emitError("Only CG and CA cache modifiers are supported.");
- if (getSize() != 4 && getSize() != 8 && getSize() != 16)
- return emitError("expected byte size to be either 4, 8 or 16.");
- if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
- return emitError("CG cache modifier is only support for 16 bytes copy.");
- return success();
-}
-
-// Given the element type of an operand and whether or not it is an accumulator,
-// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
-// operand's element type.
-std::optional<mlir::NVVM::MMATypes>
-MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
- auto half2Type =
- LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
- if (operandElType.isF64())
- return NVVM::MMATypes::f64;
- if (operandElType.isF16() || operandElType == half2Type)
- return NVVM::MMATypes::f16;
- if (operandElType.isF32() && isAccumulator)
- return NVVM::MMATypes::f32;
- if (operandElType.isF32() && !isAccumulator)
- return NVVM::MMATypes::tf32;
- if (llvm::isa<IntegerType>(operandElType)) {
- if (isAccumulator)
- return NVVM::MMATypes::s32;
- return std::nullopt;
- }
-
- if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
- if (structType.getBody().empty())
- return std::nullopt;
- return inferOperandMMAType(structType.getBody()[0], isAccumulator);
- }
-
- return std::nullopt;
-}
-
-static bool isInt4PtxType(MMATypes type) {
- return (type == MMATypes::u4 || type == MMATypes::s4);
-}
-
-static bool isInt8PtxType(MMATypes type) {
- return (type == MMATypes::u8 || type == MMATypes::s8);
-}
-
-static bool isIntegerPtxType(MMATypes type) {
- return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
- type == MMATypes::s32;
-}
-
-MMATypes MmaOp::accumPtxType() {
- std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
- getODSOperands(2).getTypes().front(), /*isAccum=*/true);
- assert(val.has_value() && "accumulator PTX type should always be inferrable");
- return val.value();
-}
-
-MMATypes MmaOp::resultPtxType() {
- std::optional<mlir::NVVM::MMATypes> val =
- inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
- assert(val.has_value() && "result PTX type should always be inferrable");
- return val.value();
-}
-
-void MmaOp::print(OpAsmPrinter &p) {
- SmallVector<Type, 4> regTypes;
- struct OperandFragment {
- StringRef operandName;
- StringRef ptxTypeAttr;
- SmallVector<Value, 4> regs;
- explicit OperandFragment(StringRef name, StringRef ptxTypeName)
- : operandName(name), ptxTypeAttr(ptxTypeName) {}
- };
-
- std::array<OperandFragment, 3> frags{
- OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
- OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
- OperandFragment("C", "")};
- SmallVector<StringRef, 4> ignoreAttrNames{
- mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
-
- for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
- auto &frag = frags[fragIdx];
- auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
- for (auto operandIdx = varOperandSpec.first;
- operandIdx < varOperandSpec.first + varOperandSpec.second;
- operandIdx++) {
- frag.regs.push_back(this->getOperand(operandIdx));
- if (operandIdx == 0) {
- regTypes.push_back(this->getOperand(operandIdx).getType());
- }
- }
- std::optional<MMATypes> inferredType =
- inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
- if (inferredType)
- ignoreAttrNames.push_back(frag.ptxTypeAttr);
- }
-
- auto printMmaOperand = [&](const OperandFragment &frag) -> void {
- p << " " << frag.operandName;
- p << "[";
- p.printOperands(frag.regs);
- p << "] ";
- };
-
- for (const auto &frag : frags) {
- printMmaOperand(frag);
- }
-
- p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
-
- // Print the types of the operands and result.
- p << " : " << "(";
- llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
- frags[1].regs[0].getType(),
- frags[2].regs[0].getType()},
- p);
- p << ")";
- p.printArrowTypeList(TypeRange{this->getRes().getType()});
-}
-
-void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
- ValueRange operandA, ValueRange operandB, ValueRange operandC,
- ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
- std::optional<MMAIntOverflow> intOverflow,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
- std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
-
- assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
- MLIRContext *ctx = builder.getContext();
- result.addAttribute(
- "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
-
- result.addOperands(operandA);
- result.addOperands(operandB);
- result.addOperands(operandC);
-
- if (multiplicandPtxTypes) {
- result.addAttribute("multiplicandAPtxType",
- MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
- result.addAttribute("multiplicandBPtxType",
- MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
- } else {
- if (auto res = inferOperandMMAType(operandA[0].getType(), false))
- result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
- if (auto res = inferOperandMMAType(operandB[0].getType(), false))
- result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
- }
-
- if (multiplicandLayouts) {
- result.addAttribute("layoutA",
- MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
- result.addAttribute("layoutB",
- MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
- } else {
- result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
- result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
- }
-
- if (intOverflow.has_value())
- result.addAttribute("intOverflowBehavior",
- MMAIntOverflowAttr::get(ctx, *intOverflow));
- if (b1Op.has_value())
- result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
-
- result.addTypes(resultType);
- result.addAttribute(
- MmaOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
- static_cast<int32_t>(operandB.size()),
- static_cast<int32_t>(operandC.size())}));
-}
-
-// <operation> :=
-// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
-// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
-// `->` type($res)
-ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
- struct OperandFragment {
- std::optional<MMATypes> elemtype;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
- SmallVector<Type> regTypes;
- };
-
- Builder &builder = parser.getBuilder();
- std::array<OperandFragment, 4> frags;
-
- NamedAttrList namedAttributes;
-
- // A helper to parse the operand segments.
- auto parseMmaOperand = [&](StringRef operandName,
- OperandFragment &frag) -> LogicalResult {
- if (parser.parseKeyword(operandName).failed())
- return failure();
- if (parser
- .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
- .failed())
- return failure();
- return success();
- };
-
- // Parse the operand segments.
- if (parseMmaOperand("A", frags[0]).failed())
- return failure();
- if (parseMmaOperand("B", frags[1]).failed())
- return failure();
- if (parseMmaOperand("C", frags[2]).failed())
- return failure();
-
- if (parser.parseOptionalAttrDict(namedAttributes).failed())
- return failure();
-
- // Parse the type specification and resolve operands.
- SmallVector<Type, 3> operandTypes;
- if (failed(parser.parseColon()))
- return failure();
- if (failed(parser.parseLParen()))
- return failure();
- if (failed(parser.parseTypeList(operandTypes)))
- return failure();
- if (failed(parser.parseRParen()))
- if (operandTypes.size() != 3)
- return parser.emitError(
- parser.getNameLoc(),
- "expected one type for each operand segment but got " +
- Twine(operandTypes.size()) + " types");
- for (const auto &iter : llvm::enumerate(operandTypes)) {
- auto &frag = frags[iter.index()];
- frag.regTypes.resize(frag.regs.size(), iter.value());
- if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
- parser.getNameLoc(), result.operands)))
- return failure();
- frag.elemtype =
- inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
- }
-
- Type resultType;
- if (parser.parseArrow() || parser.parseType(resultType))
- return failure();
- frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
-
- std::array<StringRef, 2> names{"multiplicandAPtxType",
- "multiplicandBPtxType"};
- for (unsigned idx = 0; idx < names.size(); idx++) {
- const auto &frag = frags[idx];
- std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
- if (!frag.elemtype.has_value() && !attr.has_value()) {
- return parser.emitError(
- parser.getNameLoc(),
- "attribute " + names[idx] +
- " is not provided explicitly and cannot be inferred");
- }
- if (!attr.has_value())
- result.addAttribute(
- names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
- }
-
- result.addTypes(resultType);
- if (!namedAttributes.empty())
- result.addAttributes(namedAttributes);
- result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({
- static_cast<int32_t>(frags[0].regs.size()),
- static_cast<int32_t>(frags[1].regs.size()),
- static_cast<int32_t>(frags[2].regs.size()),
- }));
- return success();
-}
-
-LogicalResult MmaOp::verify() {
- MLIRContext *context = getContext();
- auto f16Ty = Float16Type::get(context);
- auto i32Ty = IntegerType::get(context, 32);
- auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
- auto f32Ty = Float32Type::get(context);
- auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
- context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
-
- auto s32x4StructTy =
- LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
- auto f32x8StructTy =
- LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
- auto f16x2x2StructTy =
- LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
- auto f32x4StructTy =
- LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
- auto s32x2StructTy =
- LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
-
- std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
- getShapeAttr().getK()};
-
- // These variables define the set of allowed data types for matrices A, B, C,
- // and result.
- using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
- using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
- AllowedShapes allowedShapes;
- AllowedTypes expectedA;
- AllowedTypes expectedB;
- AllowedTypes expectedC;
- SmallVector<Type> expectedResult;
-
- // When M = 16, we just need to calculate the number of 8xk tiles, where
- // k is a factor that depends on the data type.
- if (mmaShape[0] == 16) {
- int64_t kFactor;
- Type multiplicandFragType;
- switch (*getMultiplicandAPtxType()) {
- case MMATypes::tf32:
- kFactor = 4;
- multiplicandFragType = i32Ty;
- expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
- context, {f32Ty, f32Ty, f32Ty, f32Ty}));
- break;
- case MMATypes::f16:
- case MMATypes::bf16:
- kFactor = 8;
- multiplicandFragType = f16x2Ty;
- expectedResult.push_back(f16x2x2StructTy);
- expectedResult.push_back(f32x4StructTy);
- break;
- case MMATypes::s4:
- case MMATypes::u4:
- kFactor = 32;
- break;
- case MMATypes::b1:
- kFactor = 128;
- break;
- case MMATypes::s8:
- case MMATypes::u8:
- kFactor = 16;
- break;
- default:
- return emitError("invalid shape or multiplicand type: " +
- stringifyEnum(getMultiplicandAPtxType().value()));
- }
-
- if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
- expectedResult.push_back(s32x4StructTy);
- expectedC.emplace_back(4, i32Ty);
- multiplicandFragType = i32Ty;
- } else {
- expectedC.emplace_back(2, f16x2Ty);
- expectedC.emplace_back(4, f32Ty);
- }
-
- int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
- int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
- expectedA.emplace_back(unitA, multiplicandFragType);
- expectedB.emplace_back(unitB, multiplicandFragType);
- allowedShapes.push_back({16, 8, kFactor});
- allowedShapes.push_back({16, 8, kFactor * 2});
- }
-
- // In the M=8 case, there is only 1 possible case per data type.
- if (mmaShape[0] == 8) {
- if (*getMultiplicandAPtxType() == MMATypes::f16) {
- expectedA.emplace_back(2, f16x2Ty);
- expectedB.emplace_back(2, f16x2Ty);
- expectedResult.push_back(f16x2x4StructTy);
- expectedResult.push_back(f32x8StructTy);
- expectedC.emplace_back(4, f16x2Ty);
- expectedC.emplace_back(8, f32Ty);
- allowedShapes.push_back({8, 8, 4});
- }
- if (*getMultiplicandAPtxType() == MMATypes::f64) {
- Type f64Ty = Float64Type::get(context);
- expectedA.emplace_back(1, f64Ty);
- expectedB.emplace_back(1, f64Ty);
- expectedC.emplace_back(2, f64Ty);
- // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
- expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
- context, SmallVector<Type>(2, f64Ty)));
- allowedShapes.push_back({8, 8, 4});
- }
- if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
- expectedA.push_back({i32Ty});
- expectedB.push_back({i32Ty});
- expectedC.push_back({i32Ty, i32Ty});
- expectedResult.push_back(s32x2StructTy);
- if (isInt4PtxType(getMultiplicandAPtxType().value()))
- allowedShapes.push_back({8, 8, 32});
- if (isInt8PtxType(getMultiplicandAPtxType().value()))
- allowedShapes.push_back({8, 8, 16});
- if (getMultiplicandAPtxType().value() == MMATypes::b1)
- allowedShapes.push_back({8, 8, 128});
- }
- }
-
- std::string errorMessage;
- llvm::raw_string_ostream errorStream(errorMessage);
-
- // Check that we matched an existing shape/dtype combination.
- if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
- !llvm::is_contained(allowedShapes, mmaShape)) {
- errorStream << "unimplemented variant for MMA shape <";
- llvm::interleaveComma(mmaShape, errorStream);
- errorStream << ">";
- return emitOpError(errorMessage);
- }
-
- // Verify the operand types for segments of A, B, and C operands.
- std::array<StringRef, 3> operandNames{"A", "B", "C"};
- for (const auto &iter : llvm::enumerate(
- SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
- auto spec = this->getODSOperandIndexAndLength(iter.index());
- SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
- operand_type_begin() + spec.first +
- spec.second);
- bool match = llvm::is_contained(iter.value(), operandTySeg);
-
- if (!match) {
- errorStream << "Could not match types for the "
- << operandNames[iter.index()]
- << " operands; expected one of ";
- for (const auto &x : iter.value()) {
- errorStream << x.size() << "x" << x[0] << " ";
- }
- errorStream << "but got ";
- llvm::interleaveComma(operandTySeg, errorStream);
- return emitOpError(errorStream.str());
- }
- }
-
- // Check the result type
- if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
- return expectedResultType == getResult().getType();
- })) {
- errorStream
- << "Could not match allowed types for the result; expected one of ";
- llvm::interleaveComma(expectedResult, errorStream);
- errorStream << " but got " << getResult().getType();
- return emitOpError(errorStream.str());
- }
-
- // Ensure that binary MMA variants have a b1 MMA operation defined.
- if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
- return emitOpError("op requires " + getB1OpAttrName().strref() +
- " attribute");
- }
-
- // Ensure int4/int8 MMA variants specify the accum overflow behavior
- // attribute.
- if (isInt4PtxType(*getMultiplicandAPtxType()) ||
- isInt8PtxType(*getMultiplicandAPtxType())) {
- if (!getIntOverflowBehavior())
- return emitOpError("op requires " +
- getIntOverflowBehaviorAttrName().strref() +
- " attribute");
- }
-
- return success();
-}
-
-LogicalResult ShflOp::verify() {
- if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
- return success();
- auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
- auto elementType = (type && type.getBody().size() == 2)
- ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
- : nullptr;
- if (!elementType || elementType.getWidth() != 1)
- return emitError("expected return type to be a two-element struct with "
- "i1 as the second element");
- return success();
-}
-
-std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
- NVVM::MMAFrag frag, int nRow,
- int nCol,
- MLIRContext *context) {
- unsigned numberElements = 0;
- Type elementType;
- OpBuilder builder(context);
- Type f16x2 = VectorType::get(2, builder.getF16Type());
- if (type == NVVM::MMATypes::f16) {
- elementType = f16x2;
- if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
- numberElements = 8;
- else
- numberElements = 4;
- } else if (type == NVVM::MMATypes::f32) {
- elementType = builder.getF32Type();
- numberElements = 8;
- } else if (type == NVVM::MMATypes::tf32) {
- elementType = builder.getI32Type();
- numberElements = 4;
- } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
- elementType = builder.getI32Type();
- int parallelSize = 0;
- if (frag == NVVM::MMAFrag::a)
- parallelSize = nRow;
- if (frag == NVVM::MMAFrag::b)
- parallelSize = nCol;
-
- // m == 16 && n == 16 && k == 16
- if (parallelSize == 16)
- numberElements = 2;
- // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
- else if (parallelSize == 8)
- numberElements = 1;
- else if (parallelSize == 32)
- numberElements = 4;
- } else if (type == NVVM::MMATypes::s32) {
- elementType = builder.getI32Type();
- numberElements = 8;
- }
- assert(numberElements != 0 && elementType != nullptr);
- return std::make_pair(elementType, numberElements);
-}
-
-static std::pair<mlir::Type, unsigned>
-inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
- int k, MLIRContext *context) {
- int nRow, nCol;
- if (frag == NVVM::MMAFrag::a) {
- nRow = m;
- nCol = k;
- } else if (frag == NVVM::MMAFrag::b) {
- nRow = k;
- nCol = n;
- } else {
- nRow = m;
- nCol = n;
- }
- assert(nRow && nCol);
- return inferMMAType(type, frag, nRow, nCol, context);
-}
-
-LogicalResult NVVM::WMMALoadOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
- addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory "
- "space 0, 1, 3");
-
- if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
- getEltype(), getFrag()) == 0)
- return emitOpError() << "invalid attribute combination";
- std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
- getEltype(), getFrag(), getM(), getN(), getK(), getContext());
- Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
- if (getType() != dstType)
- return emitOpError("expected destination type is a structure of ")
- << typeInfo.second << " elements of type " << typeInfo.first;
- return success();
-}
-
-LogicalResult NVVM::WMMAStoreOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
- addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected operands to be a source pointer in memory "
- "space 0, 1, 3");
-
- if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
- getEltype()) == 0)
- return emitOpError() << "invalid attribute combination";
- std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
- getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
- if (getArgs().size() != typeInfo.second)
- return emitOpError() << "expected " << typeInfo.second << " data operands";
- if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
- return operands.getType() != typeInfo.first;
- }))
- return emitOpError() << "expected data operands of type " << typeInfo.first;
- return success();
-}
-
-LogicalResult NVVM::WMMAMmaOp::verify() {
- if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
- getLayoutB(), getEltypeA(),
- getEltypeB()) == 0)
- return emitOpError() << "invalid attribute combination";
- std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
- getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
- std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
- getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
- std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
- getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
- SmallVector<Type, 32> arguments;
- arguments.append(typeInfoA.second, typeInfoA.first);
- arguments.append(typeInfoB.second, typeInfoB.first);
- arguments.append(typeInfoC.second, typeInfoC.first);
- unsigned numArgs = arguments.size();
- if (getArgs().size() != numArgs)
- return emitOpError() << "expected " << numArgs << " arguments";
- for (unsigned i = 0; i < numArgs; i++) {
- if (getArgs()[i].getType() != arguments[i])
- return emitOpError() << "expected argument " << i << " to be of type "
- << arguments[i];
- }
- Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
- if (getType() != dstType)
- return emitOpError("expected destination type is a structure of ")
- << typeInfoC.second << " elements of type " << typeInfoC.first;
- return success();
-}
-
-LogicalResult NVVM::LdMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
- if (getNum() != 1 && getNum() != 2 && getNum() != 4)
- return emitOpError("expected num attribute to be 1, 2 or 4");
-
- Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
- return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
- Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
- if (getType() != dstType)
- return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
- }
- return success();
-}
-
-LogicalResult NVVM::StMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
- int numMatrix = getSources().size();
- if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
- return emitOpError("expected num attribute to be 1, 2 or 4");
-
- return success();
-}
-
-FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
- if (typeA == NVVM::WGMMATypes::tf32)
- return 8;
- if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
- return 16;
- if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
- return 32;
- if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
- return 32;
- if (typeA == NVVM::WGMMATypes::b1)
- return 256;
- return failure();
-}
-
-LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
- NVVM::WGMMATypes typeA,
- NVVM::WGMMATypes typeB) {
- switch (typeA) {
- case NVVM::WGMMATypes::f16:
- if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
- typeB == NVVM::WGMMATypes::f16)
- return success();
- break;
- case NVVM::WGMMATypes::tf32:
- if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
- return success();
- break;
- case NVVM::WGMMATypes::u8:
- case NVVM::WGMMATypes::s8:
- if (typeD == NVVM::WGMMATypes::s32 &&
- (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
- return success();
- break;
- case NVVM::WGMMATypes::b1:
- if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
- return success();
- break;
- case NVVM::WGMMATypes::bf16:
- if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
- typeB == NVVM::WGMMATypes::bf16)
- return success();
- break;
- case NVVM::WGMMATypes::e4m3:
- case NVVM::WGMMATypes::e5m2:
- if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
- (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
- return success();
- break;
- case WGMMATypes::f32:
- case WGMMATypes::s32:
- llvm_unreachable("unsupported input types");
- break;
- }
- return failure();
-}
-
-LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
- SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
- 72, 80, 88, 96, 104, 112, 120, 128,
- 136, 144, 152, 160, 168, 176, 184, 192,
- 200, 208, 216, 224, 232, 240, 248, 256};
- SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
- 80, 96, 112, 128, 144, 160,
- 176, 192, 208, 224, 240, 256};
- switch (typeA) {
- case WGMMATypes::f16:
- case WGMMATypes::tf32:
- case WGMMATypes::bf16:
- case WGMMATypes::e4m3:
- case WGMMATypes::e5m2:
- if (llvm::is_contained(allowedN, sizeN))
- return success();
- break;
- case WGMMATypes::u8:
- case WGMMATypes::s8:
- case WGMMATypes::b1:
- if (llvm::is_contained(allowedNshort, sizeN))
- return success();
- break;
- case WGMMATypes::f32:
- case WGMMATypes::s32:
- llvm_unreachable("unsupported input types");
- break;
- }
- return failure();
-}
-
-LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
- Value outValue = getResults();
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
- if (!stype)
- return emitOpError() << "expected results to be struct";
- int outputSize = stype.getBody().size();
- WGMMATypes typeD = getTypeD();
- WGMMATypes typeA = getTypeA();
- WGMMATypes typeB = getTypeB();
-
- for (Type t : stype.getBody()) {
- if (t != stype.getBody().front())
- return emitOpError()
- << "all elements in struct must be same type but there is " << t;
- }
-
- if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
- typeD != WGMMATypes::s32) {
- return emitOpError() << "does not support the given output type "
- << NVVM::stringifyWGMMATypes(typeD);
- }
- if (typeD == WGMMATypes::s32 &&
- (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
- return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
- }
-
- if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
- return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
- << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
- << NVVM::stringifyWGMMATypes(typeB)
- << ", it is not supported.";
- }
-
- // Check M
- if (getShape().getM() != 64)
- return emitOpError() << "shape 'm' must be 64";
-
- // Check K
- FailureOr<int> allowedK = getAllowedSizeK(typeA);
- if (failed(allowedK) || allowedK.value() != getShape().getK())
- return emitOpError() << "shape 'k' must be " << allowedK.value()
- << " for input type "
- << NVVM::stringifyWGMMATypes(typeA);
-
- // Check N
- if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
- return emitOpError() << "has input type "
- << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
- << getShape().getN() << ", it is not supported.";
- }
-
- // Check transpose (only available for f16/bf16)
- if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
- (getLayoutA() == mlir::NVVM::MMALayout::col ||
- getLayoutB() == mlir::NVVM::MMALayout::row)) {
- return emitOpError()
- << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
- << " and layout_b = " << stringifyMMALayout(getLayoutB())
- << " for input types " << stringifyWGMMATypes(typeA) << " and "
- << stringifyWGMMATypes(typeB)
- << " requires transpose. However, this is only supported for: "
- << stringifyMMATypes(MMATypes::f16) << " and "
- << stringifyMMATypes(MMATypes::bf16);
- }
-
- // Check result registers
- int expectedOutput = 0;
- if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
- expectedOutput = getShape().getN() / 2;
- if (typeD == WGMMATypes::f16)
- expectedOutput = getShape().getN() / 4;
- if (outputSize != expectedOutput) {
- return emitOpError() << "results " << expectedOutput
- << ", however output struct has " << outputSize
- << " elements";
- }
- // Check satfinite (only available for s32 accumulator)
- if (typeD != WGMMATypes::s32 &&
- getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
- NVVM::MMAIntOverflow::satfinite) {
- return emitOpError()
- << " `satfinite` can be only used with s32 accumulator, however "
- "the current accumulator is "
- << NVVM::stringifyWGMMATypes(typeD);
- }
-
- return success();
-}
-
-std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
-
- int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
- bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
-
- StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
-
- int expectedOutputRegisters = 0;
- if (getTypeD() == WGMMATypes::f16)
- expectedOutputRegisters = getShape().getN() / 4;
- else
- expectedOutputRegisters = getShape().getN() / 2;
-
- std::string ptx;
- llvm::raw_string_ostream ss(ptx);
-
- ss << "{\n"
- ".reg .pred p;\n"
- "setp.ne.b32 p, $"
- << ((expectedOutputRegisters * 2) + 2)
- << ", 0;\n"
- "wgmma.mma_async.sync.aligned.m"
- << m << "n" << n << "k" << k << "." << outputTypeName << "."
- << stringifyWGMMATypes(getTypeA()) << "."
- << stringifyWGMMATypes(getTypeB());
- if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
- NVVM::MMAIntOverflow::satfinite)
- ss << ".satfinite";
- ss << " {";
- int regCnt = 0;
- for (; regCnt < expectedOutputRegisters; ++regCnt) {
- ss << "$" << regCnt;
- if (regCnt != expectedOutputRegisters - 1)
- ss << ", ";
- }
-
- ss << "},";
- // Need to map read/write registers correctly.
- regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
- if (getTypeD() != WGMMATypes::s32) {
- ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
- }
- // Don't add transpose parameters unless needed.
- if (isF16) {
- ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
- }
- ss << ";\n"
- << "}\n";
- ss.flush();
- return ptx;
-}
-
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
- RewriterBase &rewriter,
- llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
- &asmValues) {
- bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
- if (getResults())
- asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
- if (getInouts())
- asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
- asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
- asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
- asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
- mlir::NVVM::PTXRegisterMod::Read});
- if (getTypeD() != WGMMATypes::s32) {
- asmValues.push_back(
- {makeConstantI32(rewriter,
- getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
- mlir::NVVM::PTXRegisterMod::Read});
- asmValues.push_back(
- {makeConstantI32(rewriter,
- getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
- mlir::NVVM::PTXRegisterMod::Read});
- }
- if (isF16) {
- asmValues.push_back(
- {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
- mlir::NVVM::PTXRegisterMod::Read});
- asmValues.push_back(
- {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
- mlir::NVVM::PTXRegisterMod::Read});
- }
-}
-LogicalResult NVVM::FenceProxyOp::verify() {
- if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
- return emitOpError() << "async_shared fence requires space attribute";
- }
- if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
- return emitOpError() << "only async_shared fence can have space attribute";
- }
- return success();
-}
-
-LogicalResult NVVM::SetMaxRegisterOp::verify() {
- if (getRegCount() % 8)
- return emitOpError("new register size must be multiple of 8");
- if (getRegCount() < 24 || getRegCount() > 256)
- return emitOpError("new register size must be in between 24 to 256");
- return success();
-}
-
-LogicalResult NVVM::BarrierOp::verify() {
- if (getNumberOfThreads() && !getBarrierId())
- return emitOpError(
- "barrier id is missing, it should be set between 0 to 15");
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// NVVMDialect initialization, type parsing, and registration.
-//===----------------------------------------------------------------------===//
-
-// TODO: This should be the llvm.nvvm dialect once this is supported.
-void NVVMDialect::initialize() {
- addOperations<
-#define GET_OP_LIST
-#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
- >();
- addAttributes<
-#define GET_ATTRDEF_LIST
-#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
- >();
-
- // Support unknown operations because not all NVVM operations are
- // registered.
- allowUnknownOperations();
- declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
- declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
-}
-
-LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
- NamedAttribute attr) {
- StringAttr attrName = attr.getName();
- // Kernel function attribute should be attached to functions.
- if (attrName == NVVMDialect::getKernelFuncAttrName()) {
- if (!isa<LLVM::LLVMFuncOp>(op)) {
- return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
- << "' attribute attached to unexpected op";
- }
- }
- // If maxntid and reqntid exist, it must be an array with max 3 dim
- if (attrName == NVVMDialect::getMaxntidAttrName() ||
- attrName == NVVMDialect::getReqntidAttrName()) {
- auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
- if (!values || values.empty() || values.size() > 3)
- return op->emitError()
- << "'" << attrName
- << "' attribute must be integer array with maximum 3 index";
- }
- // If minctasm and maxnreg exist, it must be an integer attribute
- if (attrName == NVVMDialect::getMinctasmAttrName() ||
- attrName == NVVMDialect::getMaxnregAttrName()) {
- if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
- return op->emitError()
- << "'" << attrName << "' attribute must be integer constant";
- }
-
- return success();
-}
-
-LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
- unsigned regionIndex,
- unsigned argIndex,
- NamedAttribute argAttr) {
- auto funcOp = dyn_cast<FunctionOpInterface>(op);
- if (!funcOp)
- return success();
-
- bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
- StringAttr attrName = argAttr.getName();
- if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
- if (!isKernel) {
- return op->emitError()
- << "'" << attrName
- << "' attribute must be present only on kernel arguments";
- }
- if (!isa<UnitAttr>(argAttr.getValue()))
- return op->emitError() << "'" << attrName << "' must be a unit attribute";
- if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
- return op->emitError()
- << "'" << attrName
- << "' attribute requires the argument to also have attribute '"
- << LLVM::LLVMDialect::getByValAttrName() << "'";
- }
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// NVVM target attribute.
-//===----------------------------------------------------------------------===//
-LogicalResult
-NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- int optLevel, StringRef triple, StringRef chip,
- StringRef features, DictionaryAttr flags,
- ArrayAttr files) {
- if (optLevel < 0 || optLevel > 3) {
- emitError() << "The optimization level must be a number between 0 and 3.";
- return failure();
- }
- if (triple.empty()) {
- emitError() << "The target triple cannot be empty.";
- return failure();
- }
- if (chip.empty()) {
- emitError() << "The target chip cannot be empty.";
- return failure();
- }
- if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
- return attr && mlir::isa<StringAttr>(attr);
- })) {
- emitError() << "All the elements in the `link` array must be strings.";
- return failure();
- }
- return success();
-}
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
-
-#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
>From 7c9053ca23bbf94e10c7174e8c8c4ae178e5c0e7 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:03:29 +0800
Subject: [PATCH 3/4] Update NVVMDialect.cpp
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 1834 ++++++++++++--------
1 file changed, 1141 insertions(+), 693 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 375e2951a037c..036a9a15af838 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1,694 +1,1142 @@
-// RUN: mlir-opt --convert-nvvm-to-llvm --convert-arith-to-llvm --split-input-file %s | FileCheck %s
-
-// Same below, but using the `ConvertToLLVMPatternInterface` entry point
-// and the generic `convert-to-llvm` pass.
-// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: @init_mbarrier
-llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b"
- nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
- nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
- llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
- llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r"
- nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
- nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1
- llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_try_wait_shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred P1;
- // CHECK-SAME: LAB_WAIT:
- // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
- // CHECK-SAME: @P1 bra.uni DONE;
- // CHECK-SAME: bra.uni LAB_WAIT;
- // CHECK-SAME: DONE:
- // CHECK-SAME: }",
- // CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
- llvm.return
-}
-
-// CHECK-LABEL: @init_mbarrier_try_wait
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred P1;
- // CHECK-SAME: LAB_WAIT:
- // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
- // CHECK-SAME: @P1 bra.uni DONE;
- // CHECK-SAME: bra.uni LAB_WAIT;
- // CHECK-SAME: DONE:
- // CHECK-SAME: }",
- // CHECK-SAME: "l,r,r"
- nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
- llvm.return
-}
-
-// CHECK-LABEL: @async_cp
-func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
- // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
- nvvm.cp.async.shared.global %dst, %src, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1>
- // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
- nvvm.cp.async.shared.global %dst, %src, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1>
- return
-}
-
-// CHECK-LABEL: @async_cp_zfill
-func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A",
- // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
- nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A",
- // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
- nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
- return
-}
-
-// CHECK-LABEL: @cp_async_mbarrier_arrive
-func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
- // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}}
- nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
- // CHECK: nvvm.cp.async.mbarrier.arrive %{{.*}} {noinc = true}
- nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
- // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}}
- nvvm.cp.async.mbarrier.arrive.shared %bar_shared : !llvm.ptr<3>
- // CHECK: nvvm.cp.async.mbarrier.arrive.shared %{{.*}} {noinc = true}
- nvvm.cp.async.mbarrier.arrive.shared %bar_shared {noinc = true} : !llvm.ptr<3>
- llvm.return
-}
-
-// CHECK-LABEL: @tma_load_3d_all
-func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_4d_all
-func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_5d_all
-func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
- // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4} ], [$5];", "r,l,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_4d
-func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5} ], [$6];", "r,l,r,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_5d
-func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3,$4,$5,$6} ], [$7];", "r,l,r,r,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_multicast1d
-func.func @tma_load_multicast1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2} ], [$3], $4;", "r,l,r,r,h,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_multicast2d
-func.func @tma_load_multicast2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3} ], [$4], $5;", "r,l,r,r,r,h,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_multicast3d
-func.func @tma_load_multicast3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4} ], [$5], $6;", "r,l,r,r,r,r,h,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_multicast4d
-func.func @tma_load_multicast4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask: !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$8 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5} ], [$6], $7;", "r,l,r,r,r,r,r,h,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_load_multicast5d
-func.func @tma_load_multicast5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %multicastMask : i16, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster [$0], [$1, {$2,$3,$4,$5,$6} ], [$7], $8;", "r,l,r,r,r,r,r,r,h,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box [%crd0,%crd1,%crd2,%crd3,%crd4] multicast_mask = %multicastMask predicate=%p : !llvm.ptr<3>, !llvm.ptr
- return
-}
-
-// CHECK-LABEL: @tma_store_1d
-func.func @tma_store_1d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0] : !llvm.ptr, !llvm.ptr<3>, i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$3 cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [$0, {$2} ], [$1];", "l,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i1
- return
-}
-
-// CHECK-LABEL: @tma_store_2d
-func.func @tma_store_2d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1] : !llvm.ptr, !llvm.ptr<3>, i32, i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$0, {$2, $3} ], [$1];", "l,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i1
- return
-}
-
-// CHECK-LABEL: @tma_store_3d
-func.func @tma_store_3d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [$0, {$2, $3, $4} ], [$1];", "l,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i1
- return
-}
-
-// CHECK-LABEL: @tma_store_4d
-func.func @tma_store_4d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5} ], [$1];", "l,r,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i1
- return
-}
-
-// CHECK-LABEL: @tma_store_5d
-func.func @tma_store_5d(%tmaDescriptor: !llvm.ptr, %src : !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %p : i1) {
- // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
-
- // CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att "@$7 cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [$0, {$2, $3, $4, $5, $6} ], [$1];", "l,r,r,r,r,r,r,b"
- nvvm.cp.async.bulk.tensor.global.shared.cta %tmaDescriptor, %src, box[%crd0,%crd1,%crd2,%crd3,%crd4], predicate=%p : !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32, i1
- return
-}
-
-// CHECK-LABEL: @wgmma_execute
-func.func @wgmma_execute() {
- nvvm.wgmma.fence.aligned
- nvvm.wgmma.commit.group.sync.aligned
- nvvm.wgmma.wait.group.sync.aligned 0
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
-
-
- nvvm.wgmma.fence.aligned
- nvvm.wgmma.commit.group.sync.aligned
- nvvm.wgmma.wait.group.sync.aligned 5
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
- return
-}
-
-
-// -----
-
-!mat64f32 = !llvm.struct<(
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_f16_f16(
-// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
-func.func @wgmma_f32_f16_f16(%descA : i64, %descB : i64) -> !mat64f32{
- // CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
- // CHECK: %[[A1:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[A2:.*]] = llvm.mlir.constant(-1 : i32) : i32
- // CHECK: %[[A3:.*]] = llvm.mlir.constant(-1 : i32) : i32
- // CHECK: %[[A4:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[A5:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[V4:.*]] = llvm.extractvalue %[[RES]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[V11:.*]] = llvm.extractvalue %[[RES]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[V13:.*]] = llvm.extractvalue %[[RES]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[RES1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: reg .pred p;
- // CHECK-SAME: setp.ne.b32 p, $34, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16
- // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35, $36, $37, $38;\0A}\0A",
- // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n"
- // CHECK-SAME: %[[V0]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11]], %{{.*}}, %[[V13]], %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]]
- // CHECK-SAME: : (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, i64, i64, i32, i32, i32, i32, i32)
- // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
- // CHECK: %[[DESCa:.+]] = llvm.add %[[ARG0]], %[[C2]] : i64
- // CHECK: %[[DESCb:.+]] = llvm.add %[[ARG1]], %[[C2]] : i64
- // CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES1]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[V4_2:.*]] = llvm.extractvalue %[[RES1]][4] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[V11_2:.*]] = llvm.extractvalue %[[RES1]][11] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[V13_2:.*]] = llvm.extractvalue %[[RES1]][13] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
- // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred p;
- // CHECK-SAME: setp.ne.b32 p, $34, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16
- // CHECK-SAME: {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15}, $32, $33, p, $35, $36, $37, $38;\0A}\0A",
- // CHECK-SAME: "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,l,l,n,n,n,n,n"
- // CHECK-SAME: %[[V0_2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[V4_2]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[V11_2]], %{{.*}}, %[[V13_2]], %{{.*}}, %{{.*}}, %[[DESCa]], %[[DESCb]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
- %result = llvm.mlir.undef : !mat64f32
- %result1 = nvvm.wgmma.mma_async
- %descA, %descB, %result,
- #nvvm.shape<m = 64, n = 32, k = 16>,
- D [<f32>, #nvvm.wgmma_scale_out<zero>],
- A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
- B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
- :!mat64f32 -> !mat64f32
- %c2 = arith.constant 2 : i64
- %descAnext = arith.addi %descA, %c2 : i64
- %descBnext = arith.addi %descB, %c2 : i64
- %result2 = nvvm.wgmma.mma_async
- %descAnext, %descBnext, %result1,
- #nvvm.shape<m = 64, n = 32, k = 16>,
- D [<f32>, #nvvm.wgmma_scale_out<zero>],
- A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>],
- B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>]
- : !mat64f32 -> !mat64f32
- return %result2 : !mat64f32
-}
-
-// -----
-
-!mat16i32 = !llvm.struct<(i32, i32, i32, i32)>
-
-// CHECK-LABEL: @wgmma_s32_s8_s8_satfinite(
-// CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
-func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
- %result = llvm.mlir.undef : !mat16i32
-// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
-// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
-// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
-// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
-// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
-// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
-// CHECK-SAME: "{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
-// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n"
-// CHECK-SAME: %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] :
-// CHECK-SAME: (i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
-// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
-// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
-// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
-// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
-// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
-// CHECK-SAME: "{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
-// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A",
-// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n"
-// CHECK-SAME: %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
-// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
-// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
-// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
-// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
-// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
-// CHECK-SAME:"{
-// CHECK-SAME:.reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite
-// CHECK-SAME: {$0, $1, $2, $3}, $8, $9, p;\0A}\0A", "=r,=r,=r,=r,0,1,2,3,l,l,n"
-// CHECK-SAME: %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
- %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
- #nvvm.shape<m = 64, n = 8, k = 32>,
- D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
- A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
- : !mat16i32 -> !mat16i32
- %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
- #nvvm.shape<m = 64, n = 8, k = 32>,
- D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
- A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
- : !mat16i32 -> !mat16i32
- %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
- #nvvm.shape<m = 64, n = 8, k = 32>,
- D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
- A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
- : !mat16i32 -> !mat16i32
- return %result3 : !mat16i32
-}
-
-// CHECK-LABEL: @wgmma_s32_u8_u8(
- // CHECK-SAME: %[[ARG0:.+]]: i64, %[[ARG1:.+]]: i64
-func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
-// CHECK: %[[RES:.*]] = llvm.mlir.undef : !llvm.struct
-// CHECK: %[[A1:.*]] = llvm.mlir.constant(1 : i32) : i32
-// CHECK: %[[V0:.*]] = llvm.extractvalue %[[RES]][0]
-// CHECK: %[[V1:.*]] = llvm.extractvalue %[[RES]][1]
-// CHECK: %[[V2:.*]] = llvm.extractvalue %[[RES]][2]
-// CHECK: %[[V3:.*]] = llvm.extractvalue %[[RES]][3]
-// CHECK: %[[RES_2:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
-// CHECK-SAME: "{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
-// CHECK-SAME: }\0A",
-// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0]], %[[V1]], %[[V2]], %[[V3]], %[[ARG0]], %[[ARG1]], %[[A1]] :
-// CHECK-SAME:(i32, i32, i32, i32, i64, i64, i32) -> !llvm.struct<(i32, i32, i32, i32)>
-// CHECK: %[[V0_2:.*]] = llvm.extractvalue %[[RES_2]][0]
-// CHECK: %[[V1_2:.*]] = llvm.extractvalue %[[RES_2]][1]
-// CHECK: %[[V2_2:.*]] = llvm.extractvalue %[[RES_2]][2]
-// CHECK: %[[V3_2:.*]] = llvm.extractvalue %[[RES_2]][3]
-// CHECK: %[[RES_3:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
-// CHECK-SAME:"{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
-// CHECK-SAME: }\0A",
-// CHECK-SAME: "=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_2]], %[[V1_2]], %[[V2_2]], %[[V3_2]], %[[ARG0]], %[[ARG1]], %{{.*}}
-// CHECK: %[[V0_3:.*]] = llvm.extractvalue %[[RES_3]][0]
-// CHECK: %[[V1_3:.*]] = llvm.extractvalue %[[RES_3]][1]
-// CHECK: %[[V2_3:.*]] = llvm.extractvalue %[[RES_3]][2]
-// CHECK: %[[V3_3:.*]] = llvm.extractvalue %[[RES_3]][3]
-// CHECK: %[[RES1:.*]] = llvm.inline_asm has_side_effects asm_dialect = att
-// CHECK-SAME:"{
-// CHECK-SAME: .reg .pred p;
-// CHECK-SAME: setp.ne.b32 p, $10, 0;
-// CHECK-SAME: wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 {$0, $1, $2, $3}, $8, $9, p;
-// CHECK-SAME:}\0A",
-// CHECK-SAME:"=r,=r,=r,=r,0,1,2,3,l,l,n" %[[V0_3]], %[[V1_3]], %[[V2_3]], %[[V3_3]], %[[ARG0]], %[[ARG1]], %{{.*}}
- %result = llvm.mlir.undef : !mat16i32
- %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
- #nvvm.shape<m = 64, n = 8, k = 32>,
- D [<s32>, #nvvm.wgmma_scale_out<one>],
- A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
- : !mat16i32 -> !mat16i32
- %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
- #nvvm.shape<m = 64, n = 8, k = 32>,
- D [<s32>, #nvvm.wgmma_scale_out<one>],
- A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
- : !mat16i32 -> !mat16i32
- %result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
- #nvvm.shape<m = 64, n = 8, k = 32>,
- D [<s32>, #nvvm.wgmma_scale_out<one>],
- A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
- B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
- : !mat16i32 -> !mat16i32
- return %result3 : !mat16i32
-}
-
-// -----
-
-!mat32f32 = !llvm.struct<(
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_tf32_tf32
-func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
- // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME:"{
- // CHECK-SAME: .reg .pred p;
- // CHECK-SAME: setp.ne.b32 p, $66, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
- // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred p;
- // CHECK-SAME: setp.ne.b32 p, $66, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
- %result = llvm.mlir.undef : !mat32f32
- %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
- #nvvm.shape<m = 64, n = 64, k = 8>,
- D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
- : !mat32f32 -> !mat32f32
- %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
- #nvvm.shape<m = 64, n = 64, k = 8>,
- D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
- : !mat32f32 -> !mat32f32
- return %result2 : !mat32f32
-}
-
-
-// -----
-
-!mat32f32 = !llvm.struct<(
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_e4m3_e4m3
-func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
- // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
- // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
- %result = llvm.mlir.undef : !mat32f32
- %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
- #nvvm.shape<m = 64, n = 64, k = 32>,
- D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
- : !mat32f32 -> !mat32f32
- %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
- #nvvm.shape<m = 64, n = 64, k = 32>,
- D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
- : !mat32f32 -> !mat32f32
- return %result2 : !mat32f32
-}
-
-// -----
-
-!mat32f32 = !llvm.struct<(
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32)>
-
-// CHECK-LABEL: @wgmma_f32_e5m2_e4m3
-func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
- // CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
- // CHECK: %[[RES_2:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{\0A.reg .pred p;\0Asetp.ne.b32 p, $66, 0;
- // CHECK-SAME: wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 {$0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31}, $64, $65, p, $67, $68;\0A}\0A", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,l,l,n,n,n"
- %result = llvm.mlir.undef : !mat32f32
- %result1 = nvvm.wgmma.mma_async %descA, %descB, %result,
- #nvvm.shape<m = 64, n = 64, k = 32>,
- D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
- : !mat32f32 -> !mat32f32
- %result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
- #nvvm.shape<m = 64, n = 64, k = 32>,
- D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
- A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
- B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
- : !mat32f32 -> !mat32f32
- return %result2 : !mat32f32
-}
-
-// -----
-
-func.func @elect_one_leader_sync() {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
- // CHECK-SAME: .reg .u32 rx;
- // CHECK-SAME: .reg .pred px;
- // CHECK-SAME: mov.pred $0, 0;
- // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
- // CHECK-SAME: @px mov.pred $0, 1;
- // CHECK-SAME: "=b" : () -> i1
- %cnd = nvvm.elect.sync -> i1
- return
-}
-
-// -----
-
-// CHECK-LABEL: @stmatrix(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
-// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
-llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
- nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
- nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
- nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
- nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
- nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
- nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
- llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
- nvvm.prefetch.tensormap %desc : !llvm.ptr
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
- nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
- llvm.return
-}
-
-// -----
-
-func.func @set_max_register() {
- // CHECK: nvvm.setmaxregister increase 232
- nvvm.setmaxregister increase 232
-
- // CHECK: nvvm.setmaxregister decrease 40
- nvvm.setmaxregister decrease 40
- func.return
-}
-
-// -----
-
-func.func @cp_async_bulk_commit() {
- // CHECK: nvvm.cp.async.bulk.commit.group
- nvvm.cp.async.bulk.commit.group
- func.return
-}
-
-// -----
-
-func.func @cp_async_bulk_wait_group() {
- // CHECK: nvvm.cp.async.bulk.wait_group 1
- // CHECK: nvvm.cp.async.bulk.wait_group 0
- // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
- // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
- nvvm.cp.async.bulk.wait_group 1
- nvvm.cp.async.bulk.wait_group 0
- nvvm.cp.async.bulk.wait_group 5 {read}
- nvvm.cp.async.bulk.wait_group 0 {read}
- func.return
-}
-
-// -----
-
-func.func @fence_mbarrier_init() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.mbarrier_init.release.cluster;"
- nvvm.fence.mbarrier.init
- func.return
-}
-// -----
-
-func.func @fence_proxy() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.alias;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<alias>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.global;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.global>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cta;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cta>}
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "fence.proxy.async.shared::cluster;", "" : () -> ()
- nvvm.fence.proxy { kind = #nvvm.proxy_kind<async.shared>, space = #nvvm.shared_space<cluster>}
- func.return
-}
-
-// -----
-
-// CHECK-LABEL: @llvm_nvvm_barrier_arrive
-// CHECK-SAME: (%[[barId:.*]]: i32, %[[numberOfThreads:.*]]: i32)
-llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive 0, $0;", "r" %[[numberOfThreads]] : (i32) -> ()
- nvvm.barrier.arrive number_of_threads = %numberOfThreads
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "bar.arrive $0, $1;", "r,r" %[[barId]], %[[numberOfThreads]] : (i32, i32) -> ()
- nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
- llvm.return
+//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the types and operation details for the NVVM IR dialect in
+// MLIR, and the LLVM IR dialect. It also registers the dialect.
+//
+// The NVVM dialect only contains GPU specific additions on top of the general
+// LLVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cassert>
+#include <optional>
+#include <string>
+
+using namespace mlir;
+using namespace NVVM;
+
+#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
+#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for NVVM ops
+//===----------------------------------------------------------------------===//
+
+static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
+ p << " " << op->getOperands();
+ if (op->getNumResults() > 0)
+ p << " : " << op->getResultTypes();
}
+
+// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
+ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
+ MLIRContext *context = parser.getContext();
+ auto int32Ty = IntegerType::get(context, 32);
+ auto int1Ty = IntegerType::get(context, 1);
+
+ SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
+ Type type;
+ return failure(parser.parseOperandList(ops) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(type) ||
+ parser.addTypeToList(type, result.types) ||
+ parser.resolveOperands(ops, {int32Ty, int1Ty},
+ parser.getNameLoc(), result.operands));
+}
+
+void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+ if (getCoordinates().empty() || getCoordinates().size() > 5)
+ return emitError("expects coordinates between 1 to 5 dimension");
+
+ // Check for im2col mode
+ if (!getIm2colOffsets().empty()) {
+ if (getCoordinates().size() < 3)
+ return emitError(
+ "to use im2col mode, the tensor has to be at least 3-dimensional");
+ if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
+ return emitError(
+ "im2col offsets must be 2 less than number of coordinates");
+ }
+ return success();
+}
+
+LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
+ if (getCoordinates().size() > 5)
+ return emitError("Maximum 5 coordinates and dimension is supported.");
+ return success();
+}
+
+LogicalResult CpAsyncOp::verify() {
+ if (getModifier() != LoadCacheModifierKind::CG &&
+ getModifier() != LoadCacheModifierKind::CA)
+ return emitError("Only CG and CA cache modifiers are supported.");
+ if (getSize() != 4 && getSize() != 8 && getSize() != 16)
+ return emitError("expected byte size to be either 4, 8 or 16.");
+ if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
+ return emitError("CG cache modifier is only support for 16 bytes copy.");
+ return success();
+}
+
+// Given the element type of an operand and whether or not it is an accumulator,
+// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
+// operand's element type.
+std::optional<mlir::NVVM::MMATypes>
+MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
+ auto half2Type =
+ LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+ if (operandElType.isF64())
+ return NVVM::MMATypes::f64;
+ if (operandElType.isF16() || operandElType == half2Type)
+ return NVVM::MMATypes::f16;
+ if (operandElType.isF32() && isAccumulator)
+ return NVVM::MMATypes::f32;
+ if (operandElType.isF32() && !isAccumulator)
+ return NVVM::MMATypes::tf32;
+ if (llvm::isa<IntegerType>(operandElType)) {
+ if (isAccumulator)
+ return NVVM::MMATypes::s32;
+ return std::nullopt;
+ }
+
+ if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
+ if (structType.getBody().empty())
+ return std::nullopt;
+ return inferOperandMMAType(structType.getBody()[0], isAccumulator);
+ }
+
+ return std::nullopt;
+}
+
+static bool isInt4PtxType(MMATypes type) {
+ return (type == MMATypes::u4 || type == MMATypes::s4);
+}
+
+static bool isInt8PtxType(MMATypes type) {
+ return (type == MMATypes::u8 || type == MMATypes::s8);
+}
+
+static bool isIntegerPtxType(MMATypes type) {
+ return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
+ type == MMATypes::s32;
+}
+
+MMATypes MmaOp::accumPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
+ getODSOperands(2).getTypes().front(), /*isAccum=*/true);
+ assert(val.has_value() && "accumulator PTX type should always be inferrable");
+ return val.value();
+}
+
+MMATypes MmaOp::resultPtxType() {
+ std::optional<mlir::NVVM::MMATypes> val =
+ inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
+ assert(val.has_value() && "result PTX type should always be inferrable");
+ return val.value();
+}
+
+void MmaOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+ };
+
+ std::array<OperandFragment, 3> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
+
+ for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(this->getOperand(operandIdx));
+ if (operandIdx == 0) {
+ regTypes.push_back(this->getOperand(operandIdx).getType());
+ }
+ }
+ std::optional<MMATypes> inferredType =
+ inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+
+ auto printMmaOperand = [&](const OperandFragment &frag) -> void {
+ p << " " << frag.operandName;
+ p << "[";
+ p.printOperands(frag.regs);
+ p << "] ";
+ };
+
+ for (const auto &frag : frags) {
+ printMmaOperand(frag);
+ }
+
+ p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+ // Print the types of the operands and result.
+ p << " : " << "(";
+ llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+ frags[1].regs[0].getType(),
+ frags[2].regs[0].getType()},
+ p);
+ p << ")";
+ p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
+
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+
+ if (multiplicandLayouts) {
+ result.addAttribute("layoutA",
+ MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
+ result.addAttribute("layoutB",
+ MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
+ } else {
+ result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
+ result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
+ }
+
+ if (intOverflow.has_value())
+ result.addAttribute("intOverflowBehavior",
+ MMAIntOverflowAttr::get(ctx, *intOverflow));
+ if (b1Op.has_value())
+ result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
+
+ result.addTypes(resultType);
+ result.addAttribute(
+ MmaOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size())}));
+}
+
+// <operation> :=
+// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
+// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
+// `->` type($res)
+ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
+ struct OperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ SmallVector<Type> regTypes;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<OperandFragment, 4> frags;
+
+ NamedAttrList namedAttributes;
+
+ // A helper to parse the operand segments.
+ auto parseMmaOperand = [&](StringRef operandName,
+ OperandFragment &frag) -> LogicalResult {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser
+ .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+ };
+
+ // Parse the operand segments.
+ if (parseMmaOperand("A", frags[0]).failed())
+ return failure();
+ if (parseMmaOperand("B", frags[1]).failed())
+ return failure();
+ if (parseMmaOperand("C", frags[2]).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse the type specification and resolve operands.
+ SmallVector<Type, 3> operandTypes;
+ if (failed(parser.parseColon()))
+ return failure();
+ if (failed(parser.parseLParen()))
+ return failure();
+ if (failed(parser.parseTypeList(operandTypes)))
+ return failure();
+ if (failed(parser.parseRParen()))
+ if (operandTypes.size() != 3)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expected one type for each operand segment but got " +
+ Twine(operandTypes.size()) + " types");
+ for (const auto &iter : llvm::enumerate(operandTypes)) {
+ auto &frag = frags[iter.index()];
+ frag.regTypes.resize(frag.regs.size(), iter.value());
+ if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
+ parser.getNameLoc(), result.operands)))
+ return failure();
+ frag.elemtype =
+ inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
+ }
+
+ Type resultType;
+ if (parser.parseArrow() || parser.parseType(resultType))
+ return failure();
+ frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
+
+ std::array<StringRef, 2> names{"multiplicandAPtxType",
+ "multiplicandBPtxType"};
+ for (unsigned idx = 0; idx < names.size(); idx++) {
+ const auto &frag = frags[idx];
+ std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
+ if (!frag.elemtype.has_value() && !attr.has_value()) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "attribute " + names[idx] +
+ " is not provided explicitly and cannot be inferred");
+ }
+ if (!attr.has_value())
+ result.addAttribute(
+ names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
+ }
+
+ result.addTypes(resultType);
+ if (!namedAttributes.empty())
+ result.addAttributes(namedAttributes);
+ result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ }));
+ return success();
+}
+
+LogicalResult MmaOp::verify() {
+ MLIRContext *context = getContext();
+ auto f16Ty = Float16Type::get(context);
+ auto i32Ty = IntegerType::get(context, 32);
+ auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+ auto f32Ty = Float32Type::get(context);
+ auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
+ context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+
+ auto s32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
+ auto f32x8StructTy =
+ LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+ auto f16x2x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
+ auto f32x4StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
+ auto s32x2StructTy =
+ LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
+
+ std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
+ getShapeAttr().getK()};
+
+ // These variables define the set of allowed data types for matrices A, B, C,
+ // and result.
+ using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
+ using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
+ AllowedShapes allowedShapes;
+ AllowedTypes expectedA;
+ AllowedTypes expectedB;
+ AllowedTypes expectedC;
+ SmallVector<Type> expectedResult;
+
+ // When M = 16, we just need to calculate the number of 8xk tiles, where
+ // k is a factor that depends on the data type.
+ if (mmaShape[0] == 16) {
+ int64_t kFactor;
+ Type multiplicandFragType;
+ switch (*getMultiplicandAPtxType()) {
+ case MMATypes::tf32:
+ kFactor = 4;
+ multiplicandFragType = i32Ty;
+ expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {f32Ty, f32Ty, f32Ty, f32Ty}));
+ break;
+ case MMATypes::f16:
+ case MMATypes::bf16:
+ kFactor = 8;
+ multiplicandFragType = f16x2Ty;
+ expectedResult.push_back(f16x2x2StructTy);
+ expectedResult.push_back(f32x4StructTy);
+ break;
+ case MMATypes::s4:
+ case MMATypes::u4:
+ kFactor = 32;
+ break;
+ case MMATypes::b1:
+ kFactor = 128;
+ break;
+ case MMATypes::s8:
+ case MMATypes::u8:
+ kFactor = 16;
+ break;
+ default:
+ return emitError("invalid shape or multiplicand type: " +
+ stringifyEnum(getMultiplicandAPtxType().value()));
+ }
+
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedResult.push_back(s32x4StructTy);
+ expectedC.emplace_back(4, i32Ty);
+ multiplicandFragType = i32Ty;
+ } else {
+ expectedC.emplace_back(2, f16x2Ty);
+ expectedC.emplace_back(4, f32Ty);
+ }
+
+ int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
+ int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
+ expectedA.emplace_back(unitA, multiplicandFragType);
+ expectedB.emplace_back(unitB, multiplicandFragType);
+ allowedShapes.push_back({16, 8, kFactor});
+ allowedShapes.push_back({16, 8, kFactor * 2});
+ }
+
+ // In the M=8 case, there is only 1 possible case per data type.
+ if (mmaShape[0] == 8) {
+ if (*getMultiplicandAPtxType() == MMATypes::f16) {
+ expectedA.emplace_back(2, f16x2Ty);
+ expectedB.emplace_back(2, f16x2Ty);
+ expectedResult.push_back(f16x2x4StructTy);
+ expectedResult.push_back(f32x8StructTy);
+ expectedC.emplace_back(4, f16x2Ty);
+ expectedC.emplace_back(8, f32Ty);
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (*getMultiplicandAPtxType() == MMATypes::f64) {
+ Type f64Ty = Float64Type::get(context);
+ expectedA.emplace_back(1, f64Ty);
+ expectedB.emplace_back(1, f64Ty);
+ expectedC.emplace_back(2, f64Ty);
+ // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+ expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(2, f64Ty)));
+ allowedShapes.push_back({8, 8, 4});
+ }
+ if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
+ expectedA.push_back({i32Ty});
+ expectedB.push_back({i32Ty});
+ expectedC.push_back({i32Ty, i32Ty});
+ expectedResult.push_back(s32x2StructTy);
+ if (isInt4PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 32});
+ if (isInt8PtxType(getMultiplicandAPtxType().value()))
+ allowedShapes.push_back({8, 8, 16});
+ if (getMultiplicandAPtxType().value() == MMATypes::b1)
+ allowedShapes.push_back({8, 8, 128});
+ }
+ }
+
+ std::string errorMessage;
+ llvm::raw_string_ostream errorStream(errorMessage);
+
+ // Check that we matched an existing shape/dtype combination.
+ if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
+ !llvm::is_contained(allowedShapes, mmaShape)) {
+ errorStream << "unimplemented variant for MMA shape <";
+ llvm::interleaveComma(mmaShape, errorStream);
+ errorStream << ">";
+ return emitOpError(errorMessage);
+ }
+
+ // Verify the operand types for segments of A, B, and C operands.
+ std::array<StringRef, 3> operandNames{"A", "B", "C"};
+ for (const auto &iter : llvm::enumerate(
+ SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
+ auto spec = this->getODSOperandIndexAndLength(iter.index());
+ SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
+ operand_type_begin() + spec.first +
+ spec.second);
+ bool match = llvm::is_contained(iter.value(), operandTySeg);
+
+ if (!match) {
+ errorStream << "Could not match types for the "
+ << operandNames[iter.index()]
+ << " operands; expected one of ";
+ for (const auto &x : iter.value()) {
+ errorStream << x.size() << "x" << x[0] << " ";
+ }
+ errorStream << "but got ";
+ llvm::interleaveComma(operandTySeg, errorStream);
+ return emitOpError(errorStream.str());
+ }
+ }
+
+ // Check the result type
+ if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
+ return expectedResultType == getResult().getType();
+ })) {
+ errorStream
+ << "Could not match allowed types for the result; expected one of ";
+ llvm::interleaveComma(expectedResult, errorStream);
+ errorStream << " but got " << getResult().getType();
+ return emitOpError(errorStream.str());
+ }
+
+ // Ensure that binary MMA variants have a b1 MMA operation defined.
+ if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
+ return emitOpError("op requires " + getB1OpAttrName().strref() +
+ " attribute");
+ }
+
+ // Ensure int4/int8 MMA variants specify the accum overflow behavior
+ // attribute.
+ if (isInt4PtxType(*getMultiplicandAPtxType()) ||
+ isInt8PtxType(*getMultiplicandAPtxType())) {
+ if (!getIntOverflowBehavior())
+ return emitOpError("op requires " +
+ getIntOverflowBehaviorAttrName().strref() +
+ " attribute");
+ }
+
+ return success();
+}
+
+LogicalResult ShflOp::verify() {
+ if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
+ return success();
+ auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+ auto elementType = (type && type.getBody().size() == 2)
+ ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
+ : nullptr;
+ if (!elementType || elementType.getWidth() != 1)
+ return emitError("expected return type to be a two-element struct with "
+ "i1 as the second element");
+ return success();
+}
+
+std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
+ NVVM::MMAFrag frag, int nRow,
+ int nCol,
+ MLIRContext *context) {
+ unsigned numberElements = 0;
+ Type elementType;
+ OpBuilder builder(context);
+ Type f16x2 = VectorType::get(2, builder.getF16Type());
+ if (type == NVVM::MMATypes::f16) {
+ elementType = f16x2;
+ if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+ numberElements = 8;
+ else
+ numberElements = 4;
+ } else if (type == NVVM::MMATypes::f32) {
+ elementType = builder.getF32Type();
+ numberElements = 8;
+ } else if (type == NVVM::MMATypes::tf32) {
+ elementType = builder.getI32Type();
+ numberElements = 4;
+ } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
+ elementType = builder.getI32Type();
+ int parallelSize = 0;
+ if (frag == NVVM::MMAFrag::a)
+ parallelSize = nRow;
+ if (frag == NVVM::MMAFrag::b)
+ parallelSize = nCol;
+
+ // m == 16 && n == 16 && k == 16
+ if (parallelSize == 16)
+ numberElements = 2;
+ // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
+ else if (parallelSize == 8)
+ numberElements = 1;
+ else if (parallelSize == 32)
+ numberElements = 4;
+ } else if (type == NVVM::MMATypes::s32) {
+ elementType = builder.getI32Type();
+ numberElements = 8;
+ }
+ assert(numberElements != 0 && elementType != nullptr);
+ return std::make_pair(elementType, numberElements);
+}
+
+static std::pair<mlir::Type, unsigned>
+inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
+ int k, MLIRContext *context) {
+ int nRow, nCol;
+ if (frag == NVVM::MMAFrag::a) {
+ nRow = m;
+ nCol = k;
+ } else if (frag == NVVM::MMAFrag::b) {
+ nRow = k;
+ nCol = n;
+ } else {
+ nRow = m;
+ nCol = n;
+ }
+ assert(nRow && nCol);
+ return inferMMAType(type, frag, nRow, nCol, context);
+}
+
+LogicalResult NVVM::WMMALoadOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
+ addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected source pointer in memory "
+ "space 0, 1, 3");
+
+ if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
+ getEltype(), getFrag()) == 0)
+ return emitOpError() << "invalid attribute combination";
+ std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
+ getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+ Type dstType = LLVM::LLVMStructType::getLiteral(
+ getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
+ if (getType() != dstType)
+ return emitOpError("expected destination type is a structure of ")
+ << typeInfo.second << " elements of type " << typeInfo.first;
+ return success();
+}
+
+LogicalResult NVVM::WMMAStoreOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
+ addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected operands to be a source pointer in memory "
+ "space 0, 1, 3");
+
+ if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
+ getEltype()) == 0)
+ return emitOpError() << "invalid attribute combination";
+ std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
+ getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
+ if (getArgs().size() != typeInfo.second)
+ return emitOpError() << "expected " << typeInfo.second << " data operands";
+ if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
+ return operands.getType() != typeInfo.first;
+ }))
+ return emitOpError() << "expected data operands of type " << typeInfo.first;
+ return success();
+}
+
+LogicalResult NVVM::WMMAMmaOp::verify() {
+ if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
+ getLayoutB(), getEltypeA(),
+ getEltypeB()) == 0)
+ return emitOpError() << "invalid attribute combination";
+ std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
+ getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
+ std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
+ getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
+ std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
+ getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
+ SmallVector<Type, 32> arguments;
+ arguments.append(typeInfoA.second, typeInfoA.first);
+ arguments.append(typeInfoB.second, typeInfoB.first);
+ arguments.append(typeInfoC.second, typeInfoC.first);
+ unsigned numArgs = arguments.size();
+ if (getArgs().size() != numArgs)
+ return emitOpError() << "expected " << numArgs << " arguments";
+ for (unsigned i = 0; i < numArgs; i++) {
+ if (getArgs()[i].getType() != arguments[i])
+ return emitOpError() << "expected argument " << i << " to be of type "
+ << arguments[i];
+ }
+ Type dstType = LLVM::LLVMStructType::getLiteral(
+ getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
+ if (getType() != dstType)
+ return emitOpError("expected destination type is a structure of ")
+ << typeInfoC.second << " elements of type " << typeInfoC.first;
+ return success();
+}
+
+LogicalResult NVVM::LdMatrixOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected source pointer in memory space 3");
+
+ if (getNum() != 1 && getNum() != 2 && getNum() != 4)
+ return emitOpError("expected num attribute to be 1, 2 or 4");
+
+ Type i32 = IntegerType::get(getContext(), 32);
+ if (getNum() == 1 && getType() != i32)
+ return emitOpError("expected destination type is i32");
+ if (getNum() == 2 || getNum() == 4) {
+ Type dstType = LLVM::LLVMStructType::getLiteral(
+ getContext(), SmallVector<Type>(getNum(), i32));
+ if (getType() != dstType)
+ return emitOpError("expected destination type is a structure of ")
+ << getNum() << " elements of type i32";
+ }
+ return success();
+}
+
+LogicalResult NVVM::StMatrixOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected source pointer in memory space 3");
+
+ int numMatrix = getSources().size();
+ if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+ return emitOpError("expected num attribute to be 1, 2 or 4");
+
+ return success();
+}
+
+FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
+ if (typeA == NVVM::WGMMATypes::tf32)
+ return 8;
+ if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
+ return 16;
+ if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
+ return 32;
+ if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
+ return 32;
+ if (typeA == NVVM::WGMMATypes::b1)
+ return 256;
+ return failure();
+}
+
+LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
+ NVVM::WGMMATypes typeA,
+ NVVM::WGMMATypes typeB) {
+ switch (typeA) {
+ case NVVM::WGMMATypes::f16:
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ typeB == NVVM::WGMMATypes::f16)
+ return success();
+ break;
+ case NVVM::WGMMATypes::tf32:
+ if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
+ return success();
+ break;
+ case NVVM::WGMMATypes::u8:
+ case NVVM::WGMMATypes::s8:
+ if (typeD == NVVM::WGMMATypes::s32 &&
+ (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
+ return success();
+ break;
+ case NVVM::WGMMATypes::b1:
+ if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
+ return success();
+ break;
+ case NVVM::WGMMATypes::bf16:
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ typeB == NVVM::WGMMATypes::bf16)
+ return success();
+ break;
+ case NVVM::WGMMATypes::e4m3:
+ case NVVM::WGMMATypes::e5m2:
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
+ (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
+ return success();
+ break;
+ case WGMMATypes::f32:
+ case WGMMATypes::s32:
+ llvm_unreachable("unsupported input types");
+ break;
+ }
+ return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
+ SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
+ 72, 80, 88, 96, 104, 112, 120, 128,
+ 136, 144, 152, 160, 168, 176, 184, 192,
+ 200, 208, 216, 224, 232, 240, 248, 256};
+ SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
+ 80, 96, 112, 128, 144, 160,
+ 176, 192, 208, 224, 240, 256};
+ switch (typeA) {
+ case WGMMATypes::f16:
+ case WGMMATypes::tf32:
+ case WGMMATypes::bf16:
+ case WGMMATypes::e4m3:
+ case WGMMATypes::e5m2:
+ if (llvm::is_contained(allowedN, sizeN))
+ return success();
+ break;
+ case WGMMATypes::u8:
+ case WGMMATypes::s8:
+ case WGMMATypes::b1:
+ if (llvm::is_contained(allowedNshort, sizeN))
+ return success();
+ break;
+ case WGMMATypes::f32:
+ case WGMMATypes::s32:
+ llvm_unreachable("unsupported input types");
+ break;
+ }
+ return failure();
+}
+
+LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
+ Value outValue = getResults();
+ auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
+ if (!stype)
+ return emitOpError() << "expected results to be struct";
+ int outputSize = stype.getBody().size();
+ WGMMATypes typeD = getTypeD();
+ WGMMATypes typeA = getTypeA();
+ WGMMATypes typeB = getTypeB();
+
+ for (Type t : stype.getBody()) {
+ if (t != stype.getBody().front())
+ return emitOpError()
+ << "all elements in struct must be same type but there is " << t;
+ }
+
+ if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
+ typeD != WGMMATypes::s32) {
+ return emitOpError() << "does not support the given output type "
+ << NVVM::stringifyWGMMATypes(typeD);
+ }
+ if (typeD == WGMMATypes::s32 &&
+ (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
+ return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
+ }
+
+ if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
+ return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
+ << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
+ << NVVM::stringifyWGMMATypes(typeB)
+ << ", it is not supported.";
+ }
+
+ // Check M
+ if (getShape().getM() != 64)
+ return emitOpError() << "shape 'm' must be 64";
+
+ // Check K
+ FailureOr<int> allowedK = getAllowedSizeK(typeA);
+ if (failed(allowedK) || allowedK.value() != getShape().getK())
+ return emitOpError() << "shape 'k' must be " << allowedK.value()
+ << " for input type "
+ << NVVM::stringifyWGMMATypes(typeA);
+
+ // Check N
+ if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
+ return emitOpError() << "has input type "
+ << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
+ << getShape().getN() << ", it is not supported.";
+ }
+
+ // Check transpose (only available for f16/bf16)
+ if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
+ (getLayoutA() == mlir::NVVM::MMALayout::col ||
+ getLayoutB() == mlir::NVVM::MMALayout::col)) {
+ return emitOpError()
+ << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
+ << " and layout_b = " << stringifyMMALayout(getLayoutB())
+ << " for input types " << stringifyWGMMATypes(typeA) << " and "
+ << stringifyWGMMATypes(typeB)
+ << " requires transpose. However, this is only supported for: "
+ << stringifyMMATypes(MMATypes::f16) << " and "
+ << stringifyMMATypes(MMATypes::bf16);
+ }
+
+ // Check result registers
+ int expectedOutput = 0;
+ if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
+ expectedOutput = getShape().getN() / 2;
+ if (typeD == WGMMATypes::f16)
+ expectedOutput = getShape().getN() / 4;
+ if (outputSize != expectedOutput) {
+ return emitOpError() << "results " << expectedOutput
+ << ", however output struct has " << outputSize
+ << " elements";
+ }
+ // Check satfinite (only available for s32 accumulator)
+ if (typeD != WGMMATypes::s32 &&
+ getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+ NVVM::MMAIntOverflow::satfinite) {
+ return emitOpError()
+ << " `satfinite` can be only used with s32 accumulator, however "
+ "the current accumulator is "
+ << NVVM::stringifyWGMMATypes(typeD);
+ }
+
+ return success();
+}
+
+std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
+
+ int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
+ bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
+
+ StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
+
+ int expectedOutputRegisters = 0;
+ if (getTypeD() == WGMMATypes::f16)
+ expectedOutputRegisters = getShape().getN() / 4;
+ else
+ expectedOutputRegisters = getShape().getN() / 2;
+
+ std::string ptx;
+ llvm::raw_string_ostream ss(ptx);
+
+ ss << "{\n"
+ ".reg .pred p;\n"
+ "setp.ne.b32 p, $"
+ << ((expectedOutputRegisters * 2) + 2)
+ << ", 0;\n"
+ "wgmma.mma_async.sync.aligned.m"
+ << m << "n" << n << "k" << k << "." << outputTypeName << "."
+ << stringifyWGMMATypes(getTypeA()) << "."
+ << stringifyWGMMATypes(getTypeB());
+ if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
+ NVVM::MMAIntOverflow::satfinite)
+ ss << ".satfinite";
+ ss << " {";
+ int regCnt = 0;
+ for (; regCnt < expectedOutputRegisters; ++regCnt) {
+ ss << "$" << regCnt;
+ if (regCnt != expectedOutputRegisters - 1)
+ ss << ", ";
+ }
+
+ ss << "},";
+ // Need to map read/write registers correctly.
+ regCnt = (regCnt * 2);
+ ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+ if (getTypeD() != WGMMATypes::s32) {
+ ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
+ }
+ // Don't add transpose parameters unless needed.
+ if (isF16) {
+ ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
+ }
+ ss << ";\n"
+ << "}\n";
+ ss.flush();
+ return ptx;
+}
+
+void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
+ if (getResults())
+ asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
+ if (getInouts())
+ asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
+ asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
+ mlir::NVVM::PTXRegisterMod::Read});
+ if (getTypeD() != WGMMATypes::s32) {
+ asmValues.push_back(
+ {makeConstantI32(rewriter,
+ getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+ mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back(
+ {makeConstantI32(rewriter,
+ getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
+ mlir::NVVM::PTXRegisterMod::Read});
+ }
+ if (isF16) {
+ asmValues.push_back(
+ {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
+ mlir::NVVM::PTXRegisterMod::Read});
+ asmValues.push_back(
+ {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
+ mlir::NVVM::PTXRegisterMod::Read});
+ }
+}
+LogicalResult NVVM::FenceProxyOp::verify() {
+ if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
+ return emitOpError() << "async_shared fence requires space attribute";
+ }
+ if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
+ return emitOpError() << "only async_shared fence can have space attribute";
+ }
+ return success();
+}
+
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+ if (getRegCount() % 8)
+ return emitOpError("new register size must be multiple of 8");
+ if (getRegCount() < 24 || getRegCount() > 256)
+ return emitOpError("new register size must be in between 24 to 256");
+ return success();
+}
+
+LogicalResult NVVM::BarrierOp::verify() {
+ if (getNumberOfThreads() && !getBarrierId())
+ return emitOpError(
+ "barrier id is missing, it should be set between 0 to 15");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVMDialect initialization, type parsing, and registration.
+//===----------------------------------------------------------------------===//
+
+// TODO: This should be the llvm.nvvm dialect once this is supported.
+void NVVMDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
+ >();
+
+ // Support unknown operations because not all NVVM operations are
+ // registered.
+ allowUnknownOperations();
+ declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
+ declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
+}
+
+LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
+ NamedAttribute attr) {
+ StringAttr attrName = attr.getName();
+ // Kernel function attribute should be attached to functions.
+ if (attrName == NVVMDialect::getKernelFuncAttrName()) {
+ if (!isa<LLVM::LLVMFuncOp>(op)) {
+ return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
+ << "' attribute attached to unexpected op";
+ }
+ }
+ // If maxntid and reqntid exist, it must be an array with max 3 dim
+ if (attrName == NVVMDialect::getMaxntidAttrName() ||
+ attrName == NVVMDialect::getReqntidAttrName()) {
+ auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
+ if (!values || values.empty() || values.size() > 3)
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be integer array with maximum 3 index";
+ }
+ // If minctasm and maxnreg exist, it must be an integer attribute
+ if (attrName == NVVMDialect::getMinctasmAttrName() ||
+ attrName == NVVMDialect::getMaxnregAttrName()) {
+ if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
+ return op->emitError()
+ << "'" << attrName << "' attribute must be integer constant";
+ }
+
+ return success();
+}
+
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+ unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute argAttr) {
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+
+ bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+ StringAttr attrName = argAttr.getName();
+ if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+ if (!isKernel) {
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be present only on kernel arguments";
+ }
+ if (!isa<UnitAttr>(argAttr.getValue()))
+ return op->emitError() << "'" << attrName << "' must be a unit attribute";
+ if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute requires the argument to also have attribute '"
+ << LLVM::LLVMDialect::getByValAttrName() << "'";
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM target attribute.
+//===----------------------------------------------------------------------===//
+LogicalResult
+NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ int optLevel, StringRef triple, StringRef chip,
+ StringRef features, DictionaryAttr flags,
+ ArrayAttr files) {
+ if (optLevel < 0 || optLevel > 3) {
+ emitError() << "The optimization level must be a number between 0 and 3.";
+ return failure();
+ }
+ if (triple.empty()) {
+ emitError() << "The target triple cannot be empty.";
+ return failure();
+ }
+ if (chip.empty()) {
+ emitError() << "The target chip cannot be empty.";
+ return failure();
+ }
+ if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
+ return attr && mlir::isa<StringAttr>(attr);
+ })) {
+ emitError() << "All the elements in the `link` array must be strings.";
+ return failure();
+ }
+ return success();
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
>From 740ba25c18140a87f5311725daf9ae74663dc067 Mon Sep 17 00:00:00 2001
From: bangyu shen <94283495+shubaoyu2 at users.noreply.github.com>
Date: Wed, 3 Jul 2024 17:08:00 +0800
Subject: [PATCH 4/4] change check cases when ab cannot be transposed in wgmma
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 036a9a15af838..48f44165ccc58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -880,7 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
// Check transpose (only available for f16/bf16)
if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
(getLayoutA() == mlir::NVVM::MMALayout::col ||
- getLayoutB() == mlir::NVVM::MMALayout::col)) {
+ getLayoutB() == mlir::NVVM::MMALayout::row)) {
return emitOpError()
<< "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
<< " and layout_b = " << stringifyMMALayout(getLayoutB())
More information about the Mlir-commits
mailing list