[llvm] add register logging (PR #137664)

Steffi Stumpos via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 28 09:29:13 PDT 2025


https://github.com/stumpOS created https://github.com/llvm/llvm-project/pull/137664

None

>From 1b915430bbc1d7959ee8cdd853f2c28631ff2c8a Mon Sep 17 00:00:00 2001
From: Stephanie <stumpos at modular.com>
Date: Mon, 28 Apr 2025 16:27:07 +0000
Subject: [PATCH] add register logging

---
 llvm/lib/Target/AMDGPU/VOP3PInstructions.td  |  2 +-
 llvm/lib/Target/NVPTX/CMakeLists.txt         |  1 +
 llvm/lib/Target/NVPTX/NVPTX.h                |  2 +-
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp  |  2 +-
 llvm/lib/Target/NVPTX/NVPTXRegCount.cpp      | 64 ++++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 15 +++++
 6 files changed, 83 insertions(+), 3 deletions(-)
 create mode 100644 llvm/lib/Target/NVPTX/NVPTXRegCount.cpp

diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index d8088b8c638fd..46d2165e91d16 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -626,7 +626,7 @@ class VOPProfileMAI<VOPProfile P, RegisterOperand _SrcRC, RegisterOperand _DstRC
   // and with the earlyclobber flag on the dst. This is stricter than the
   // actual HW restriction. In particular earlyclobber also affects src0 and
   // src1 allocation which is not required.
-  bit NoDstOverlap = !gt(DstVT.Size, 128);
+  bit NoDstOverlap = 1; //!gt(DstVT.Size, 128);
 }
 
 class VOPProfileSMFMAC<VOPProfile P, RegisterOperand _DstRC,
diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index 1cffde138eab7..c6d8f699cc3fe 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -27,6 +27,7 @@ set(NVPTXCodeGen_sources
   NVPTXLowerArgs.cpp
   NVPTXLowerAlloca.cpp
   NVPTXLowerUnreachable.cpp
+  NVPTXRegCount.cpp
   NVPTXPeephole.cpp
   NVPTXMCExpr.cpp
   NVPTXPrologEpilogPass.cpp
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ff983e52179af..6561b8022e836 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -36,7 +36,7 @@ enum CondCodes {
   GE
 };
 }
-
+FunctionPass *createNVPTXRegCountPass();
 FunctionPass *createNVPTXISelDag(NVPTXTargetMachine &TM,
                                  llvm::CodeGenOptLevel OptLevel);
 ModulePass *createNVPTXAssignValidGlobalNamesPass();
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec1f969494cd1..3a8381319ec3b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1852,7 +1852,7 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
     case 1: {
       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
       SDValue Imm = Ops[0];
-      if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
+      if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
         // Convert immediate to target constant
         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegCount.cpp b/llvm/lib/Target/NVPTX/NVPTXRegCount.cpp
new file mode 100644
index 0000000000000..5caf934ed4f2c
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXRegCount.cpp
@@ -0,0 +1,64 @@
+//===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an instruction selector for the NVPTX target.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NVPTXISelDAGToDAG.h"
+#include "NVPTX.h"
+#include "NVPTXUtilities.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
+#include "llvm/IR/NVVMIntrinsicUtils.h"
+#include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <optional>
+using namespace llvm;
+
+namespace {
+class NVPTXRegCountPass : public MachineFunctionPass {
+    public:
+      static char ID;
+      NVPTXRegCountPass() : MachineFunctionPass(ID) {}
+    
+      bool runOnMachineFunction(MachineFunction &MF) override {
+        unsigned maxRegs = 0;
+        for (const MachineBasicBlock &MBB : MF) {
+          unsigned liveRegs = 0;
+          for (const MachineInstr &MI : MBB) {
+            // Count unique virtual and physical registers
+            for (const MachineOperand &MO : MI.operands()) {
+              if (MO.isReg() && MO.getReg())
+                liveRegs++;
+            }
+          }
+          maxRegs = std::max(maxRegs, liveRegs);
+        }
+        errs() << "Function " << MF.getName() << " uses maximum of " 
+               << maxRegs << " registers\n";
+        return false;
+      }
+    };
+} // namespace
+
+char NVPTXRegCountPass::ID = 0;
+// INITIALIZE_PASS(NVPTXRegCountPass, "nvptx-count-reg",
+//     "NVPTX count reg", false, false)
+
+    FunctionPass *llvm::createNVPTXRegCountPass() {
+      return new NVPTXRegCountPass();
+    }
\ No newline at end of file
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 8a25256ea1e4a..b9dd54ca3be60 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -234,6 +234,19 @@ void NVPTXTargetMachine::registerDefaultAliasAnalyses(AAManager &AAM) {
   AAM.registerFunctionAnalysis<NVPTXAA>();
 }
 
+struct NVPTXModulePrinter : public PassInfoMixin<NVPTXModulePrinter> {
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
+    std::error_code EC;
+    raw_fd_ostream OutFile("/home/ubuntu/modular/delete-me-test_batch_kv_cache_flash_attention_causal_mask_ragged_paged.ll", EC);
+    if (!EC) {
+      M.print(OutFile, nullptr);
+    }
+    return PreservedAnalyses::all();
+  }
+  
+  static bool isRequired() { return true; }
+};
+
 void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
 #define GET_PASS_REGISTRY "NVPTXPassRegistry.def"
 #include "llvm/Passes/TargetPassRegistry.inc"
@@ -250,6 +263,7 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         FPM.addPass(NVVMIntrRangePass());
         if (EarlyByValArgsCopy)
           FPM.addPass(NVPTXCopyByValArgsPass());
+        //PM.addPass(NVPTXModulePrinter());
         PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
       });
 
@@ -418,6 +432,7 @@ void NVPTXPassConfig::addPreRegAlloc() {
 
 void NVPTXPassConfig::addPostRegAlloc() {
   addPass(createNVPTXPrologEpilogPass());
+  addPass(createNVPTXRegCountPass());
   if (getOptLevel() != CodeGenOptLevel::None) {
     // NVPTXPrologEpilogPass calculates frame object offset and replace frame
     // index with VRFrame register. NVPTXPeephole need to be run after that and



More information about the llvm-commits mailing list