[Mlir-commits] [mlir] 51bc82d - [mlir] Implement SymbolUserOpInterface in LLVM::CallOp

Eugene Zhulenev llvmlistbot at llvm.org
Fri Aug 5 13:50:35 PDT 2022


Author: Eugene Zhulenev
Date: 2022-08-05T13:50:31-07:00
New Revision: 51bc82d147f8205dc516a50c66a3938249116f41

URL: https://github.com/llvm/llvm-project/commit/51bc82d147f8205dc516a50c66a3938249116f41
DIFF: https://github.com/llvm/llvm-project/commit/51bc82d147f8205dc516a50c66a3938249116f41.diff

LOG: [mlir] Implement SymbolUserOpInterface in LLVM::CallOp

Avoid expensive calls to `SymbolTable::lookupNearestSymbolFrom` in verifier

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D131285

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index e65827195340..7dcd48f0c5e8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -647,7 +647,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 
 def LLVM_CallOp : LLVM_Op<"call",
                           [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
-                           DeclareOpInterfaceMethods<CallOpInterface>]> {
+                           DeclareOpInterfaceMethods<CallOpInterface>,
+                           DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "Call to an LLVM function.";
   let description = [{
 
@@ -701,8 +702,8 @@ def LLVM_CallOp : LLVM_Op<"call",
             StringAttr::get($_builder.getContext(), callee), operands);
     }]>];
   let hasCustomAssemblyFormat = 1;
-  let hasVerifier = 1;
 }
+
 def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
   let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
   let results = (outs LLVM_Type:$res);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3d9ec1798e50..4cb6a5658c51 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1160,7 +1160,7 @@ Operation::operand_range CallOp::getArgOperands() {
   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
-LogicalResult CallOp::verify() {
+LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
 
@@ -1184,7 +1184,7 @@ LogicalResult CallOp::verify() {
     fnType = ptrType.getElementType();
   } else {
     Operation *callee =
-        SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
+        symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
     if (!callee)
       return emitOpError()
              << "'" << calleeName.getValue()

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 459d6b188753..4ac9724b0838 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -191,6 +191,7 @@ func.func @store_malformed_elem_type(%foo: !llvm.ptr, %bar: f32) {
 func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
   // expected-error at +1 {{expected function type}}
   llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
+  llvm.return
 }
 
 // -----
@@ -198,6 +199,7 @@ func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 func.func @invalid_call() {
   // expected-error at +1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
   "llvm.call"() : () -> ()
+  llvm.return
 }
 
 // -----
@@ -205,6 +207,7 @@ func.func @invalid_call() {
 func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
   // expected-error at +1 {{expected function type}}
   llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
+  llvm.return
 }
 
 // -----
@@ -212,6 +215,7 @@ func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
 func.func @call_unknown_symbol() {
   // expected-error at +1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}}
   llvm.call @missing_callee() : () -> ()
+  llvm.return
 }
 
 // -----
@@ -221,6 +225,7 @@ func.func private @standard_func_callee()
 func.func @call_non_llvm() {
   // expected-error at +1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}}
   llvm.call @standard_func_callee() : () -> ()
+  llvm.return
 }
 
 // -----
@@ -228,6 +233,7 @@ func.func @call_non_llvm() {
 func.func @call_non_llvm_indirect(%arg0 : tensor<*xi32>) {
   // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type}}
   "llvm.call"(%arg0) : (tensor<*xi32>) -> ()
+  llvm.return
 }
 
 // -----
@@ -237,6 +243,7 @@ llvm.func @callee_func(i8) -> ()
 func.func @callee_arg_mismatch(%arg0 : i32) {
   // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}}
   llvm.call @callee_func(%arg0) : (i32) -> ()
+  llvm.return
 }
 
 // -----
@@ -244,6 +251,7 @@ func.func @callee_arg_mismatch(%arg0 : i32) {
 func.func @indirect_callee_arg_mismatch(%arg0 : i32, %callee : !llvm.ptr<func<void(i8)>>) {
   // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}}
   "llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, i32) -> ()
+  llvm.return
 }
 
 // -----
@@ -253,6 +261,7 @@ llvm.func @callee_func() -> (i8)
 func.func @callee_return_mismatch() {
   // expected-error at +1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}}
   %res = llvm.call @callee_func() : () -> (i32)
+  llvm.return
 }
 
 // -----
@@ -260,6 +269,7 @@ func.func @callee_return_mismatch() {
 func.func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) {
   // expected-error at +1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}}
   "llvm.call"(%callee) : (!llvm.ptr<func<i8()>>) -> (i32)
+  llvm.return
 }
 
 // -----
@@ -267,6 +277,7 @@ func.func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) {
 func.func @call_too_many_results(%callee : () -> (i32,i32)) {
   // expected-error at +1 {{expected function with 0 or 1 result}}
   llvm.call %callee() : () -> (i32, i32)
+  llvm.return
 }
 
 // -----
@@ -274,6 +285,7 @@ func.func @call_too_many_results(%callee : () -> (i32,i32)) {
 func.func @call_non_llvm_result(%callee : () -> (tensor<*xi32>)) {
   // expected-error at +1 {{expected result to have LLVM type}}
   llvm.call %callee() : () -> (tensor<*xi32>)
+  llvm.return
 }
 
 // -----
@@ -281,6 +293,7 @@ func.func @call_non_llvm_result(%callee : () -> (tensor<*xi32>)) {
 func.func @call_non_llvm_input(%callee : (tensor<*xi32>) -> (), %arg : tensor<*xi32>) {
   // expected-error at +1 {{expected LLVM types as inputs}}
   llvm.call %callee(%arg) : (tensor<*xi32>) -> ()
+  llvm.return
 }
 
 // -----


        


More information about the Mlir-commits mailing list