[llvm] [LLVM][NVPTX]: Add intrinsic for setmaxnreg (PR #77289)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 8 01:34:59 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-support
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This patch adds an intrinsic for setmaxnreg PTX instruction.
* PTX Doc link for this instruction: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg
* The first argument, an `i32 flags`, is a compile-time constant, indicating the inc(flag=0)/dec(flag=1) modifiers.
* The second argument, an immediate value, specifies the actual absolute register count for the instruction.
* The `setmaxnreg` instruction is available in SM90a. So, this patch adds 'hasSM90a' predicate to use in the NVPTX backend.
* lit tests are added to verify the lowering of the intrinsic.
The modifiers are encoded into `flags` so that the same intrinsic can be extended with more options in future. (without having to add separate intrinsics).
The flags are defined in Support/NVVMIntrinsicFlags.h, to facilitate usage by both upstream and downstream clients.
---
Full diff: https://github.com/llvm/llvm-project/pull/77289.diff
7 Files Affected:
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+6)
- (added) llvm/include/llvm/Support/NVVMIntrinsicFlags.h (+39)
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+20)
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+4)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+13)
- (added) llvm/test/CodeGen/NVPTX/setmaxnreg.ll (+15)
``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 6fd8e80013cee5..81c56ca3c6ee03 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -4710,4 +4710,10 @@ def int_nvvm_is_explicit_cluster
[IntrNoMem, IntrSpeculatable, NoUndef<RetIndex>],
"llvm.nvvm.is_explicit_cluster">;
+// Setmaxnreg intrinsic
+def int_nvvm_setmaxnreg_sync_aligned_u32
+ : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_i32_ty],
+ [IntrConvergent, IntrNoMem, IntrHasSideEffects, ImmArg<ArgIndex<1>>],
+ "llvm.nvvm.setmaxnreg.sync.aligned.u32">;
+
} // let TargetPrefix = "nvvm"
diff --git a/llvm/include/llvm/Support/NVVMIntrinsicFlags.h b/llvm/include/llvm/Support/NVVMIntrinsicFlags.h
new file mode 100644
index 00000000000000..23c265831ae4e2
--- /dev/null
+++ b/llvm/include/llvm/Support/NVVMIntrinsicFlags.h
@@ -0,0 +1,39 @@
+//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This file contains the definitions of the enumerations and flags
+/// associated with NVVM Intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
+#define LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
+
+#include <stdint.h>
+
+namespace llvm {
+namespace nvvm {
+
+enum SetMaxNRegAction {
+ ACTION_INC = 0,
+ ACTION_DEC = 1,
+};
+
+typedef union {
+ uint32_t V;
+ struct {
+ uint32_t Action : 1; // inc(0) or dec(1)
+ uint32_t reserved : 31;
+ } U;
+} SetMaxNRegFlags;
+
+} // namespace nvvm
+} // namespace llvm
+
+#endif // LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index b7a20c351f5ff6..c5e575e805ab93 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -20,6 +20,7 @@
#include "llvm/MC/MCSymbol.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/NVVMIntrinsicFlags.h"
#include <cctype>
using namespace llvm;
@@ -340,3 +341,22 @@ void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
break;
}
}
+
+void NVPTXInstPrinter::printSetMaxNRegActionFlag(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ nvvm::SetMaxNRegFlags Flags;
+ Flags.V = (int)MI->getOperand(OpNum).getImm();
+
+ using Action = nvvm::SetMaxNRegAction;
+ switch (Flags.U.Action) {
+ case Action::ACTION_INC:
+ O << ".inc";
+ break;
+ case Action::ACTION_DEC:
+ O << ".dec";
+ break;
+ default:
+ llvm_unreachable("Invalid action flag for setmaxnreg intrinsic");
+ }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index e6954f861cd10e..234d5f139ad496 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -49,6 +49,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
raw_ostream &O, const char *Modifier = nullptr);
void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
+ void printSetMaxNRegActionFlag(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 13665985f52eba..aa02814061da4f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -164,6 +164,10 @@ def True : Predicate<"true">;
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
+// Explicit records for arch-accelerated SM versions
+def hasSM90a : Predicate<"Subtarget->getSmVersion() == 90"
+ "&& Subtarget->getFullSmVersion() == 901">;
+
// non-sync shfl instructions are not available on sm_70+ in PTX6.4+
def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
"&& Subtarget->getPTXVersion() >= 64)">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 85eae44f349aa3..9594d47fedf5ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6727,3 +6727,16 @@ def is_explicit_cluster: NVPTXInst<(outs Int1Regs:$d), (ins),
"mov.pred\t$d, %is_explicit_cluster;",
[(set Int1Regs:$d, (int_nvvm_is_explicit_cluster))]>,
Requires<[hasSM<90>, hasPTX<78>]>;
+
+// setmaxnreg intrinsic
+def SetMaxNRegFlags : Operand<i32> {
+ let PrintMethod = "printSetMaxNRegActionFlag";
+}
+
+let isConvergent = true in {
+def INT_SET_MAXNREG : NVPTXInst<(outs),
+ (ins SetMaxNRegFlags:$flags, i32imm:$reg_count),
+ "setmaxnreg${flags:action}.sync.aligned.u32 $reg_count;",
+ [(int_nvvm_setmaxnreg_sync_aligned_u32 imm:$flags, timm:$reg_count)]>,
+ Requires<[hasSM90a, hasPTX<80>]>;
+} // isConvergent
diff --git a/llvm/test/CodeGen/NVPTX/setmaxnreg.ll b/llvm/test/CodeGen/NVPTX/setmaxnreg.ll
new file mode 100644
index 00000000000000..25698088e1cf7b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/setmaxnreg.ll
@@ -0,0 +1,15 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90a -mattr=+ptx80| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_90a -mattr=+ptx80| %ptxas-verify -arch=sm_90a %}
+
+declare void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 %flags, i32 %reg_count)
+
+; CHECK-LABEL: test_set_maxn_reg
+define void @test_set_maxn_reg() {
+ ; CHECK: setmaxnreg.inc.sync.aligned.u32 96;
+ call void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 0, i32 96)
+
+ ; CHECK: setmaxnreg.dec.sync.aligned.u32 64;
+ call void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 1, i32 64)
+
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/77289
More information about the llvm-commits
mailing list