[Mlir-commits] [mlir] [MLIR][NVVM] Update Op verifiers to prevent ungraceful exits (PR #165677)

Srinivasa Ravi llvmlistbot at llvm.org
Mon Nov 3 21:09:02 PST 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/165677

>From f566fa78261aab88fcb77f06437b53119e52fb3f Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 30 Oct 2025 08:27:40 +0000
Subject: [PATCH 1/7] [MLIR][NVVM] Update Op verifiers to prevent ungraceful
 exits

Updates the following Ops to prevent ungraceful exits with a
stack-dump in certain cases of incorrect usages, and instead
gracefully error out with a more informative error message:

- tcgen05.ld
- shfl.sync
---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 24 ++++++++++++++-------
 mlir/test/Dialect/LLVMIR/invalid.mlir       |  7 ++++++
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir |  8 +++++++
 3 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4dbcc1d4b..402c90fba0f2d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -867,15 +867,20 @@ LogicalResult MmaOp::verify() {
 }
 
 LogicalResult ShflOp::verify() {
-  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
-    return success();
   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
-  auto elementType = (type && type.getBody().size() == 2)
-                         ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
-                         : nullptr;
-  if (!elementType || elementType.getWidth() != 1)
-    return emitError("expected return type to be a two-element struct with "
-                     "i1 as the second element");
+
+  if ((*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) {
+    auto elementType = (type && type.getBody().size() == 2)
+                           ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
+                           : nullptr;
+    if (!elementType || elementType.getWidth() != 1)
+      return emitOpError("expected return type to be a two-element struct with "
+                         "i1 as the second element");
+  } else {
+    if (type)
+      return emitOpError("\"return_value_and_is_valid\" attribute must be "
+                         "specified when returning the predicate");
+  }
   return success();
 }
 
@@ -2450,6 +2455,9 @@ LogicalResult Tcgen05LdOp::verify() {
   LogicalResult result = success();
   if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
     result = emitError("shape 16x32bx2 requires offset argument");
+  
+  if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
+    result = emitError("offset argument is only supported for shape 16x32bx2");
 
   auto resTy = getRes().getType();
   unsigned resLen = isa<VectorType>(resTy)
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index aaf9f8024bfbe..90208aa55bd55 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -684,6 +684,13 @@ func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3
 
 // -----
 
+func.func @nvvm_invalid_shfl_pred_4(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+  // expected-error at +1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}}
+  %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
+}
+
+// -----
+
 func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 09b8f593154b5..8cb7b068498fd 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -621,3 +621,11 @@ func.func @invalid_range_equal_bounds() {
   %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
   return
 }
+
+// -----
+
+llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
+  // expected-error at +1 {{offset argument is only supported for shape 16x32bx2}}
+  %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
+  llvm.return
+}

>From 045ce6c99786cb8634ca885dfca5b94a58ce2b8a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 30 Oct 2025 08:37:55 +0000
Subject: [PATCH 2/7] fix formatting

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 402c90fba0f2d..a23245f92cee7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2455,7 +2455,7 @@ LogicalResult Tcgen05LdOp::verify() {
   LogicalResult result = success();
   if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())
     result = emitError("shape 16x32bx2 requires offset argument");
-  
+
   if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset())
     result = emitError("offset argument is only supported for shape 16x32bx2");
 

>From 70fe4cc7f10f3f4c4f1787d838676bb2b71c0bcd Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 31 Oct 2025 06:00:23 +0000
Subject: [PATCH 3/7] address comments

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td          | 8 ++++++--
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp           | 4 ++--
 mlir/test/Dialect/LLVMIR/invalid.mlir                | 7 -------
 mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir  | 8 ++++++++
 mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir | 9 +++++++++
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir          | 8 --------
 6 files changed, 25 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4f483859ac18d..1e915e3027d58 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1341,9 +1341,9 @@ def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
 
 def NVVM_ShflOp :
   NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
-  Results<(outs LLVM_Type:$res)>,
+  Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
-                 LLVM_Type:$val,
+                 AnyTypeOf<[I32, F32]>:$val,
                  I32:$offset,
                  I32:$mask_and_clamp,
                  ShflKindAttr:$kind,
@@ -1359,6 +1359,10 @@ def NVVM_ShflOp :
     a mask for logically splitting warps into sub-segments and an upper bound
     for clamping the source lane index.
     
+    Optionally, `return_value_and_is_valid` can be specified to return a 
+    two-element struct with the result and a predicate indicating if the 
+    computed source lane index is valid.
+    
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
   }];
   string llvmBuilder = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a23245f92cee7..b5b07929bab6a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -870,10 +870,10 @@ LogicalResult ShflOp::verify() {
   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
 
   if ((*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) {
-    auto elementType = (type && type.getBody().size() == 2)
+    auto predicateType = (type && type.getBody().size() == 2)
                            ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
                            : nullptr;
-    if (!elementType || elementType.getWidth() != 1)
+    if (!predicateType || predicateType.getWidth() != 1)
       return emitOpError("expected return type to be a two-element struct with "
                          "i1 as the second element");
   } else {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 90208aa55bd55..aaf9f8024bfbe 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -684,13 +684,6 @@ func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3
 
 // -----
 
-func.func @nvvm_invalid_shfl_pred_4(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
-  // expected-error at +1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}}
-  %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
-}
-
-// -----
-
 func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
                          %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                          %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
new file mode 100644
index 0000000000000..d2fe21c841a76
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+  // expected-error at +1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}}
+  %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
new file mode 100644
index 0000000000000..1b93f20c15b99
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-ld-invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+// -----
+
+llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
+  // expected-error at +1 {{offset argument is only supported for shape 16x32bx2}}
+  %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8cb7b068498fd..09b8f593154b5 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -621,11 +621,3 @@ func.func @invalid_range_equal_bounds() {
   %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
   return
 }
-
-// -----
-
-llvm.func @nvvm_tcgen05_ld_32x32b_offset(%tmemAddr : !llvm.ptr<6>, %offset : i64) -> () {
-  // expected-error at +1 {{offset argument is only supported for shape 16x32bx2}}
-  %ldv2 = nvvm.tcgen05.ld %tmemAddr, %offset { pack, shape = #nvvm.tcgen05_ldst_shape<shape_32x32b>} : vector<2 x i32>
-  llvm.return
-}

>From 29aa80ec9c3cc5231393d16ffb4d2aaa5c588b69 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 31 Oct 2025 06:06:38 +0000
Subject: [PATCH 4/7] fix formatting

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index b5b07929bab6a..33f9a256a78d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -871,8 +871,8 @@ LogicalResult ShflOp::verify() {
 
   if ((*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) {
     auto predicateType = (type && type.getBody().size() == 2)
-                           ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
-                           : nullptr;
+                             ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
+                             : nullptr;
     if (!predicateType || predicateType.getWidth() != 1)
       return emitOpError("expected return type to be a two-element struct with "
                          "i1 as the second element");

>From 867f33c14e877dd89d0e5f2e04be6b996482fb88 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 3 Nov 2025 12:25:56 +0000
Subject: [PATCH 5/7] update and refactor shfl.sync Op verifier

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  5 ++-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 38 +++++++++++++------
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  6 +--
 .../Target/LLVMIR/nvvm/shfl-sync-invalid.mlir | 16 +++++++-
 4 files changed, 47 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1e915e3027d58..4d6b0acffe862 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1359,8 +1359,9 @@ def NVVM_ShflOp :
     a mask for logically splitting warps into sub-segments and an upper bound
     for clamping the source lane index.
     
-    Optionally, `return_value_and_is_valid` can be specified to return a 
-    two-element struct with the result and a predicate indicating if the 
+    The `return_value_and_is_valid` unit attribute can be specified to indicate 
+    that the return value is a two-element struct, where the first element is 
+    the result value and the second element is a predicate indicating if the 
     computed source lane index is valid.
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 33f9a256a78d0..1367f63d4b7a3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -867,19 +867,33 @@ LogicalResult MmaOp::verify() {
 }
 
 LogicalResult ShflOp::verify() {
-  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
-
-  if ((*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) {
-    auto predicateType = (type && type.getBody().size() == 2)
-                             ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
-                             : nullptr;
-    if (!predicateType || predicateType.getWidth() != 1)
-      return emitOpError("expected return type to be a two-element struct with "
-                         "i1 as the second element");
+  auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
+
+  if (returnStructType && !getReturnValueAndIsValid())
+    return emitOpError("\"return_value_and_is_valid\" attribute must be "
+                       "specified when the return type is a struct type");
+
+  if (getReturnValueAndIsValid()) {
+    if (!returnStructType || returnStructType.getBody().size() != 2)
+      return emitOpError("expected return type to be a two-element struct");
+
+    llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
+
+    auto resultType = returnStruct[0];
+    if (resultType != getVal().getType())
+      return emitOpError(
+                 "expected first element in the returned struct to be of type ")
+             << getVal().getType() << " but got " << resultType << " instead.";
+
+    auto predicateType = returnStruct[1];
+    if (!predicateType.isInteger(1))
+      return emitOpError("expected second element in the returned struct to be "
+                         "of type 'i1' but got ")
+             << predicateType << " instead.";
   } else {
-    if (type)
-      return emitOpError("\"return_value_and_is_valid\" attribute must be "
-                         "specified when returning the predicate");
+    if (getType() != getVal().getType())
+      return emitOpError("expected return type to be of type ")
+             << getVal().getType() << " but got " << getType() << " instead.";
   }
   return success();
 }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index aaf9f8024bfbe..ba74d4d9585a3 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -664,21 +664,21 @@ func.func @zero_non_llvm_type() {
 // -----
 
 func.func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
-  // expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
+  // expected-error at +1 {{expected return type to be a two-element struct}}
   %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
 }
 
 // -----
 
 func.func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
-  // expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
+  // expected-error at +1 {{expected return type to be a two-element struct}}
   %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
 }
 
 // -----
 
 func.func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
-  // expected-error at +1 {{expected return type to be a two-element struct with i1 as the second element}}
+  // expected-error at +1 {{expected second element in the returned struct to be of type 'i1' but got 'i32' instead.}}
   %0 = nvvm.shfl.sync bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
index d2fe21c841a76..cd65eab977216 100644
--- a/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/shfl-sync-invalid.mlir
@@ -3,6 +3,20 @@
 // -----
 
 func.func @nvvm_invalid_shfl_pred(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
-  // expected-error at +1 {{"return_value_and_is_valid" attribute must be specified when returning the predicate}}
+  // expected-error at +1 {{"return_value_and_is_valid" attribute must be specified when the return type is a struct type}}
   %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> !llvm.struct<(f32, i1)>
 }
+
+// -----
+
+func.func @nvvm_invalid_shfl_invalid_return_type_1(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+  // expected-error at +1 {{expected return type to be of type 'f32' but got 'i32' instead.}}
+  %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 : f32 -> i32
+}
+
+// -----
+
+func.func @nvvm_invalid_shfl_invalid_return_type_2(%arg0 : i32, %arg1 : f32, %arg2 : i32, %arg3 : i32) {
+  // expected-error at +1 {{expected first element in the returned struct to be of type 'f32' but got 'i32' instead.}}
+  %0 = nvvm.shfl.sync bfly %arg0, %arg1, %arg2, %arg3 {return_value_and_is_valid} : f32 -> !llvm.struct<(i32, i1)>
+}

>From 9241b6a72aa98b4560226d57c2d1772d8a1afd38 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 3 Nov 2025 13:12:57 +0000
Subject: [PATCH 6/7] refactor shfl.sync verifier

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 33 +++++++++++++---------
 1 file changed, 20 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1367f63d4b7a3..beda3d08d1ba6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -869,31 +869,38 @@ LogicalResult MmaOp::verify() {
 LogicalResult ShflOp::verify() {
   auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
 
-  if (returnStructType && !getReturnValueAndIsValid())
-    return emitOpError("\"return_value_and_is_valid\" attribute must be "
-                       "specified when the return type is a struct type");
+  auto mismatchedType = [&](Twine desc, Type expectedType,
+                            Type actualType) -> LogicalResult {
+    return emitOpError("expected " + desc + " to be of type ")
+           << expectedType << " but got " << actualType << " instead.";
+  };
+
+  if (returnStructType) {
+    if (!getReturnValueAndIsValid())
+      return emitOpError("\"return_value_and_is_valid\" attribute must be "
+                         "specified when the return type is a struct type");
 
-  if (getReturnValueAndIsValid()) {
-    if (!returnStructType || returnStructType.getBody().size() != 2)
+    if (returnStructType.getBody().size() != 2)
       return emitOpError("expected return type to be a two-element struct");
 
     llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
 
     auto resultType = returnStruct[0];
     if (resultType != getVal().getType())
-      return emitOpError(
-                 "expected first element in the returned struct to be of type ")
-             << getVal().getType() << " but got " << resultType << " instead.";
+      return mismatchedType("first element in the returned struct",
+                            getVal().getType(), resultType);
 
     auto predicateType = returnStruct[1];
     if (!predicateType.isInteger(1))
-      return emitOpError("expected second element in the returned struct to be "
-                         "of type 'i1' but got ")
-             << predicateType << " instead.";
+      return mismatchedType("second element in the returned struct",
+                            mlir::IntegerType::get(getContext(), 1),
+                            predicateType);
   } else {
+    if (getReturnValueAndIsValid())
+      return emitOpError("expected return type to be a two-element struct");
+
     if (getType() != getVal().getType())
-      return emitOpError("expected return type to be of type ")
-             << getVal().getType() << " but got " << getType() << " instead.";
+      return mismatchedType("return type", getVal().getType(), getType());
   }
   return success();
 }

>From 4c722d4476365052a16ae106f77798c6bc71a3bf Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 4 Nov 2025 05:08:24 +0000
Subject: [PATCH 7/7] address comments

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index beda3d08d1ba6..afc2161a9acdc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -869,8 +869,8 @@ LogicalResult MmaOp::verify() {
 LogicalResult ShflOp::verify() {
   auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
 
-  auto mismatchedType = [&](Twine desc, Type expectedType,
-                            Type actualType) -> LogicalResult {
+  auto verifyTypeError = [&](Twine desc, Type expectedType,
+                             Type actualType) -> LogicalResult {
     return emitOpError("expected " + desc + " to be of type ")
            << expectedType << " but got " << actualType << " instead.";
   };
@@ -884,23 +884,22 @@ LogicalResult ShflOp::verify() {
       return emitOpError("expected return type to be a two-element struct");
 
     llvm::ArrayRef<Type> returnStruct = returnStructType.getBody();
-
     auto resultType = returnStruct[0];
     if (resultType != getVal().getType())
-      return mismatchedType("first element in the returned struct",
-                            getVal().getType(), resultType);
+      return verifyTypeError("first element in the returned struct",
+                             getVal().getType(), resultType);
 
     auto predicateType = returnStruct[1];
     if (!predicateType.isInteger(1))
-      return mismatchedType("second element in the returned struct",
-                            mlir::IntegerType::get(getContext(), 1),
-                            predicateType);
+      return verifyTypeError("second element in the returned struct",
+                             mlir::IntegerType::get(getContext(), 1),
+                             predicateType);
   } else {
     if (getReturnValueAndIsValid())
       return emitOpError("expected return type to be a two-element struct");
 
     if (getType() != getVal().getType())
-      return mismatchedType("return type", getVal().getType(), getType());
+      return verifyTypeError("return type", getVal().getType(), getType());
   }
   return success();
 }



More information about the Mlir-commits mailing list