[Mlir-commits] [mlir] f05173d - Implement callee/caller type checking for llvm.call

Mehdi Amini llvmlistbot at llvm.org
Sun Oct 4 13:15:20 PDT 2020


Author: Mehdi Amini
Date: 2020-10-04T20:15:06Z
New Revision: f05173d0bfc268530afd9fb9622caf7f1db5f8d6

URL: https://github.com/llvm/llvm-project/commit/f05173d0bfc268530afd9fb9622caf7f1db5f8d6
DIFF: https://github.com/llvm/llvm-project/commit/f05173d0bfc268530afd9fb9622caf7f1db5f8d6.diff

LOG: Implement callee/caller type checking for llvm.call

This aligns the behavior with the standard call as well as the LLVM verifier.

Reviewed By: ftynse, dcaballe

Differential Revision: https://reviews.llvm.org/D88362

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b5b8e45eb21f..94f6809349f1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -417,11 +417,7 @@ def LLVM_CallOp : LLVM_Op<"call">,
       $_state.addAttributes(attributes);
       $_state.addOperands(operands);
     }]>];
-  let verifier = [{
-    if (getNumResults() > 1)
-      return emitOpError("must have 0 or 1 result");
-    return success();
-  }];
+  let verifier = [{ return ::verify(*this); }];
   let parser = [{ return parseCallOp(parser, result); }];
   let printer = [{ printCallOp(p, *this); }];
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e1f7aede782d..8da2e0b36300 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -531,9 +531,83 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser,
 }
 
 //===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::CallOp.
+// Verifying/Printing/parsing for LLVM::CallOp.
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verify(CallOp &op) {
+  if (op.getNumResults() > 1)
+    return op.emitOpError("must have 0 or 1 result");
+
+  // Type for the callee, we'll get it 
diff erently depending if it is a direct
+  // or indirect call.
+  LLVMType fnType;
+
+  bool isIndirect = false;
+
+  // If this is an indirect call, the callee attribute is missing.
+  Optional<StringRef> calleeName = op.callee();
+  if (!calleeName) {
+    isIndirect = true;
+    if (!op.getNumOperands())
+      return op.emitOpError(
+          "must have either a `callee` attribute or at least an operand");
+    fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
+    if (!fnType)
+      return op.emitOpError("indirect call to a non-llvm type: ")
+             << op.getOperand(0).getType();
+    auto ptrType = fnType.dyn_cast<LLVMPointerType>();
+    if (!ptrType)
+      return op.emitOpError("indirect call expects a pointer as callee: ")
+             << fnType;
+    fnType = ptrType.getElementType();
+  } else {
+    Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
+    if (!callee)
+      return op.emitOpError()
+             << "'" << *calleeName
+             << "' does not reference a symbol in the current scope";
+    auto fn = dyn_cast<LLVMFuncOp>(callee);
+    if (!fn)
+      return op.emitOpError() << "'" << *calleeName
+                              << "' does not reference a valid LLVM function";
+
+    fnType = fn.getType();
+  }
+  if (!fnType.isFunctionTy())
+    return op.emitOpError("callee does not have a functional type: ") << fnType;
+
+  // Verify that the operand and result types match the callee.
+
+  if (!fnType.isFunctionVarArg() &&
+      fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
+    return op.emitOpError()
+           << "incorrect number of operands ("
+           << (op.getNumOperands() - isIndirect)
+           << ") for callee (expecting: " << fnType.getFunctionNumParams()
+           << ")";
+
+  if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
+    return op.emitOpError() << "incorrect number of operands ("
+                            << (op.getNumOperands() - isIndirect)
+                            << ") for varargs callee (expecting at least: "
+                            << fnType.getFunctionNumParams() << ")";
+
+  for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
+    if (op.getOperand(i + isIndirect).getType() !=
+        fnType.getFunctionParamType(i))
+      return op.emitOpError() << "operand type mismatch for operand " << i
+                              << ": " << op.getOperand(i + isIndirect).getType()
+                              << " != " << fnType.getFunctionParamType(i);
+
+  if (op.getNumResults() &&
+      op.getResult(0).getType() != fnType.getFunctionResultType())
+    return op.emitOpError()
+           << "result type mismatch: " << op.getResult(0).getType()
+           << " != " << fnType.getFunctionResultType();
+
+  return success();
+}
+
 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
   auto callee = op.callee();
   bool isDirect = callee.hasValue();

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c19795e98b68..322d5397a417 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -125,6 +125,75 @@ func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : !llvm.i8) {
 
 // -----
 
+func @invalid_call() {
+  // expected-error at +1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
+  "llvm.call"() : () -> ()
+}
+
+// -----
+
+func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : !llvm.i8) {
+  // expected-error at +1 {{expected function type}}
+  llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
+}
+
+// -----
+
+func @call_unknown_symbol() {
+  // expected-error at +1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}}
+  llvm.call @missing_callee() : () -> ()
+}
+
+// -----
+
+func @standard_func_callee()
+
+func @call_non_llvm() {
+  // expected-error at +1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}}
+  llvm.call @standard_func_callee() : () -> ()
+}
+
+// -----
+
+func @call_non_llvm_indirect(%arg0 : i32) {
+  // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}}
+  "llvm.call"(%arg0) : (i32) -> ()
+}
+
+// -----
+
+llvm.func @callee_func(!llvm.i8) -> ()
+
+func @callee_arg_mismatch(%arg0 : !llvm.i32) {
+  // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}}
+  llvm.call @callee_func(%arg0) : (!llvm.i32) -> ()
+}
+
+// -----
+
+func @indirect_callee_arg_mismatch(%arg0 : !llvm.i32, %callee : !llvm.ptr<func<void(i8)>>) {
+  // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}}
+  "llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, !llvm.i32) -> ()
+}
+
+// -----
+
+llvm.func @callee_func() -> (!llvm.i8)
+
+func @callee_return_mismatch() {
+  // expected-error at +1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}}
+  %res = llvm.call @callee_func() : () -> (!llvm.i32)
+}
+
+// -----
+
+func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) {
+  // expected-error at +1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}}
+  "llvm.call"(%callee) : (!llvm.ptr<func<i8()>>) -> (!llvm.i32)
+}
+
+// -----
+
 func @call_too_many_results(%callee : () -> (i32,i32)) {
   // expected-error at +1 {{expected function with 0 or 1 result}}
   llvm.call %callee() : () -> (i32, i32)


        


More information about the Mlir-commits mailing list