[Mlir-commits] [mlir] db884da - [mlir] Explicitly track branch instructions in translation to LLVM IR
Alex Zinenko
llvmlistbot at llvm.org
Thu Dec 10 02:09:06 PST 2020
Author: Alex Zinenko
Date: 2020-12-10T11:08:58+01:00
New Revision: db884dafb7b5771e6ae01e8252f1520fac3e1c77
URL: https://github.com/llvm/llvm-project/commit/db884dafb7b5771e6ae01e8252f1520fac3e1c77
DIFF: https://github.com/llvm/llvm-project/commit/db884dafb7b5771e6ae01e8252f1520fac3e1c77.diff
LOG: [mlir] Explicitly track branch instructions in translation to LLVM IR
The current implementation of the translation to LLVM IR relies on the
existence of a one-to-one mapping between MLIR blocks and LLVM IR basic blocks
in order to configure PHI nodes with appropriate source blocks. The one-to-one
mapping model is broken in presence of OpenMP operations that use LLVM's
OpenMPIRBuilder, which produces multiple blocks under the hood. This can lead
to invalid LLVM IR being emitted if OpenMPIRBuilder moved the branch operation
into a basic block different from the one it was originally created in;
specifically, a block that is not a direct predecessor could be used in the PHI
node. Instead, keep track of the mapping between MLIR LLVM dialect branch
operations and their LLVM IR counterparts and take the parent basic block of
the LLVM IR instruction at the moment of connecting the PHI nodes to
predecessors.
This behavior cannot be triggered as of now, but will be once we introduce the
conversion of OpenMP workshare loops.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D92845
Added:
Modified:
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 996c8de06e37..d3d289414b38 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -151,6 +151,11 @@ class ModuleTranslation {
llvm::StringMap<llvm::Function *> functionMapping;
DenseMap<Value, llvm::Value *> valueMapping;
DenseMap<Block *, llvm::BasicBlock *> blockMapping;
+
+ /// A mapping between MLIR LLVM dialect terminators and LLVM IR terminators
+ /// they are converted to. This allows for conneting PHI nodes to the source
+ /// values after all operations are converted.
+ DenseMap<Operation *, llvm::Instruction *> branchMapping;
};
} // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index c20b19f2d5ca..057f57409940 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -340,9 +340,10 @@ static Value getPHISourceValue(Block *current, Block *pred,
/// Connect the PHI nodes to the results of preceding blocks.
template <typename T>
-static void
-connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
- const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
+static void connectPHINodes(
+ T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
+ const DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
+ const DenseMap<Operation *, llvm::Instruction *> &branchMapping) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
@@ -355,9 +356,17 @@ connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
auto &phiNode = numberedPhiNode.value();
unsigned index = numberedPhiNode.index();
for (auto *pred : bb->getPredecessors()) {
+ // Find the LLVM IR block that contains the converted terminator
+ // instruction and use it in the PHI node. Note that this block is not
+ // necessarily the same as blockMapping.lookup(pred), some operations
+ // (in particular, OpenMP operations using OpenMPIRBuilder) may have
+ // split the blocks.
+ llvm::Instruction *terminator =
+ branchMapping.lookup(pred->getTerminator());
+ assert(terminator && "missing the mapping for a terminator");
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
bb, pred, numArguments, index)),
- blockMapping.lookup(pred));
+ terminator->getParent());
}
}
}
@@ -476,7 +485,7 @@ void ModuleTranslation::convertOmpOpRegions(
}
// Finally, after all blocks have been traversed and values mapped,
// connect the PHI nodes to the results of preceding blocks.
- connectPHINodes(region, valueMapping, blockMapping);
+ connectPHINodes(region, valueMapping, blockMapping, branchMapping);
}
LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
@@ -682,7 +691,9 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
// Emit branches. We need to look up the remapped blocks and ignore the block
// arguments that were transformed into PHI nodes.
if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
- builder.CreateBr(blockMapping[brOp.getSuccessor()]);
+ llvm::BranchInst *branch =
+ builder.CreateBr(blockMapping[brOp.getSuccessor()]);
+ branchMapping.try_emplace(&opInst, branch);
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
@@ -699,9 +710,11 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
.createBranchWeights(static_cast<uint32_t>(trueWeight),
static_cast<uint32_t>(falseWeight));
}
- builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
- blockMapping[condbrOp.getSuccessor(0)],
- blockMapping[condbrOp.getSuccessor(1)], branchWeights);
+ llvm::BranchInst *branch = builder.CreateCondBr(
+ valueMapping.lookup(condbrOp.getOperand(0)),
+ blockMapping[condbrOp.getSuccessor(0)],
+ blockMapping[condbrOp.getSuccessor(1)], branchWeights);
+ branchMapping.try_emplace(&opInst, branch);
return success();
}
@@ -893,10 +906,11 @@ forwardPassthroughAttributes(Location loc, Optional<ArrayAttr> attributes,
}
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
- // Clear the block and value mappings, they are only relevant within one
+ // Clear the block, branch value mappings, they are only relevant within one
// function.
blockMapping.clear();
valueMapping.clear();
+ branchMapping.clear();
llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
// Translate the debug information for this function.
@@ -964,7 +978,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Finally, after all blocks have been traversed and values mapped, connect
// the PHI nodes to the results of preceding blocks.
- connectPHINodes(func, valueMapping, blockMapping);
+ connectPHINodes(func, valueMapping, blockMapping, branchMapping);
return success();
}
More information about the Mlir-commits
mailing list