[clang] [llvm] [NVPTX] Revamp NVVMIntrRange pass (PR #94422)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 5 08:35:18 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/94422

>From 708374e03f1bf70006f2472f19edad1bd621e2d6 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Mon, 3 Jun 2024 16:46:36 +0000
Subject: [PATCH] [NVPTX] Revamp NVVMIntrRange pass

---
 clang/test/CodeGenCUDA/cuda-builtin-vars.cu  |  24 +--
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp    |  32 ++--
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp |   6 +-
 llvm/lib/Target/NVPTX/NVPTXUtilities.cpp     |  58 ++++--
 llvm/lib/Target/NVPTX/NVPTXUtilities.h       |  16 +-
 llvm/lib/Target/NVPTX/NVVMIntrRange.cpp      | 177 ++++++++++---------
 llvm/test/CodeGen/NVPTX/intr-range.ll        |  60 +++++++
 llvm/test/CodeGen/NVPTX/intrinsic-old.ll     |  43 ++---
 8 files changed, 249 insertions(+), 167 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/intr-range.ll

diff --git a/clang/test/CodeGenCUDA/cuda-builtin-vars.cu b/clang/test/CodeGenCUDA/cuda-builtin-vars.cu
index ba5e5f13ebe70..dba0a76af21dd 100644
--- a/clang/test/CodeGenCUDA/cuda-builtin-vars.cu
+++ b/clang/test/CodeGenCUDA/cuda-builtin-vars.cu
@@ -6,21 +6,21 @@
 __attribute__((global))
 void kernel(int *out) {
   int i = 0;
-  out[i++] = threadIdx.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x()
-  out[i++] = threadIdx.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.y()
-  out[i++] = threadIdx.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+  out[i++] = threadIdx.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  out[i++] = threadIdx.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.y()
+  out[i++] = threadIdx.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.tid.z()
 
-  out[i++] = blockIdx.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
-  out[i++] = blockIdx.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
-  out[i++] = blockIdx.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
+  out[i++] = blockIdx.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+  out[i++] = blockIdx.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
+  out[i++] = blockIdx.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
 
-  out[i++] = blockDim.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
-  out[i++] = blockDim.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
-  out[i++] = blockDim.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
+  out[i++] = blockDim.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+  out[i++] = blockDim.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+  out[i++] = blockDim.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
 
-  out[i++] = gridDim.x; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
-  out[i++] = gridDim.y; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
-  out[i++] = gridDim.z; // CHECK: call noundef i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
+  out[i++] = gridDim.x; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+  out[i++] = gridDim.y; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
+  out[i++] = gridDim.z; // CHECK: call noundef {{.*}} i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
 
   out[i++] = warpSize; // CHECK: store i32 32,
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index f63697916d902..82770f8660850 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -542,30 +542,24 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
   // If the NVVM IR has some of reqntid* specified, then output
   // the reqntid directive, and set the unspecified ones to 1.
   // If none of Reqntid* is specified, don't output reqntid directive.
-  unsigned Reqntidx, Reqntidy, Reqntidz;
-  Reqntidx = Reqntidy = Reqntidz = 1;
-  bool ReqSpecified = false;
-  ReqSpecified |= getReqNTIDx(F, Reqntidx);
-  ReqSpecified |= getReqNTIDy(F, Reqntidy);
-  ReqSpecified |= getReqNTIDz(F, Reqntidz);
+  std::optional<unsigned> Reqntidx = getReqNTIDx(F);
+  std::optional<unsigned> Reqntidy = getReqNTIDy(F);
+  std::optional<unsigned> Reqntidz = getReqNTIDz(F);
 
-  if (ReqSpecified)
-    O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
-      << "\n";
+  if (Reqntidx || Reqntidy || Reqntidz)
+    O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
+      << ", " << Reqntidz.value_or(1) << "\n";
 
   // If the NVVM IR has some of maxntid* specified, then output
   // the maxntid directive, and set the unspecified ones to 1.
   // If none of maxntid* is specified, don't output maxntid directive.
-  unsigned Maxntidx, Maxntidy, Maxntidz;
-  Maxntidx = Maxntidy = Maxntidz = 1;
-  bool MaxSpecified = false;
-  MaxSpecified |= getMaxNTIDx(F, Maxntidx);
-  MaxSpecified |= getMaxNTIDy(F, Maxntidy);
-  MaxSpecified |= getMaxNTIDz(F, Maxntidz);
-
-  if (MaxSpecified)
-    O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
-      << "\n";
+  std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
+  std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
+  std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
+
+  if (Maxntidx || Maxntidy || Maxntidz)
+    O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
+      << ", " << Maxntidz.value_or(1) << "\n";
 
   unsigned Mincta = 0;
   if (getMinCTASm(F, Mincta))
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 4dc3cea4bd8e7..657decb3308b3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -233,9 +233,9 @@ void NVPTXTargetMachine::registerPassBuilderCallbacks(
       [this](ModulePassManager &PM, OptimizationLevel Level) {
         FunctionPassManager FPM;
         FPM.addPass(NVVMReflectPass(Subtarget.getSmVersion()));
-        // FIXME: NVVMIntrRangePass is causing numerical discrepancies,
-        // investigate and re-enable.
-        // FPM.addPass(NVVMIntrRangePass(Subtarget.getSmVersion()));
+        // Note: NVVMIntrRangePass was causing numerical discrepancies at one
+        // point, if issues crop up, consider disabling.
+        FPM.addPass(NVVMIntrRangePass(Subtarget.getSmVersion()));
         PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
       });
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 013afe916e86c..4305bbc99c969 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -128,6 +128,15 @@ bool findOneNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
   return true;
 }
 
+static std::optional<unsigned>
+findOneNVVMAnnotation(const GlobalValue &GV, const std::string &PropName) {
+  unsigned RetVal;
+  bool Found = findOneNVVMAnnotation(&GV, PropName, RetVal);
+  if (Found)
+    return RetVal;
+  return std::nullopt;
+}
+
 bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
                            std::vector<unsigned> &retval) {
   auto &AC = getAnnotationCache();
@@ -252,32 +261,57 @@ std::string getSamplerName(const Value &val) {
   return std::string(val.getName());
 }
 
-bool getMaxNTIDx(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "maxntidx", x);
+std::optional<unsigned> getMaxNTIDx(const Function &F) {
+  return findOneNVVMAnnotation(F, "maxntidx");
 }
 
-bool getMaxNTIDy(const Function &F, unsigned &y) {
-  return findOneNVVMAnnotation(&F, "maxntidy", y);
+std::optional<unsigned> getMaxNTIDy(const Function &F) {
+  return findOneNVVMAnnotation(F, "maxntidy");
 }
 
-bool getMaxNTIDz(const Function &F, unsigned &z) {
-  return findOneNVVMAnnotation(&F, "maxntidz", z);
+std::optional<unsigned> getMaxNTIDz(const Function &F) {
+  return findOneNVVMAnnotation(F, "maxntidz");
+}
+
+std::optional<unsigned> getMaxNTID(const Function &F) {
+  // Note: The semantics here are a bit strange. The PTX ISA states the
+  // following (11.4.2. Performance-Tuning Directives: .maxntid):
+  //
+  //  Note that this directive guarantees that the total number of threads does
+  //  not exceed the maximum, but does not guarantee that the limit in any
+  //  particular dimension is not exceeded.
+  std::optional<unsigned> MaxNTIDx = getMaxNTIDx(F);
+  std::optional<unsigned> MaxNTIDy = getMaxNTIDy(F);
+  std::optional<unsigned> MaxNTIDz = getMaxNTIDz(F);
+  if (MaxNTIDx || MaxNTIDy || MaxNTIDz)
+    return MaxNTIDx.value_or(1) * MaxNTIDy.value_or(1) * MaxNTIDz.value_or(1);
+  return std::nullopt;
 }
 
 bool getMaxClusterRank(const Function &F, unsigned &x) {
   return findOneNVVMAnnotation(&F, "maxclusterrank", x);
 }
 
-bool getReqNTIDx(const Function &F, unsigned &x) {
-  return findOneNVVMAnnotation(&F, "reqntidx", x);
+std::optional<unsigned> getReqNTIDx(const Function &F) {
+  return findOneNVVMAnnotation(F, "reqntidx");
+}
+
+std::optional<unsigned> getReqNTIDy(const Function &F) {
+  return findOneNVVMAnnotation(F, "reqntidy");
 }
 
-bool getReqNTIDy(const Function &F, unsigned &y) {
-  return findOneNVVMAnnotation(&F, "reqntidy", y);
+std::optional<unsigned> getReqNTIDz(const Function &F) {
+  return findOneNVVMAnnotation(F, "reqntidz");
 }
 
-bool getReqNTIDz(const Function &F, unsigned &z) {
-  return findOneNVVMAnnotation(&F, "reqntidz", z);
+std::optional<unsigned> getReqNTID(const Function &F) {
+  // Note: The semantics here are a bit strange. See getMaxNTID.
+  std::optional<unsigned> ReqNTIDx = getReqNTIDx(F);
+  std::optional<unsigned> ReqNTIDy = getReqNTIDy(F);
+  std::optional<unsigned> ReqNTIDz = getReqNTIDz(F);
+  if (ReqNTIDx || ReqNTIDy || ReqNTIDz)
+    return ReqNTIDx.value_or(1) * ReqNTIDy.value_or(1) * ReqNTIDz.value_or(1);
+  return std::nullopt;
 }
 
 bool getMinCTASm(const Function &F, unsigned &x) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index 2872db9fa2131..e020bc0f02e96 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -48,13 +48,15 @@ std::string getTextureName(const Value &);
 std::string getSurfaceName(const Value &);
 std::string getSamplerName(const Value &);
 
-bool getMaxNTIDx(const Function &, unsigned &);
-bool getMaxNTIDy(const Function &, unsigned &);
-bool getMaxNTIDz(const Function &, unsigned &);
-
-bool getReqNTIDx(const Function &, unsigned &);
-bool getReqNTIDy(const Function &, unsigned &);
-bool getReqNTIDz(const Function &, unsigned &);
+std::optional<unsigned> getMaxNTIDx(const Function &);
+std::optional<unsigned> getMaxNTIDy(const Function &);
+std::optional<unsigned> getMaxNTIDz(const Function &);
+std::optional<unsigned> getMaxNTID(const Function &F);
+
+std::optional<unsigned> getReqNTIDx(const Function &);
+std::optional<unsigned> getReqNTIDy(const Function &);
+std::optional<unsigned> getReqNTIDz(const Function &);
+std::optional<unsigned> getReqNTID(const Function &);
 
 bool getMaxClusterRank(const Function &, unsigned &);
 bool getMinCTASm(const Function &, unsigned &);
diff --git a/llvm/lib/Target/NVPTX/NVVMIntrRange.cpp b/llvm/lib/Target/NVPTX/NVVMIntrRange.cpp
index 5381646434eb8..a7ab350c40a87 100644
--- a/llvm/lib/Target/NVPTX/NVVMIntrRange.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMIntrRange.cpp
@@ -1,4 +1,4 @@
-//===- NVVMIntrRange.cpp - Set !range metadata for NVVM intrinsics --------===//
+//===- NVVMIntrRange.cpp - Set range attributes for NVVM intrinsics -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,19 +6,21 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass adds appropriate !range metadata for calls to NVVM
+// This pass adds appropriate range attributes for calls to NVVM
 // intrinsics that return a limited range of values.
 //
 //===----------------------------------------------------------------------===//
 
 #include "NVPTX.h"
-#include "llvm/IR/Constants.h"
+#include "NVPTXUtilities.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Support/CommandLine.h"
+#include <cstdint>
 
 using namespace llvm;
 
@@ -26,25 +28,24 @@ using namespace llvm;
 
 namespace llvm { void initializeNVVMIntrRangePass(PassRegistry &); }
 
-// Add !range metadata based on limits of given SM variant.
+// Add range attributes based on limits of given SM variant.
 static cl::opt<unsigned> NVVMIntrRangeSM("nvvm-intr-range-sm", cl::init(20),
                                          cl::Hidden, cl::desc("SM variant"));
 
 namespace {
 class NVVMIntrRange : public FunctionPass {
- private:
-   unsigned SmVersion;
+  unsigned SmVersion;
 
- public:
-   static char ID;
-   NVVMIntrRange() : NVVMIntrRange(NVVMIntrRangeSM) {}
-   NVVMIntrRange(unsigned int SmVersion)
-       : FunctionPass(ID), SmVersion(SmVersion) {
+public:
+  static char ID;
+  NVVMIntrRange() : NVVMIntrRange(NVVMIntrRangeSM) {}
+  NVVMIntrRange(unsigned int SmVersion)
+      : FunctionPass(ID), SmVersion(SmVersion) {
 
-     initializeNVVMIntrRangePass(*PassRegistry::getPassRegistry());
-   }
+    initializeNVVMIntrRangePass(*PassRegistry::getPassRegistry());
+  }
 
-   bool runOnFunction(Function &) override;
+  bool runOnFunction(Function &) override;
 };
 }
 
@@ -58,17 +59,17 @@ INITIALIZE_PASS(NVVMIntrRange, "nvvm-intr-range",
 
 // Adds the passed-in [Low,High) range information as metadata to the
 // passed-in call instruction.
-static bool addRangeMetadata(uint64_t Low, uint64_t High, CallInst *C) {
-  // This call already has range metadata, nothing to do.
-  if (C->getMetadata(LLVMContext::MD_range))
+static bool addRangeAttr(uint64_t Low, uint64_t High, IntrinsicInst *II) {
+  if (II->getMetadata(LLVMContext::MD_range))
     return false;
 
-  LLVMContext &Context = C->getParent()->getContext();
-  IntegerType *Int32Ty = Type::getInt32Ty(Context);
-  Metadata *LowAndHigh[] = {
-      ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Low)),
-      ConstantAsMetadata::get(ConstantInt::get(Int32Ty, High))};
-  C->setMetadata(LLVMContext::MD_range, MDNode::get(Context, LowAndHigh));
+  const uint64_t BitWidth = II->getType()->getIntegerBitWidth();
+  ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High));
+
+  if (auto CurrentRange = II->getRange())
+    Range = Range.intersectWith(CurrentRange.value());
+
+  II->addRangeRetAttr(Range);
   return true;
 }
 
@@ -76,9 +77,13 @@ static bool runNVVMIntrRange(Function &F, unsigned SmVersion) {
   struct {
     unsigned x, y, z;
   } MaxBlockSize, MaxGridSize;
-  MaxBlockSize.x = 1024;
-  MaxBlockSize.y = 1024;
-  MaxBlockSize.z = 64;
+
+  const unsigned MetadataNTID = getReqNTID(F).value_or(
+      getMaxNTID(F).value_or(std::numeric_limits<unsigned>::max()));
+
+  MaxBlockSize.x = std::min(1024u, MetadataNTID);
+  MaxBlockSize.y = std::min(1024u, MetadataNTID);
+  MaxBlockSize.z = std::min(64u, MetadataNTID);
 
   MaxGridSize.x = SmVersion >= 30 ? 0x7fffffff : 0xffff;
   MaxGridSize.y = 0xffff;
@@ -87,69 +92,67 @@ static bool runNVVMIntrRange(Function &F, unsigned SmVersion) {
   // Go through the calls in this function.
   bool Changed = false;
   for (Instruction &I : instructions(F)) {
-    CallInst *Call = dyn_cast<CallInst>(&I);
-    if (!Call)
+    IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
+    if (!II)
       continue;
 
-    if (Function *Callee = Call->getCalledFunction()) {
-      switch (Callee->getIntrinsicID()) {
-      // Index within block
-      case Intrinsic::nvvm_read_ptx_sreg_tid_x:
-        Changed |= addRangeMetadata(0, MaxBlockSize.x, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_tid_y:
-        Changed |= addRangeMetadata(0, MaxBlockSize.y, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_tid_z:
-        Changed |= addRangeMetadata(0, MaxBlockSize.z, Call);
-        break;
-
-      // Block size
-      case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
-        Changed |= addRangeMetadata(1, MaxBlockSize.x+1, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
-        Changed |= addRangeMetadata(1, MaxBlockSize.y+1, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
-        Changed |= addRangeMetadata(1, MaxBlockSize.z+1, Call);
-        break;
-
-      // Index within grid
-      case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
-        Changed |= addRangeMetadata(0, MaxGridSize.x, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
-        Changed |= addRangeMetadata(0, MaxGridSize.y, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
-        Changed |= addRangeMetadata(0, MaxGridSize.z, Call);
-        break;
-
-      // Grid size
-      case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
-        Changed |= addRangeMetadata(1, MaxGridSize.x+1, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
-        Changed |= addRangeMetadata(1, MaxGridSize.y+1, Call);
-        break;
-      case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
-        Changed |= addRangeMetadata(1, MaxGridSize.z+1, Call);
-        break;
-
-      // warp size is constant 32.
-      case Intrinsic::nvvm_read_ptx_sreg_warpsize:
-        Changed |= addRangeMetadata(32, 32+1, Call);
-        break;
-
-      // Lane ID is [0..warpsize)
-      case Intrinsic::nvvm_read_ptx_sreg_laneid:
-        Changed |= addRangeMetadata(0, 32, Call);
-        break;
-
-      default:
-        break;
-      }
+    switch (II->getIntrinsicID()) {
+    // Index within block
+    case Intrinsic::nvvm_read_ptx_sreg_tid_x:
+      Changed |= addRangeAttr(0, MaxBlockSize.x, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_tid_y:
+      Changed |= addRangeAttr(0, MaxBlockSize.y, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_tid_z:
+      Changed |= addRangeAttr(0, MaxBlockSize.z, II);
+      break;
+
+    // Block size
+    case Intrinsic::nvvm_read_ptx_sreg_ntid_x:
+      Changed |= addRangeAttr(1, MaxBlockSize.x + 1, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_ntid_y:
+      Changed |= addRangeAttr(1, MaxBlockSize.y + 1, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_ntid_z:
+      Changed |= addRangeAttr(1, MaxBlockSize.z + 1, II);
+      break;
+
+    // Index within grid
+    case Intrinsic::nvvm_read_ptx_sreg_ctaid_x:
+      Changed |= addRangeAttr(0, MaxGridSize.x, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_ctaid_y:
+      Changed |= addRangeAttr(0, MaxGridSize.y, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_ctaid_z:
+      Changed |= addRangeAttr(0, MaxGridSize.z, II);
+      break;
+
+    // Grid size
+    case Intrinsic::nvvm_read_ptx_sreg_nctaid_x:
+      Changed |= addRangeAttr(1, MaxGridSize.x + 1, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_nctaid_y:
+      Changed |= addRangeAttr(1, MaxGridSize.y + 1, II);
+      break;
+    case Intrinsic::nvvm_read_ptx_sreg_nctaid_z:
+      Changed |= addRangeAttr(1, MaxGridSize.z + 1, II);
+      break;
+
+    // warp size is constant 32.
+    case Intrinsic::nvvm_read_ptx_sreg_warpsize:
+      Changed |= addRangeAttr(32, 32 + 1, II);
+      break;
+
+    // Lane ID is [0..warpsize)
+    case Intrinsic::nvvm_read_ptx_sreg_laneid:
+      Changed |= addRangeAttr(0, 32, II);
+      break;
+
+    default:
+      break;
     }
   }
 
diff --git a/llvm/test/CodeGen/NVPTX/intr-range.ll b/llvm/test/CodeGen/NVPTX/intr-range.ll
new file mode 100644
index 0000000000000..3fd1672759903
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/intr-range.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-attributes --version 5
+; RUN: opt < %s -S -mtriple=nvptx-nvidia-cuda -mcpu=sm_20 -passes=nvvm-intr-range | FileCheck %s
+
+define i32 @test_maxntid() {
+; CHECK-LABEL: define i32 @test_maxntid(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i32 0, 96) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; CHECK-NEXT:    [[TMP2:%.*]] = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+; CHECK-NEXT:    [[TMP4:%.*]] = call range(i32 1, 97) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+; CHECK-NEXT:    [[TMP3:%.*]] = add i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i32 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    ret i32 [[TMP5]]
+;
+  %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+  %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+  %4 = add i32 %1, %2
+  %5 = add i32 %4, %3
+  ret i32 %5
+}
+
+define i32 @test_reqntid() {
+; CHECK-LABEL: define i32 @test_reqntid(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i32 0, 20) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; CHECK-NEXT:    [[TMP2:%.*]] = call range(i32 0, 20) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+; CHECK-NEXT:    [[TMP3:%.*]] = call range(i32 1, 21) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i32 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    ret i32 [[TMP5]]
+;
+  %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+  %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+  %4 = add i32 %1, %2
+  %5 = add i32 %4, %3
+  ret i32 %5
+}
+
+;; A case like this could occur if a function with the sreg intrinsic was
+;; inlined into a kernel where the tid metadata is present, ensure the range is
+;; updated.
+define i32 @test_inlined() {
+; CHECK-LABEL: define i32 @test_inlined(
+; CHECK-SAME: ) #[[ATTR0]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call range(i32 0, 4) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; CHECK-NEXT:    ret i32 [[TMP1]]
+;
+  %1 = call range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  ret i32 %1
+}
+
+declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+
+!nvvm.annotations = !{!0, !1, !2}
+!0 = !{ptr @test_maxntid, !"kernel", i32 1, !"maxntidx", i32 32, !"maxntidz", i32 3}
+!1 = !{ptr @test_reqntid, !"kernel", i32 1, !"reqntidx", i32 20}
+!2 = !{ptr @test_inlined, !"kernel", i32 1, !"maxntidx", i32 4}
diff --git a/llvm/test/CodeGen/NVPTX/intrinsic-old.ll b/llvm/test/CodeGen/NVPTX/intrinsic-old.ll
index 3930e6d774183..a53e538241e31 100644
--- a/llvm/test/CodeGen/NVPTX/intrinsic-old.ll
+++ b/llvm/test/CodeGen/NVPTX/intrinsic-old.ll
@@ -15,7 +15,7 @@
 
 define ptx_device i32 @test_tid_x() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %tid.x;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[BLK_IDX_XY:[0-9]+]]
+; RANGE: call range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
 	ret i32 %x
@@ -23,7 +23,7 @@ define ptx_device i32 @test_tid_x() {
 
 define ptx_device i32 @test_tid_y() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %tid.y;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.tid.y(), !range ![[BLK_IDX_XY]]
+; RANGE: call range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.y()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
 	ret i32 %x
@@ -31,7 +31,7 @@ define ptx_device i32 @test_tid_y() {
 
 define ptx_device i32 @test_tid_z() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %tid.z;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.tid.z(), !range ![[BLK_IDX_Z:[0-9]+]]
+; RANGE: call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.z()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
 	ret i32 %x
@@ -46,7 +46,7 @@ define ptx_device i32 @test_tid_w() {
 
 define ptx_device i32 @test_ntid_x() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %ntid.x;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.ntid.x(), !range ![[BLK_SIZE_XY:[0-9]+]]
+; RANGE: call range(i32 1, 1025) i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
 	ret i32 %x
@@ -54,7 +54,7 @@ define ptx_device i32 @test_ntid_x() {
 
 define ptx_device i32 @test_ntid_y() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %ntid.y;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.ntid.y(), !range ![[BLK_SIZE_XY]]
+; RANGE: call range(i32 1, 1025) i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
 	ret i32 %x
@@ -62,7 +62,7 @@ define ptx_device i32 @test_ntid_y() {
 
 define ptx_device i32 @test_ntid_z() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %ntid.z;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.ntid.z(), !range ![[BLK_SIZE_Z:[0-9]+]]
+; RANGE: call range(i32 1, 65) i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
 	ret i32 %x
@@ -77,7 +77,7 @@ define ptx_device i32 @test_ntid_w() {
 
 define ptx_device i32 @test_laneid() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %laneid;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.laneid(), !range ![[LANEID:[0-9]+]]
+; RANGE: call range(i32 0, 32) i32 @llvm.nvvm.read.ptx.sreg.laneid()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.laneid()
 	ret i32 %x
@@ -85,7 +85,7 @@ define ptx_device i32 @test_laneid() {
 
 define ptx_device i32 @test_warpsize() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, WARP_SZ;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.warpsize(), !range ![[WARPSIZE:[0-9]+]]
+; RANGE: call range(i32 32, 33) i32 @llvm.nvvm.read.ptx.sreg.warpsize()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
 	ret i32 %x
@@ -107,7 +107,7 @@ define ptx_device i32 @test_nwarpid() {
 
 define ptx_device i32 @test_ctaid_y() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %ctaid.y;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range ![[GRID_IDX_YZ:[0-9]+]]
+; RANGE: call range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
 	ret i32 %x
@@ -115,7 +115,7 @@ define ptx_device i32 @test_ctaid_y() {
 
 define ptx_device i32 @test_ctaid_z() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %ctaid.z;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z(), !range ![[GRID_IDX_YZ]]
+; RANGE: call range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
 	ret i32 %x
@@ -123,8 +123,8 @@ define ptx_device i32 @test_ctaid_z() {
 
 define ptx_device i32 @test_ctaid_x() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %ctaid.x;
-; RANGE_30: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[GRID_IDX_X:[0-9]+]]
-; RANGE_20: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[GRID_IDX_YZ]]
+; RANGE_30: call range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+; RANGE_20: call range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
 	ret i32 %x
@@ -139,7 +139,7 @@ define ptx_device i32 @test_ctaid_w() {
 
 define ptx_device i32 @test_nctaid_y() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %nctaid.y;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y(), !range ![[GRID_SIZE_YZ:[0-9]+]]
+; RANGE: call range(i32 1, 65536) i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
 	ret i32 %x
@@ -147,7 +147,7 @@ define ptx_device i32 @test_nctaid_y() {
 
 define ptx_device i32 @test_nctaid_z() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %nctaid.z;
-; RANGE: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z(), !range ![[GRID_SIZE_YZ]]
+; RANGE: call range(i32 1, 65536) i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
 	ret i32 %x
@@ -155,8 +155,8 @@ define ptx_device i32 @test_nctaid_z() {
 
 define ptx_device i32 @test_nctaid_x() {
 ; CHECK: mov.u32 %r{{[0-9]+}}, %nctaid.x;
-; RANGE_30: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x(), !range ![[GRID_SIZE_X:[0-9]+]]
-; RANGE_20: call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x(), !range ![[GRID_SIZE_YZ]]
+; RANGE_30: call range(i32 1, -2147483648) i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+; RANGE_20: call range(i32 1, 65536) i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
 ; CHECK: ret;
 	%x = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
 	ret i32 %x
@@ -327,14 +327,3 @@ declare void @llvm.nvvm.bar.sync(i32 %i)
 
 !0 = !{i32 0, i32 19}
 ; RANGE-DAG: ![[ALREADY]] = !{i32 0, i32 19}
-; RANGE-DAG: ![[BLK_IDX_XY]] = !{i32 0, i32 1024}
-; RANGE-DAG: ![[BLK_IDX_XY]] = !{i32 0, i32 1024}
-; RANGE-DAG: ![[BLK_IDX_Z]] = !{i32 0, i32 64}
-; RANGE-DAG: ![[BLK_SIZE_XY]] = !{i32 1, i32 1025}
-; RANGE-DAG: ![[BLK_SIZE_Z]] = !{i32 1, i32 65}
-; RANGE-DAG: ![[LANEID]] = !{i32 0, i32 32}
-; RANGE-DAG: ![[WARPSIZE]] = !{i32 32, i32 33}
-; RANGE_30-DAG: ![[GRID_IDX_X]] = !{i32 0, i32 2147483647}
-; RANGE-DAG: ![[GRID_IDX_YZ]] = !{i32 0, i32 65535}
-; RANGE_30-DAG: ![[GRID_SIZE_X]] = !{i32 1, i32 -2147483648}
-; RANGE-DAG: ![[GRID_SIZE_YZ]] = !{i32 1, i32 65536}



More information about the llvm-commits mailing list