[Mlir-commits] [mlir] 54c5521 - [mlir][spirv] Use `verifySymbolUses` for `spirv.FunctionCall`. (#159399)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 18 14:34:46 PDT 2025


Author: Erick Ochoa Lopez
Date: 2025-09-18T17:34:42-04:00
New Revision: 54c55219ea3fbb44046d385acefcff0b73d3f8f4

URL: https://github.com/llvm/llvm-project/commit/54c55219ea3fbb44046d385acefcff0b73d3f8f4
DIFF: https://github.com/llvm/llvm-project/commit/54c55219ea3fbb44046d385acefcff0b73d3f8f4.diff

LOG: [mlir][spirv] Use `verifySymbolUses` for `spirv.FunctionCall`. (#159399)

`spirv.FunctionCall`'s verifier was being too aggressive. It included
verification of non-local properties by looking at the callee's
definition.

This caused problems in cases where callee had verification errors and
could lead to null pointer dereferencing.

According to [MLIR's developers guide
](https://mlir.llvm.org/getting_started/DeveloperGuide/#ir-verifier)

> TLDR: only verify local aspects of an operation,
> in particular don’t follow def-use chains
> (don’t look at the producer of any operand or the user
>  of any results).

The fix includes adding the `SymbolUserOpInterface` to `FunctionCall`
and moving most of the verification logic to `verifySymbolUses`.

Fixes #159295

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index ef6682ab3630c..acb6467132be9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -15,6 +15,7 @@
 #define MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
 
 include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -187,7 +188,8 @@ def SPIRV_BranchConditionalOp : SPIRV_Op<"BranchConditional", [
 // -----
 
 def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
-    InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>]> {
+    InFunctionScope, DeclareOpInterfaceMethods<CallOpInterface>,
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "Call a function.";
 
   let description = [{

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 890406df74e72..f0b46e61965f4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -151,10 +151,20 @@ LogicalResult BranchConditionalOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult FunctionCallOp::verify() {
+  if (getNumResults() > 1) {
+    return emitOpError(
+               "expected callee function to have 0 or 1 result, but provided ")
+           << getNumResults();
+  }
+  return success();
+}
+
+LogicalResult
+FunctionCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto fnName = getCalleeAttr();
 
-  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
+  auto funcOp =
+      symbolTable.lookupNearestSymbolFrom<spirv::FuncOp>(*this, fnName);
   if (!funcOp) {
     return emitOpError("callee function '")
            << fnName.getValue() << "' not found in nearest symbol table";
@@ -162,12 +172,6 @@ LogicalResult FunctionCallOp::verify() {
 
   auto functionType = funcOp.getFunctionType();
 
-  if (getNumResults() > 1) {
-    return emitOpError(
-               "expected callee function to have 0 or 1 result, but provided ")
-           << getNumResults();
-  }
-
   if (functionType.getNumInputs() != getNumOperands()) {
     return emitOpError("has incorrect number of operands for callee: expected ")
            << functionType.getNumInputs() << ", but provided "

diff  --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..8e29ff6679068 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -262,6 +262,35 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+"builtin.module"() ({
+  "spirv.module"() <{
+    addressing_model = #spirv.addressing_model<Logical>,
+    memory_model = #spirv.memory_model<GLSL450>
+  }> ({
+    "spirv.func"() <{
+      function_control = #spirv.function_control<None>,
+      function_type = (f32) -> f32,
+      sym_name = "bar"
+    }> ({
+    ^bb0(%arg0: f32):
+      %0 = "spirv.FunctionCall"(%arg0) <{callee = @foo}> : (f32) -> f32
+      "spirv.ReturnValue"(%0) : (f32) -> ()
+    }) : () -> ()
+    // expected-error @+1 {{requires attribute 'function_type'}}
+    "spirv.func"() <{
+      function_control = #spirv.function_control<None>,
+      message = "2nd parent",
+      sym_name = "foo"
+      // This is invalid MLIR because function_type is missing from spirv.func.
+    }> ({
+    ^bb0(%arg0: f32):
+      "spirv.ReturnValue"(%arg0) : (f32) -> ()
+    }) : () -> ()
+  }) : () -> ()
+}) : () -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.loop
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list