[Mlir-commits] [mlir] 9bb5bff - [mlir] Add an assertion on creating an Operation with null result types
Alex Zinenko
llvmlistbot at llvm.org
Thu Nov 19 13:28:47 PST 2020
Author: Alex Zinenko
Date: 2020-11-19T22:28:38+01:00
New Revision: 9bb5bff570140d4fc5b1750ca7352b840dd58ed7
URL: https://github.com/llvm/llvm-project/commit/9bb5bff570140d4fc5b1750ca7352b840dd58ed7
DIFF: https://github.com/llvm/llvm-project/commit/9bb5bff570140d4fc5b1750ca7352b840dd58ed7.diff
LOG: [mlir] Add an assertion on creating an Operation with null result types
Null types are commonly used as an error marker. Catch them in the constructor
of Operation if they are present in the result type list, as otherwise this
could lead to further surprising behavior when querying op result types.
Fix AsyncToLLVM and StandardToLLVM that were using null types when constructing
operations.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D91770
Added:
Modified:
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/TypeRange.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 8195e17fe0c3a..0cbf3debd8942 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -223,7 +223,7 @@ static void addResumeFunction(ModuleOp module) {
auto *block = resumeOp.addEntryBlock();
OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
- blockBuilder.create<LLVM::CallOp>(loc, Type(),
+ blockBuilder.create<LLVM::CallOp>(loc, TypeRange(),
blockBuilder.getSymbolRefAttr(kCoroResume),
resumeOp.getArgument(0));
@@ -343,7 +343,8 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
ValueRange({coroId.getResult(0), coroHdl.getResult(0)}));
// Free the memory.
- builder.create<LLVM::CallOp>(loc, Type(), builder.getSymbolRefAttr(kFree),
+ builder.create<LLVM::CallOp>(loc, TypeRange(),
+ builder.getSymbolRefAttr(kFree),
ValueRange(coroMem.getResult(0)));
// Branch into the suspend block.
builder.create<BranchOp>(loc, suspendBlock);
@@ -503,7 +504,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Call async runtime API to execute a coroutine in the managed thread.
SmallVector<Value, 2> executeArgs = {coro.coroHandle, resumePtr.res()};
- builder.create<CallOp>(loc, Type(), kExecute, executeArgs);
+ builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs);
// Split the entry block before the terminator.
Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
@@ -524,7 +525,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// to async runtime to emplace the result token.
for (Operation &op : execute.body().getOps()) {
if (isa<async::YieldOp>(op)) {
- builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
+ builder.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
continue;
}
builder.clone(op, valueMapping);
@@ -671,7 +672,7 @@ class AwaitOpLoweringBase : public ConversionPattern {
// Inside regular function we convert await operation to the blocking
// async API await function call.
if (!isInCoroutine)
- rewriter.create<CallOp>(loc, Type(), blockingAwaitFuncName,
+ rewriter.create<CallOp>(loc, TypeRange(), blockingAwaitFuncName,
ValueRange(op->getOperand(0)));
// Inside the coroutine we convert await operation into coroutine suspension
@@ -696,7 +697,7 @@ class AwaitOpLoweringBase : public ConversionPattern {
// the async await argument becomes ready.
SmallVector<Value, 3> awaitAndExecuteArgs = {
await.getOperand(), coro.coroHandle, resumePtr.res()};
- builder.create<CallOp>(loc, Type(), coroAwaitFuncName,
+ builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName,
awaitAndExecuteArgs);
// Split the entry block before the await operation.
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index de540b0832bc7..49942995fc78c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2290,7 +2290,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
auto callOp = cast<CallOpType>(op);
// Pack the result types into a struct.
- Type packedResult;
+ Type packedResult = nullptr;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
@@ -2302,8 +2302,9 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
auto promoted = this->typeConverter.promoteOperands(
op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter);
- auto newOp = rewriter.create<LLVM::CallOp>(op->getLoc(), packedResult,
- promoted, op->getAttrs());
+ auto newOp = rewriter.create<LLVM::CallOp>(
+ op->getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
+ promoted, op->getAttrs());
SmallVector<Value, 4> results;
if (numResults < 2) {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 5fed2ec8e7532..e725dd87d93f2 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -167,6 +167,8 @@ Operation::Operation(Location location, OperationName name,
: location(location), numSuccs(numSuccessors), numRegions(numRegions),
hasOperandStorage(hasOperandStorage), hasSingleResult(false), name(name),
attrs(attributes) {
+ assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&
+ "unexpected null result type");
if (!resultTypes.empty()) {
// If there is a single result it is stored in-place, otherwise use a tuple.
hasSingleResult = resultTypes.size() == 1;
diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
index f3f6fb54c707b..dadcb627b0930 100644
--- a/mlir/lib/IR/TypeRange.cpp
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -14,7 +14,10 @@ using namespace mlir;
// TypeRange
TypeRange::TypeRange(ArrayRef<Type> types)
- : TypeRange(types.data(), types.size()) {}
+ : TypeRange(types.data(), types.size()) {
+ assert(llvm::all_of(types, [](Type t) { return t; }) &&
+ "attempting to construct a TypeRange with null types");
+}
TypeRange::TypeRange(OperandRange values)
: TypeRange(values.begin().getBase(), values.size()) {}
TypeRange::TypeRange(ResultRange values)
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 4d97fe9446b86..9c709a653b27b 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -203,3 +203,13 @@ func @get_gv3_memref() {
return
}
+// This should not trigger an assertion by creating an LLVM::CallOp with a
+// nullptr result type.
+
+// CHECK-LABEL: @call_zero_result_func
+func @call_zero_result_func() {
+ // CHECK: call @zero_result_func
+ call @zero_result_func() : () -> ()
+ return
+}
+func private @zero_result_func()
More information about the Mlir-commits
mailing list