[Mlir-commits] [mlir] f05173d - Implement callee/caller type checking for llvm.call
Mehdi Amini
llvmlistbot at llvm.org
Sun Oct 4 13:15:20 PDT 2020
Author: Mehdi Amini
Date: 2020-10-04T20:15:06Z
New Revision: f05173d0bfc268530afd9fb9622caf7f1db5f8d6
URL: https://github.com/llvm/llvm-project/commit/f05173d0bfc268530afd9fb9622caf7f1db5f8d6
DIFF: https://github.com/llvm/llvm-project/commit/f05173d0bfc268530afd9fb9622caf7f1db5f8d6.diff
LOG: Implement callee/caller type checking for llvm.call
This aligns the behavior with the standard call as well as the LLVM verifier.
Reviewed By: ftynse, dcaballe
Differential Revision: https://reviews.llvm.org/D88362
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b5b8e45eb21f..94f6809349f1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -417,11 +417,7 @@ def LLVM_CallOp : LLVM_Op<"call">,
$_state.addAttributes(attributes);
$_state.addOperands(operands);
}]>];
- let verifier = [{
- if (getNumResults() > 1)
- return emitOpError("must have 0 or 1 result");
- return success();
- }];
+ let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseCallOp(parser, result); }];
let printer = [{ printCallOp(p, *this); }];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e1f7aede782d..8da2e0b36300 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -531,9 +531,83 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser,
}
//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::CallOp.
+// Verifying/Printing/parsing for LLVM::CallOp.
//===----------------------------------------------------------------------===//
+static LogicalResult verify(CallOp &op) {
+ if (op.getNumResults() > 1)
+ return op.emitOpError("must have 0 or 1 result");
+
+ // Type for the callee, we'll get it
diff erently depending if it is a direct
+ // or indirect call.
+ LLVMType fnType;
+
+ bool isIndirect = false;
+
+ // If this is an indirect call, the callee attribute is missing.
+ Optional<StringRef> calleeName = op.callee();
+ if (!calleeName) {
+ isIndirect = true;
+ if (!op.getNumOperands())
+ return op.emitOpError(
+ "must have either a `callee` attribute or at least an operand");
+ fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
+ if (!fnType)
+ return op.emitOpError("indirect call to a non-llvm type: ")
+ << op.getOperand(0).getType();
+ auto ptrType = fnType.dyn_cast<LLVMPointerType>();
+ if (!ptrType)
+ return op.emitOpError("indirect call expects a pointer as callee: ")
+ << fnType;
+ fnType = ptrType.getElementType();
+ } else {
+ Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
+ if (!callee)
+ return op.emitOpError()
+ << "'" << *calleeName
+ << "' does not reference a symbol in the current scope";
+ auto fn = dyn_cast<LLVMFuncOp>(callee);
+ if (!fn)
+ return op.emitOpError() << "'" << *calleeName
+ << "' does not reference a valid LLVM function";
+
+ fnType = fn.getType();
+ }
+ if (!fnType.isFunctionTy())
+ return op.emitOpError("callee does not have a functional type: ") << fnType;
+
+ // Verify that the operand and result types match the callee.
+
+ if (!fnType.isFunctionVarArg() &&
+ fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
+ return op.emitOpError()
+ << "incorrect number of operands ("
+ << (op.getNumOperands() - isIndirect)
+ << ") for callee (expecting: " << fnType.getFunctionNumParams()
+ << ")";
+
+ if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
+ return op.emitOpError() << "incorrect number of operands ("
+ << (op.getNumOperands() - isIndirect)
+ << ") for varargs callee (expecting at least: "
+ << fnType.getFunctionNumParams() << ")";
+
+ for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
+ if (op.getOperand(i + isIndirect).getType() !=
+ fnType.getFunctionParamType(i))
+ return op.emitOpError() << "operand type mismatch for operand " << i
+ << ": " << op.getOperand(i + isIndirect).getType()
+ << " != " << fnType.getFunctionParamType(i);
+
+ if (op.getNumResults() &&
+ op.getResult(0).getType() != fnType.getFunctionResultType())
+ return op.emitOpError()
+ << "result type mismatch: " << op.getResult(0).getType()
+ << " != " << fnType.getFunctionResultType();
+
+ return success();
+}
+
static void printCallOp(OpAsmPrinter &p, CallOp &op) {
auto callee = op.callee();
bool isDirect = callee.hasValue();
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c19795e98b68..322d5397a417 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -125,6 +125,75 @@ func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : !llvm.i8) {
// -----
+func @invalid_call() {
+ // expected-error at +1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
+ "llvm.call"() : () -> ()
+}
+
+// -----
+
+func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : !llvm.i8) {
+ // expected-error at +1 {{expected function type}}
+ llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
+}
+
+// -----
+
+func @call_unknown_symbol() {
+ // expected-error at +1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}}
+ llvm.call @missing_callee() : () -> ()
+}
+
+// -----
+
+func @standard_func_callee()
+
+func @call_non_llvm() {
+ // expected-error at +1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}}
+ llvm.call @standard_func_callee() : () -> ()
+}
+
+// -----
+
+func @call_non_llvm_indirect(%arg0 : i32) {
+ // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}}
+ "llvm.call"(%arg0) : (i32) -> ()
+}
+
+// -----
+
+llvm.func @callee_func(!llvm.i8) -> ()
+
+func @callee_arg_mismatch(%arg0 : !llvm.i32) {
+ // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}}
+ llvm.call @callee_func(%arg0) : (!llvm.i32) -> ()
+}
+
+// -----
+
+func @indirect_callee_arg_mismatch(%arg0 : !llvm.i32, %callee : !llvm.ptr<func<void(i8)>>) {
+ // expected-error at +1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}}
+ "llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, !llvm.i32) -> ()
+}
+
+// -----
+
+llvm.func @callee_func() -> (!llvm.i8)
+
+func @callee_return_mismatch() {
+ // expected-error at +1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}}
+ %res = llvm.call @callee_func() : () -> (!llvm.i32)
+}
+
+// -----
+
+func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) {
+ // expected-error at +1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}}
+ "llvm.call"(%callee) : (!llvm.ptr<func<i8()>>) -> (!llvm.i32)
+}
+
+// -----
+
func @call_too_many_results(%callee : () -> (i32,i32)) {
// expected-error at +1 {{expected function with 0 or 1 result}}
llvm.call %callee() : () -> (i32, i32)
More information about the Mlir-commits
mailing list