[llvm] [LLVM][NVPTX]: Add intrinsic for setmaxnreg (PR #77289)

Durgadoss R via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 8 01:34:30 PST 2024


https://github.com/durga4github created https://github.com/llvm/llvm-project/pull/77289

This patch adds an intrinsic for setmaxnreg PTX instruction.
* PTX Doc link for this instruction: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

* The first argument, an `i32 flags`, is a compile-time constant, indicating the inc(flag=0)/dec(flag=1) modifiers.
* The second argument, an immediate value, specifies the actual absolute register count for the instruction.
* The `setmaxnreg` instruction is available in SM90a. So, this patch adds 'hasSM90a' predicate to use in the NVPTX backend.
* lit tests are added to verify the lowering of the intrinsic.

The modifiers are encoded into `flags` so that the same intrinsic can be extended with more options in future. (without having to add separate intrinsics).

The flags are defined in Support/NVVMIntrinsicFlags.h, to facilitate usage by both upstream and downstream clients.

>From a10cfe5a83b2c5f86971fa3820e4349e7eab1c4d Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Fri, 5 Jan 2024 23:36:12 +0530
Subject: [PATCH] [LLVM][NVPTX]: Add intrinsic for setmaxnreg

This patch adds an intrinsic for setmaxnreg instruction.
* PTX Doc link for this instruction:
  https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

* The first argument, an `i32 flags`, is a compile-time
  constant, indicating the inc(flag=0)/dec(flag=1) modifiers.
* The second argument, an immediate value, specifies the actual
  absolute register count for the instruction.
* The `setmaxnreg` instruction is available in SM90a.
  So, this patch adds 'hasSM90a' predicate to use in
  the NVPTX backend.
* lit tests are added to verify the lowering of the intrinsic.

The modifiers are encoded into `flags` so that the same
intrinsic can be extended with more options in future.
(without having to add separate intrinsics).

The flags are defined in Support/NVVMIntrinsicFlags.h, to
facilitate usage by both upstream and downstream clients.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 llvm/include/llvm/IR/IntrinsicsNVVM.td        |  6 +++
 .../include/llvm/Support/NVVMIntrinsicFlags.h | 39 +++++++++++++++++++
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   | 20 ++++++++++
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |  2 +
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  4 ++
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td      | 13 +++++++
 llvm/test/CodeGen/NVPTX/setmaxnreg.ll         | 15 +++++++
 7 files changed, 99 insertions(+)
 create mode 100644 llvm/include/llvm/Support/NVVMIntrinsicFlags.h
 create mode 100644 llvm/test/CodeGen/NVPTX/setmaxnreg.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 6fd8e80013cee5..81c56ca3c6ee03 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -4710,4 +4710,10 @@ def int_nvvm_is_explicit_cluster
               [IntrNoMem, IntrSpeculatable, NoUndef<RetIndex>],
               "llvm.nvvm.is_explicit_cluster">;
 
+// Setmaxnreg intrinsic
+def int_nvvm_setmaxnreg_sync_aligned_u32
+  : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_i32_ty],
+              [IntrConvergent, IntrNoMem, IntrHasSideEffects, ImmArg<ArgIndex<1>>],
+              "llvm.nvvm.setmaxnreg.sync.aligned.u32">;
+
 } // let TargetPrefix = "nvvm"
diff --git a/llvm/include/llvm/Support/NVVMIntrinsicFlags.h b/llvm/include/llvm/Support/NVVMIntrinsicFlags.h
new file mode 100644
index 00000000000000..23c265831ae4e2
--- /dev/null
+++ b/llvm/include/llvm/Support/NVVMIntrinsicFlags.h
@@ -0,0 +1,39 @@
+//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This file contains the definitions of the enumerations and flags
+/// associated with NVVM Intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
+#define LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
+
+#include <stdint.h>
+
+namespace llvm {
+namespace nvvm {
+
+enum SetMaxNRegAction {
+  ACTION_INC = 0,
+  ACTION_DEC = 1,
+};
+
+typedef union {
+  uint32_t V;
+  struct {
+    uint32_t Action : 1;    // inc(0) or dec(1)
+    uint32_t reserved : 31;
+  } U;
+} SetMaxNRegFlags;
+
+} // namespace nvvm
+} // namespace llvm
+
+#endif // LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index b7a20c351f5ff6..c5e575e805ab93 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -20,6 +20,7 @@
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/NVVMIntrinsicFlags.h"
 #include <cctype>
 using namespace llvm;
 
@@ -340,3 +341,22 @@ void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
     break;
   }
 }
+
+void NVPTXInstPrinter::printSetMaxNRegActionFlag(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  nvvm::SetMaxNRegFlags Flags;
+  Flags.V = (int)MI->getOperand(OpNum).getImm();
+
+  using Action = nvvm::SetMaxNRegAction;
+  switch (Flags.U.Action) {
+    case Action::ACTION_INC:
+      O << ".inc";
+      break;
+    case Action::ACTION_DEC:
+      O << ".dec";
+      break;
+    default:
+      llvm_unreachable("Invalid action flag for setmaxnreg intrinsic");
+  }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index e6954f861cd10e..234d5f139ad496 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -49,6 +49,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                        raw_ostream &O, const char *Modifier = nullptr);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
                      const char *Modifier = nullptr);
+  void printSetMaxNRegActionFlag(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 13665985f52eba..aa02814061da4f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -164,6 +164,10 @@ def True : Predicate<"true">;
 class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
 class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
 
+// Explicit records for arch-accelerated SM versions
+def hasSM90a : Predicate<"Subtarget->getSmVersion() == 90"
+                          "&& Subtarget->getFullSmVersion() == 901">;
+
 // non-sync shfl instructions are not available on sm_70+ in PTX6.4+
 def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
                           "&& Subtarget->getPTXVersion() >= 64)">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 85eae44f349aa3..9594d47fedf5ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6727,3 +6727,16 @@ def is_explicit_cluster: NVPTXInst<(outs Int1Regs:$d), (ins),
               "mov.pred\t$d, %is_explicit_cluster;",
               [(set Int1Regs:$d, (int_nvvm_is_explicit_cluster))]>,
     Requires<[hasSM<90>, hasPTX<78>]>;
+
+// setmaxnreg intrinsic
+def SetMaxNRegFlags : Operand<i32> {
+  let PrintMethod = "printSetMaxNRegActionFlag";
+}
+
+let isConvergent = true in {
+def INT_SET_MAXNREG : NVPTXInst<(outs),
+    (ins SetMaxNRegFlags:$flags, i32imm:$reg_count),
+    "setmaxnreg${flags:action}.sync.aligned.u32 $reg_count;",
+    [(int_nvvm_setmaxnreg_sync_aligned_u32 imm:$flags, timm:$reg_count)]>,
+    Requires<[hasSM90a, hasPTX<80>]>;
+} // isConvergent
diff --git a/llvm/test/CodeGen/NVPTX/setmaxnreg.ll b/llvm/test/CodeGen/NVPTX/setmaxnreg.ll
new file mode 100644
index 00000000000000..25698088e1cf7b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/setmaxnreg.ll
@@ -0,0 +1,15 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90a -mattr=+ptx80| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_90a -mattr=+ptx80| %ptxas-verify -arch=sm_90a %}
+
+declare void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 %flags, i32 %reg_count)
+
+; CHECK-LABEL: test_set_maxn_reg
+define void @test_set_maxn_reg() {
+  ; CHECK: setmaxnreg.inc.sync.aligned.u32 96;
+  call void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 0, i32 96)
+
+  ; CHECK: setmaxnreg.dec.sync.aligned.u32 64;
+  call void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 1, i32 64)
+
+  ret void
+}



More information about the llvm-commits mailing list