[Mlir-commits] [mlir] d5fb4c0 - [MLIR][NVVM] Enable nvvm intrinsics import to LLVMIR (#68843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 20:31:59 PST 2023


Author: Ivan R. Ivanov
Date: 2023-12-12T13:31:55+09:00
New Revision: d5fb4c0f118b47db74233af2d99ae075e1dbe148

URL: https://github.com/llvm/llvm-project/commit/d5fb4c0f118b47db74233af2d99ae075e1dbe148
DIFF: https://github.com/llvm/llvm-project/commit/d5fb4c0f118b47db74233af2d99ae075e1dbe148.diff

LOG: [MLIR][NVVM] Enable nvvm intrinsics import to LLVMIR   (#68843)

Co-authored-by: Tobias Gysi <tobias.gysi at nextsilicon.com>
Co-authored-by: Christian Ulmann <christianulmann at gmail.com>

Added: 
    mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h
    mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
    mlir/test/Target/LLVMIR/Import/nvvmir.ll

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/Target/LLVMIR/Dialect/All.h
    mlir/lib/Target/LLVMIR/Dialect/NVVM/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 64de028c7fe406..8e41fcc05a161e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -56,6 +56,8 @@ add_mlir_dialect(NVVMOps nvvm)
 add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
 set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
 mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
+mlir_tablegen(NVVMFromLLVMIRConversions.inc -gen-intr-from-llvmir-conversions)
+mlir_tablegen(NVVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
 mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
 mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 5dfc15afb75931..0b37e23e45118b 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -22,6 +22,7 @@
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
@@ -74,6 +75,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
 static inline void
 registerAllFromLLVMIRTranslations(DialectRegistry &registry) {
   registerLLVMDialectImport(registry);
+  registerNVVMDialectImport(registry);
 }
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h
new file mode 100644
index 00000000000000..02ee83284dd331
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h
@@ -0,0 +1,31 @@
+//===- LLVMIRToNVVMTranslation.h - LLVM IR to NVVM Dialect ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for LLVM IR to NVVM dialect translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_NVVM_LLVMIRTONVVMTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_NVVM_LLVMIRTONVVMTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Registers the NVVM dialect and its import from LLVM IR in the given
+/// registry.
+void registerNVVMDialectImport(DialectRegistry &registry);
+
+/// Registers the NVVM dialect and its import from LLVM IR with the given
+/// context.
+void registerNVVMDialectImport(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_NVVM_LLVMIRTONVVMTRANSLATION_H

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/NVVM/CMakeLists.txt
index 9f3935b0c3f473..a90de157981612 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/CMakeLists.txt
@@ -1,3 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+  LLVMIRToNVVMTranslation.cpp
+  NVVMToLLVMIRTranslation.cpp
+  )
+
+add_mlir_translation_library(MLIRLLVMIRToNVVMTranslation
+  LLVMIRToNVVMTranslation.cpp
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRNVVMDialect
+  MLIRSupport
+  MLIRTargetLLVMIRImport
+  )
+
 add_mlir_translation_library(MLIRNVVMToLLVMIRTranslation
   NVVMToLLVMIRTranslation.cpp
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
new file mode 100644
index 00000000000000..855abc12a909ef
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -0,0 +1,93 @@
+//===- LLVMIRToNVVMTranslation.cpp - Translate LLVM IR to NVVM dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between LLVM IR and the MLIR NVVM dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Target/LLVMIR/ModuleImport.h"
+
+#include "llvm/IR/IntrinsicsNVPTX.h"
+
+using namespace mlir;
+using namespace mlir::NVVM;
+
+/// Returns true if the LLVM IR intrinsic is convertible to an MLIR NVVM dialect
+/// intrinsic. Returns false otherwise.
+static bool isConvertibleIntrinsic(llvm::Intrinsic::ID id) {
+  static const DenseSet<unsigned> convertibleIntrinsics = {
+#include "mlir/Dialect/LLVMIR/NVVMConvertibleLLVMIRIntrinsics.inc"
+  };
+  return convertibleIntrinsics.contains(id);
+}
+
+/// Returns the list of LLVM IR intrinsic identifiers that are convertible to
+/// MLIR NVVM dialect intrinsics.
+static ArrayRef<unsigned> getSupportedIntrinsicsImpl() {
+  static const SmallVector<unsigned> convertibleIntrinsics = {
+#include "mlir/Dialect/LLVMIR/NVVMConvertibleLLVMIRIntrinsics.inc"
+  };
+  return convertibleIntrinsics;
+}
+
+/// Converts the LLVM intrinsic to an MLIR NVVM dialect operation if a
+/// conversion exits. Returns failure otherwise.
+static LogicalResult convertIntrinsicImpl(OpBuilder &odsBuilder,
+                                          llvm::CallInst *inst,
+                                          LLVM::ModuleImport &moduleImport) {
+  llvm::Intrinsic::ID intrinsicID = inst->getIntrinsicID();
+
+  // Check if the intrinsic is convertible to an MLIR dialect counterpart and
+  // copy the arguments to an an LLVM operands array reference for conversion.
+  if (isConvertibleIntrinsic(intrinsicID)) {
+    SmallVector<llvm::Value *> args(inst->args());
+    ArrayRef<llvm::Value *> llvmOperands(args);
+#include "mlir/Dialect/LLVMIR/NVVMFromLLVMIRConversions.inc"
+  }
+
+  return failure();
+}
+
+namespace {
+
+/// Implementation of the dialect interface that converts operations belonging
+/// to the NVVM dialect.
+class NVVMDialectLLVMIRImportInterface : public LLVMImportDialectInterface {
+public:
+  using LLVMImportDialectInterface::LLVMImportDialectInterface;
+
+  /// Converts the LLVM intrinsic to an MLIR NVVM dialect operation if a
+  /// conversion exits. Returns failure otherwise.
+  LogicalResult convertIntrinsic(OpBuilder &builder, llvm::CallInst *inst,
+                                 LLVM::ModuleImport &moduleImport) const final {
+    return convertIntrinsicImpl(builder, inst, moduleImport);
+  }
+
+  /// Returns the list of LLVM IR intrinsic identifiers that are convertible to
+  /// MLIR NVVM dialect intrinsics.
+  ArrayRef<unsigned> getSupportedIntrinsics() const final {
+    return getSupportedIntrinsicsImpl();
+  }
+};
+
+} // namespace
+
+void mlir::registerNVVMDialectImport(DialectRegistry &registry) {
+  registry.insert<NVVM::NVVMDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
+    dialect->addInterfaces<NVVMDialectLLVMIRImportInterface>();
+  });
+}
+
+void mlir::registerNVVMDialectImport(MLIRContext &context) {
+  DialectRegistry registry;
+  registerNVVMDialectImport(registry);
+  context.appendDialectRegistry(registry);
+}

diff  --git a/mlir/test/Target/LLVMIR/Import/nvvmir.ll b/mlir/test/Target/LLVMIR/Import/nvvmir.ll
new file mode 100644
index 00000000000000..e4a8773e2dd806
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nvvmir.ll
@@ -0,0 +1,355 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL: @nvvm_special_regs
+define i32 @nvvm_special_regs() {
+  ; CHECK: = nvvm.read.ptx.sreg.tid.x : i32
+  %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.tid.y : i32
+  %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.tid.z : i32
+  %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.ntid.x : i32
+  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.ntid.y : i32
+  %5 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.ntid.z : i32
+  %6 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.ctaid.x : i32
+  %7 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.ctaid.y : i32
+  %8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.ctaid.z : i32
+  %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.nctaid.x : i32
+  %10 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.nctaid.y : i32
+  %11 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.nctaid.z : i32
+  %12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.warpsize : i32
+  %13 = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
+  ; CHECK: = nvvm.read.ptx.sreg.laneid : i32
+  %14 = call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+  ; CHECK: = nvvm.read.ptx.sreg.clusterid.x : i32
+  %15 = call i32 @llvm.nvvm.read.ptx.sreg.clusterid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.clusterid.y : i32
+  %16 = call i32 @llvm.nvvm.read.ptx.sreg.clusterid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.clusterid.z : i32
+  %17 = call i32 @llvm.nvvm.read.ptx.sreg.clusterid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.nclusterid.x : i32
+  %18 = call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.nclusterid.y : i32
+  %19 = call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.nclusterid.z : i32
+  %20 = call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.ctaid.x : i32
+  %21 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.ctaid.y : i32
+  %22 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.ctaid.z : i32
+  %23 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.nctaid.x : i32
+  %24 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid.x()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.nctaid.y : i32
+  %25 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid.y()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.nctaid.z : i32
+  %26 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid.z()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.ctarank : i32
+  %27 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank()
+  ; CHECK: = nvvm.read.ptx.sreg.cluster.nctarank : i32
+  %28 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank()
+  ret i32 %1
+}
+
+; CHECK-LABEL: @nvvm_rcp
+define float @nvvm_rcp(float %0) {
+  ; CHECK: = nvvm.rcp.approx.ftz.f %{{.*}} : f32
+  %2 = call float @llvm.nvvm.rcp.approx.ftz.f(float %0)
+  ret float %2
+}
+
+; TODO: Support the intrinsics below once they derive from NVVM_IntrOp rather than from NVVM_Op.
+
+; define void @llvm_nvvm_barrier0() {
+;   call void @llvm.nvvm.barrier0()
+;   ret void
+; }
+;
+; define i32 @nvvm_shfl(i32 %0, i32 %1, i32 %2, i32 %3, float %4) {
+;   %6 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %7 = call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %0, float %4, i32 %1, i32 %2)
+;   %8 = call i32 @llvm.nvvm.shfl.sync.up.i32(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %9 = call float @llvm.nvvm.shfl.sync.up.f32(i32 %0, float %4, i32 %1, i32 %2)
+;   %10 = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %11 = call float @llvm.nvvm.shfl.sync.down.f32(i32 %0, float %4, i32 %1, i32 %2)
+;   %12 = call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %13 = call float @llvm.nvvm.shfl.sync.idx.f32(i32 %0, float %4, i32 %1, i32 %2)
+;   ret i32 %6
+; }
+;
+; define { i32, i1 } @nvvm_shfl_pred(i32 %0, i32 %1, i32 %2, i32 %3, float %4) {
+;   %6 = call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %7 = call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %0, float %4, i32 %1, i32 %2)
+;   %8 = call { i32, i1 } @llvm.nvvm.shfl.sync.up.i32p(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %9 = call { float, i1 } @llvm.nvvm.shfl.sync.up.f32p(i32 %0, float %4, i32 %1, i32 %2)
+;   %10 = call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %11 = call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %0, float %4, i32 %1, i32 %2)
+;   %12 = call { i32, i1 } @llvm.nvvm.shfl.sync.idx.i32p(i32 %0, i32 %3, i32 %1, i32 %2)
+;   %13 = call { float, i1 } @llvm.nvvm.shfl.sync.idx.f32p(i32 %0, float %4, i32 %1, i32 %2)
+;   ret { i32, i1 } %6
+; }
+;
+; define i32 @nvvm_vote(i32 %0, i1 %1) {
+;   %3 = call i32 @llvm.nvvm.vote.ballot.sync(i32 %0, i1 %1)
+;   ret i32 %3
+; }
+;
+; define { float, float, float, float, float, float, float, float } @nvvm_mma_mn8n8k4_row_col_f32_f32(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, float %4, float %5, float %6, float %7, float %8, float %9, float %10, float %11) {
+;   %13 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, float %4, float %5, float %6, float %7, float %8, float %9, float %10, float %11)
+;   ret { float, float, float, float, float, float, float, float } %13
+; }
+;
+; define { <2 x half>, <2 x half> } @nvvm_mma_m16n8k16_f16_f16(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, <2 x half> %6, <2 x half> %7) {
+;   %9 = call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, <2 x half> %6, <2 x half> %7)
+;   ret { <2 x half>, <2 x half> } %9
+; }
+;
+; define { float, float, float, float } @nvvm_mma_m16n8k16_f32_f16(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, <2 x half> %6, <2 x half> %7) {
+;   %9 = call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, <2 x half> %6, <2 x half> %7)
+;   ret { float, float, float, float } %9
+; }
+;
+; define { <2 x half>, <2 x half> } @nvvm_mma_m16n8k16_f16_f32(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, float %6, float %7, float %8, float %9) {
+;   %11 = call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, float %6, float %7, float %8, float %9)
+;   ret { <2 x half>, <2 x half> } %11
+; }
+;
+; define { float, float, float, float } @nvvm_mma_m16n8k16_f32_f32(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, float %6, float %7, float %8, float %9) {
+;   %11 = call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, float %6, float %7, float %8, float %9)
+;   ret { float, float, float, float } %11
+; }
+;
+; define { i32, i32, i32, i32 } @nvvm_mma_m16n8k16_s8_s8(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6) {
+;   %8 = call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6)
+;   ret { i32, i32, i32, i32 } %8
+; }
+;
+; define { i32, i32, i32, i32 } @nvvm_mma_m16n8k16_s8_u8(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6) {
+;   %8 = call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6)
+;   ret { i32, i32, i32, i32 } %8
+; }
+;
+; define { i32, i32, i32, i32 } @nvvm_mma_m16n8k128_b1_b1(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6) {
+;   %8 = call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6)
+;   ret { i32, i32, i32, i32 } %8
+; }
+;
+; define { i32, i32, i32, i32 } @nvvm_mma_m16n8k32_s4_s4(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6) {
+;   %8 = call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6)
+;   ret { i32, i32, i32, i32 } %8
+; }
+;
+; define { double, double } @nvvm_mma_m8n8k4_f64_f64(double %0, double %1, double %2, double %3) {
+;   %5 = call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64(double %0, double %1, double %2, double %3)
+;   ret { double, double } %5
+; }
+;
+; define { float, float, float, float } @nvvm_mma_m16n8k4_tf32_f32(i32 %0, i32 %1, i32 %2, float %3, float %4, float %5, float %6) {
+;   %8 = call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 %0, i32 %1, i32 %2, float %3, float %4, float %5, float %6)
+;   ret { float, float, float, float } %8
+; }
+;
+; define void @gpu_wmma_load_op(ptr addrspace(3) %0, i32 %1) {
+;   %3 = call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3(ptr addrspace(3) %0, i32 %1)
+;   ret void
+; }
+;
+; define void @gpu_wmma_store_op(ptr addrspace(3) %0, i32 %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) {
+;   call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3(ptr addrspace(3) %0, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, i32 %1)
+;   ret void
+; }
+;
+; define void @gpu_wmma_mma_op(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, <2 x half> %6, <2 x half> %7, <2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %19) {
+;   %21 = call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> %0, <2 x half> %1, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5, <2 x half> %6, <2 x half> %7, <2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %19)
+;   ret void
+; }
+;
+; define void @nvvm_wmma_load_tf32(ptr %0, i32 %1) {
+;   %3 = call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0(ptr %0, i32 %1)
+;   ret void
+; }
+;
+; define void @nvvm_wmma_mma(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, float %8, float %9, float %10, float %11, float %12, float %13, float %14, float %15) {
+;   %17 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, float %8, float %9, float %10, float %11, float %12, float %13, float %14, float %15)
+;   ret void
+; }
+;
+; define void @cp_async(ptr addrspace(3) %0, ptr addrspace(1) %1) {
+;   call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %0, ptr addrspace(1) %1)
+;   call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %0, ptr addrspace(1) %1)
+;   call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %0, ptr addrspace(1) %1)
+;   call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %0, ptr addrspace(1) %1)
+;   call void @llvm.nvvm.cp.async.commit.group()
+;   call void @llvm.nvvm.cp.async.wait.group(i32 0)
+;   ret void
+; }
+;
+; define void @ld_matrix(ptr addrspace(3) %0) {
+;   %2 = call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %0)
+;   %3 = call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %0)
+;   %4 = call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %0)
+;   %5 = call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %0)
+;   %6 = call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %0)
+;   %7 = call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %0)
+;   ret void
+; }
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.warpsize()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.clusterid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.clusterid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.clusterid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nclusterid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nclusterid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.nclusterid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid.x()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid.y()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid.z()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank()
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank()
+
+declare float @llvm.nvvm.rcp.approx.ftz.f(float)
+
+declare void @llvm.nvvm.barrier0()
+
+declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32)
+
+declare float @llvm.nvvm.shfl.sync.bfly.f32(i32, float, i32, i32)
+
+declare i32 @llvm.nvvm.shfl.sync.up.i32(i32, i32, i32, i32)
+
+declare float @llvm.nvvm.shfl.sync.up.f32(i32, float, i32, i32)
+
+declare i32 @llvm.nvvm.shfl.sync.down.i32(i32, i32, i32, i32)
+
+declare float @llvm.nvvm.shfl.sync.down.f32(i32, float, i32, i32)
+
+declare i32 @llvm.nvvm.shfl.sync.idx.i32(i32, i32, i32, i32)
+
+declare float @llvm.nvvm.shfl.sync.idx.f32(i32, float, i32, i32)
+
+declare { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32, i32, i32, i32)
+
+declare { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32, float, i32, i32)
+
+declare { i32, i1 } @llvm.nvvm.shfl.sync.up.i32p(i32, i32, i32, i32)
+
+declare { float, i1 } @llvm.nvvm.shfl.sync.up.f32p(i32, float, i32, i32)
+
+declare { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32, i32, i32, i32)
+
+declare { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32, float, i32, i32)
+
+declare { i32, i1 } @llvm.nvvm.shfl.sync.idx.i32p(i32, i32, i32, i32)
+
+declare { float, i1 } @llvm.nvvm.shfl.sync.idx.f32p(i32, float, i32, i32)
+
+declare i32 @llvm.nvvm.vote.ballot.sync(i32, i1)
+
+declare { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32(<2 x half>, <2 x half>, <2 x half>, <2 x half>, float, float, float, float, float, float, float, float)
+
+declare { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16(<2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>)
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16(<2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>)
+
+declare { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32(<2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, float, float, float, float)
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32(<2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, float, float, float, float)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8(i32, i32, i32, i32, i32, i32, i32)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8(i32, i32, i32, i32, i32, i32, i32)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1(i32, i32, i32, i32, i32, i32, i32)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4(i32, i32, i32, i32, i32, i32, i32)
+
+declare { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64(double, double, double, double)
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float)
+
+declare { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3(ptr addrspace(3) nocapture readonly, i32)
+
+declare void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3(ptr addrspace(3) nocapture writeonly, <2 x half>, <2 x half>, <2 x half>, <2 x half>, i32)
+
+declare { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0(ptr nocapture readonly, i32)
+
+declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32, i32, i32, i32, i32, i32, i32, i32, float, float, float, float, float, float, float, float)
+
+declare void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) noalias writeonly, ptr addrspace(1) noalias readonly)
+
+declare void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) noalias writeonly, ptr addrspace(1) noalias readonly)
+
+declare void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) noalias writeonly, ptr addrspace(1) noalias readonly)
+
+declare void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) noalias writeonly, ptr addrspace(1) noalias readonly)
+
+declare void @llvm.nvvm.cp.async.commit.group()
+
+declare void @llvm.nvvm.cp.async.wait.group(i32 immarg)
+
+declare i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) nocapture readonly)
+
+declare { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) nocapture readonly)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) nocapture readonly)
+
+declare i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) nocapture readonly)
+
+declare { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) nocapture readonly)
+
+declare { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) nocapture readonly)


        


More information about the Mlir-commits mailing list