[Mlir-commits] [mlir] [MLIR][NVVM] Fixed assertion failure for insufficient parsing validation of nvvm dialect PureSpecialRangeableRegisterOp (PR #163434)

Stefan Mada llvmlistbot at llvm.org
Tue Oct 14 13:45:50 PDT 2025


https://github.com/smada3 updated https://github.com/llvm/llvm-project/pull/163434

>From ba87fab55f460014ab0d07667e93404b05b21546 Mon Sep 17 00:00:00 2001
From: Stefan Mada <smada at nvidia.com>
Date: Tue, 14 Oct 2025 19:02:01 +0000
Subject: [PATCH 1/4] Fixed assertion failure for insufficient parsing
 validation of nvvm dialect with PureSpecialRangeableRegisterOp

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 17 +++++++++++++++++
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 10 ++++++++++
 2 files changed, 27 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 89fbeb7270a38..e4e23ecf77d8b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -279,6 +279,23 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
         SetIntRangeFn setResultRanges) {
         nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges);
     }
+
+    // Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
+    ::llvm::LogicalResult $cppClass::verify() {
+      auto rangeAttr = getRange();
+      if (!rangeAttr)
+        return ::mlir::success(); // No range specified, validation passes
+      
+      const ::llvm::APInt &lower = rangeAttr->getLower();
+      const ::llvm::APInt &upper = rangeAttr->getUpper();
+      
+      // Check LLVM ConstantRange constructor condition
+      if (!(lower != upper || (lower.isMaxValue() || lower.isMinValue()))) {
+        return emitOpError("invalid range attribute: range must be a valid constant range");
+      }
+      
+      return ::mlir::success();
+    }
   }];
 
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..27727d9bb5836 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -559,3 +559,13 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
   %res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1
   llvm.return
 }
+
+
+// -----
+
+// Test for range validation - invalid range where lower == upper but not at extremes
+func.func @invalid_range_equal_bounds() {
+  // expected-error @below {{invalid range attribute: range must be a valid constant range}}
+  %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
+  return
+}

>From 6d128fd65a3f3444d4294e584cf9d3b9996eac79 Mon Sep 17 00:00:00 2001
From: Stefan Mada <smada at nvidia.com>
Date: Tue, 14 Oct 2025 19:09:36 +0000
Subject: [PATCH 2/4] Added hasVerifier=true for purespecialrangeableregisterop

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e4e23ecf77d8b..13ef13a2a8f30 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
   let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
   let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
   let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
+  let hasVerifier = 1;
 
   // Backwards-compatibility builder for an unspecified range.
   let builders = [

>From c26f5dbeabbc17fa367aa8863ecc8175acbe9e84 Mon Sep 17 00:00:00 2001
From: Stefan Mada <smada at nvidia.com>
Date: Tue, 14 Oct 2025 20:39:34 +0000
Subject: [PATCH 3/4] Added more tests for rangeable register ops, clarified
 error msg

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 9 +++++++--
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 2 +-
 mlir/test/Target/LLVMIR/nvvmir.mlir         | 4 ++++
 3 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 13ef13a2a8f30..7d7e2dba745d4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -291,8 +291,13 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
       const ::llvm::APInt &upper = rangeAttr->getUpper();
       
       // Check LLVM ConstantRange constructor condition
-      if (!(lower != upper || (lower.isMaxValue() || lower.isMinValue()))) {
-        return emitOpError("invalid range attribute: range must be a valid constant range");
+      if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+        unsigned bitWidth = lower.getBitWidth();
+        ::llvm::APInt minVal = ::llvm::APInt::getMinValue(bitWidth);
+        ::llvm::APInt maxVal = ::llvm::APInt::getMaxValue(bitWidth);
+        return emitOpError("invalid range attribute: Lower == Upper, but they aren't min (")
+               << ::llvm::toString(minVal, 10, false) << ") or max (" 
+               << ::llvm::toString(maxVal, 10, false) << ") value! This is an invalid constant range.";
       }
       
       return ::mlir::success();
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 27727d9bb5836..81c5e6b773c65 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -565,7 +565,7 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ
 
 // Test for range validation - invalid range where lower == upper but not at extremes
 func.func @invalid_range_equal_bounds() {
-  // expected-error @below {{invalid range attribute: range must be a valid constant range}}
+  // expected-error @below {{invalid range attribute: Lower == Upper, but they aren't min (0) or max (4294967295) value! This is an invalid constant range.}}
   %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
   return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 00a479d1f877d..594ae4849e3eb 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -152,6 +152,10 @@ llvm.func @nvvm_special_regs() -> i32 {
   %74 = nvvm.read.ptx.sreg.lanemask.ge : i32
   //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
   %75 = nvvm.read.ptx.sreg.lanemask.gt : i32
+  // CHECK: %76 = call range(i32 0, 0) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %76 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 0> : i32
+  // CHECK: %77 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %77 = nvvm.read.ptx.sreg.tid.x range <i32, 4294967295, 4294967295> : i32
   llvm.return %1 : i32
 }
 

>From c791da34c0fa8e64d40c1aefab33eddc77566b84 Mon Sep 17 00:00:00 2001
From: Stefan Mada <smada at nvidia.com>
Date: Tue, 14 Oct 2025 20:45:24 +0000
Subject: [PATCH 4/4] Refactored verification code to NVVMDialect.cpp

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 19 +--------------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 26 +++++++++++++++++++++
 2 files changed, 27 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7d7e2dba745d4..077125b7983b2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -283,24 +283,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits =
 
     // Verify the range attribute satisfies LLVM ConstantRange constructor requirements.
     ::llvm::LogicalResult $cppClass::verify() {
-      auto rangeAttr = getRange();
-      if (!rangeAttr)
-        return ::mlir::success(); // No range specified, validation passes
-      
-      const ::llvm::APInt &lower = rangeAttr->getLower();
-      const ::llvm::APInt &upper = rangeAttr->getUpper();
-      
-      // Check LLVM ConstantRange constructor condition
-      if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
-        unsigned bitWidth = lower.getBitWidth();
-        ::llvm::APInt minVal = ::llvm::APInt::getMinValue(bitWidth);
-        ::llvm::APInt maxVal = ::llvm::APInt::getMaxValue(bitWidth);
-        return emitOpError("invalid range attribute: Lower == Upper, but they aren't min (")
-               << ::llvm::toString(minVal, 10, false) << ") or max (" 
-               << ::llvm::toString(maxVal, 10, false) << ") value! This is an invalid constant range.";
-      }
-      
-      return ::mlir::success();
+      return verifyConstantRangeAttr(getOperation(), getRange());
     }
   }];
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40bd2d32..d39467e9629e0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2306,6 +2306,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
   }
 }
 
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+                        std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+  if (!rangeAttr)
+    return success();
+
+  const llvm::APInt &lower = rangeAttr->getLower();
+  const llvm::APInt &upper = rangeAttr->getUpper();
+
+  // Check LLVM ConstantRange constructor condition
+  if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+    unsigned bitWidth = lower.getBitWidth();
+    llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+    llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+    return op->emitOpError(
+               "invalid range attribute: Lower == Upper, but they aren't min (")
+           << llvm::toString(minVal, 10, false) << ") or max ("
+           << llvm::toString(maxVal, 10, false)
+           << ") value! This is an invalid constant range.";
+  }
+
+  return success();
+}
+
 static llvm::Value *getAsPackedI32(llvm::Value *arg,
                                    llvm::IRBuilderBase &builder) {
   return builder.CreateBitCast(arg,



More information about the Mlir-commits mailing list