[Mlir-commits] [mlir] 0881a4f - [mlir] make ModuleTranslation mapping fields private

Alex Zinenko llvmlistbot at llvm.org
Thu Feb 11 05:50:59 PST 2021


Author: Alex Zinenko
Date: 2021-02-11T14:50:49+01:00
New Revision: 0881a4f1bf769a588e7d6d6af8501a230a2f42c8

URL: https://github.com/llvm/llvm-project/commit/0881a4f1bf769a588e7d6d6af8501a230a2f42c8
DIFF: https://github.com/llvm/llvm-project/commit/0881a4f1bf769a588e7d6d6af8501a230a2f42c8.diff

LOG: [mlir] make ModuleTranslation mapping fields private

ModuleTranslation contains multiple fields that keep track of the mappings
between various MLIR and LLVM IR components. The original ModuleTranslation
extension model was based on inheritance, with these fields being protected and
thus accessible in the ModuleTranslation and derived classes. The
inheritance-based model doesn't scale to translation of more than one derived
dialect and will be progressively replaced with a more flexible one based on
dialect interfaces and a translation state that is separate from
ModuleTranslation. This change prepares the replacement by making the mappings
private and providing public methods to access them.

Depends On D96436

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96437

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d1f42e06ae5d..ebe9a7c83d59 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -77,6 +77,63 @@ class ModuleTranslation {
   /// module requirements.
   static Block &getModuleBody(Operation *m) { return m->getRegion(0).front(); }
 
+  /// Stores the mapping between a function name and its LLVM IR representation.
+  void mapFunction(StringRef name, llvm::Function *func) {
+    auto result = functionMapping.try_emplace(name, func);
+    (void)result;
+    assert(result.second &&
+           "attempting to map a function that is already mapped");
+  }
+
+  /// Finds an LLVM IR function by its name.
+  llvm::Function *lookupFunction(StringRef name) const {
+    return functionMapping.lookup(name);
+  }
+
+  /// Stores the mapping between an MLIR value and its LLVM IR counterpart.
+  void mapValue(Value mlir, llvm::Value *llvm) { mapValue(mlir) = llvm; }
+
+  /// Provides write-once access to store the LLVM IR value corresponding to the
+  /// given MLIR value.
+  llvm::Value *&mapValue(Value value) {
+    llvm::Value *&llvm = valueMapping[value];
+    assert(llvm == nullptr &&
+           "attempting to map a value that is already mapped");
+    return llvm;
+  }
+
+  /// Finds an LLVM IR value corresponding to the given MLIR value.
+  llvm::Value *lookupValue(Value value) const {
+    return valueMapping.lookup(value);
+  }
+
+  /// Stores the mapping between an MLIR block and LLVM IR basic block.
+  void mapBlock(Block *mlir, llvm::BasicBlock *llvm) {
+    auto result = blockMapping.try_emplace(mlir, llvm);
+    (void)result;
+    assert(result.second && "attempting to map a block that is already mapped");
+  }
+
+  /// Finds an LLVM IR basic block that corresponds to the given MLIR block.
+  llvm::BasicBlock *lookupBlock(Block *block) const {
+    return blockMapping.lookup(block);
+  }
+
+  /// Stores the mapping between an MLIR operation with successors and a
+  /// corresponding LLVM IR instruction.
+  void mapBranch(Operation *mlir, llvm::Instruction *llvm) {
+    auto result = branchMapping.try_emplace(mlir, llvm);
+    (void)result;
+    assert(result.second &&
+           "attempting to map a branch that is already mapped");
+  }
+
+  /// Finds an LLVM IR instruction that corresponds to the given MLIR operation
+  /// with successors.
+  llvm::Instruction *lookupBranch(Operation *op) const {
+    return branchMapping.lookup(op);
+  }
+
 protected:
   /// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
   /// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
@@ -94,8 +151,6 @@ class ModuleTranslation {
   virtual LogicalResult convertOmpMaster(Operation &op,
                                          llvm::IRBuilder<> &builder);
   void convertOmpOpRegions(Region &region, StringRef blockName,
-                           DenseMap<Value, llvm::Value *> &valueMapping,
-                           DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
                            llvm::BasicBlock &sourceBlock,
                            llvm::BasicBlock &continuationBlock,
                            llvm::IRBuilder<> &builder,
@@ -147,7 +202,7 @@ class ModuleTranslation {
   /// A stateful object used to translate types.
   TypeToLLVMIRTranslator typeTranslator;
 
-protected:
+private:
   /// Mappings between original and translated values, used for lookups.
   llvm::StringMap<llvm::Function *> functionMapping;
   DenseMap<Value, llvm::Value *> valueMapping;

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 4924d4bed6d8..dce384634f78 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -94,7 +94,7 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
     } else {
       return type;
     }
-  } while (1);
+  } while (true);
 }
 
 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
@@ -119,8 +119,8 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
   if (auto floatAttr = attr.dyn_cast<FloatAttr>())
     return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
   if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
-    return llvm::ConstantExpr::getBitCast(
-        functionMapping.lookup(funcAttr.getValue()), llvmType);
+    return llvm::ConstantExpr::getBitCast(lookupFunction(funcAttr.getValue()),
+                                          llvmType);
   if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
     llvm::Type *elementType;
     uint64_t numElements;
@@ -337,7 +337,9 @@ static Value getPHISourceValue(Block *current, Block *pred,
     return condBranchOp.getSuccessor(0) == current
                ? condBranchOp.trueDestOperands()[index]
                : condBranchOp.falseDestOperands()[index];
-  } else if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
+  }
+
+  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
     // For switches, we take the operands from either the default case, or from
     // the case branch that was taken.
     if (switchOp.defaultDestination() == current)
@@ -353,15 +355,12 @@ static Value getPHISourceValue(Block *current, Block *pred,
 
 /// Connect the PHI nodes to the results of preceding blocks.
 template <typename T>
-static void connectPHINodes(
-    T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
-    const DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
-    const DenseMap<Operation *, llvm::Instruction *> &branchMapping) {
+static void connectPHINodes(T &func, const ModuleTranslation &state) {
   // Skip the first block, it cannot be branched to and its arguments correspond
   // to the arguments of the LLVM function.
   for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
     Block *bb = &*it;
-    llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
+    llvm::BasicBlock *llvmBB = state.lookupBlock(bb);
     auto phis = llvmBB->phis();
     auto numArguments = bb->getNumArguments();
     assert(numArguments == std::distance(phis.begin(), phis.end()));
@@ -371,15 +370,15 @@ static void connectPHINodes(
       for (auto *pred : bb->getPredecessors()) {
         // Find the LLVM IR block that contains the converted terminator
         // instruction and use it in the PHI node. Note that this block is not
-        // necessarily the same as blockMapping.lookup(pred), some operations
+        // necessarily the same as state.lookupBlock(pred), some operations
         // (in particular, OpenMP operations using OpenMPIRBuilder) may have
         // split the blocks.
         llvm::Instruction *terminator =
-            branchMapping.lookup(pred->getTerminator());
+            state.lookupBranch(pred->getTerminator());
         assert(terminator && "missing the mapping for a terminator");
-        phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
-                                bb, pred, numArguments, index)),
-                            terminator->getParent());
+        phiNode.addIncoming(
+            state.lookupValue(getPHISourceValue(bb, pred, numArguments, index)),
+            terminator->getParent());
       }
     }
   }
@@ -415,9 +414,8 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
                        llvm::BasicBlock &continuationBlock) {
     // ParallelOp has only one region associated with it.
     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
-    convertOmpOpRegions(region, "omp.par.region", valueMapping, blockMapping,
-                        *codeGenIP.getBlock(), continuationBlock, builder,
-                        bodyGenStatus);
+    convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
+                        continuationBlock, builder, bodyGenStatus);
   };
 
   // TODO: Perform appropriate actions according to the data-sharing
@@ -437,10 +435,10 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
 
   llvm::Value *ifCond = nullptr;
   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
-    ifCond = valueMapping.lookup(ifExprVar);
+    ifCond = lookupValue(ifExprVar);
   llvm::Value *numThreads = nullptr;
   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
-    numThreads = valueMapping.lookup(numThreadsVar);
+    numThreads = lookupValue(numThreadsVar);
   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
     pbKind = llvm::omp::getProcBindKind(bind.getValue());
@@ -460,15 +458,13 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
 
 void ModuleTranslation::convertOmpOpRegions(
     Region &region, StringRef blockName,
-    DenseMap<Value, llvm::Value *> &valueMapping,
-    DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
     llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock,
     llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) {
   llvm::LLVMContext &llvmContext = builder.getContext();
   for (Block &bb : region) {
     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
         llvmContext, blockName, builder.GetInsertBlock()->getParent());
-    blockMapping[&bb] = llvmBB;
+    mapBlock(&bb, llvmBB);
   }
 
   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
@@ -477,7 +473,7 @@ void ModuleTranslation::convertOmpOpRegions(
   // defs are converted before uses.
   llvm::SetVector<Block *> blocks = topologicalSort(region);
   for (Block *bb : blocks) {
-    llvm::BasicBlock *llvmBB = blockMapping[bb];
+    llvm::BasicBlock *llvmBB = lookupBlock(bb);
     // Retarget the branch of the entry block to the entry block of the
     // converted region (regions are single-entry).
     if (bb->isEntryBlock()) {
@@ -506,7 +502,7 @@ void ModuleTranslation::convertOmpOpRegions(
   }
   // Finally, after all blocks have been traversed and values mapped,
   // connect the PHI nodes to the results of preceding blocks.
-  connectPHINodes(region, valueMapping, blockMapping, branchMapping);
+  connectPHINodes(region, *this);
 }
 
 LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
@@ -520,9 +516,8 @@ LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
                        llvm::BasicBlock &continuationBlock) {
     // MasterOp has only one region associated with it.
     auto &region = cast<omp::MasterOp>(opInst).getRegion();
-    convertOmpOpRegions(region, "omp.master.region", valueMapping, blockMapping,
-                        *codeGenIP.getBlock(), continuationBlock, builder,
-                        bodyGenStatus);
+    convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
+                        continuationBlock, builder, bodyGenStatus);
   };
 
   // TODO: Perform finalization actions for variables. This has to be
@@ -551,12 +546,12 @@ LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
         "only static (default) loop schedule is currently supported");
 
   // Find the loop configuration.
-  llvm::Value *lowerBound = valueMapping.lookup(loop.lowerBound()[0]);
-  llvm::Value *upperBound = valueMapping.lookup(loop.upperBound()[0]);
-  llvm::Value *step = valueMapping.lookup(loop.step()[0]);
+  llvm::Value *lowerBound = lookupValue(loop.lowerBound()[0]);
+  llvm::Value *upperBound = lookupValue(loop.upperBound()[0]);
+  llvm::Value *step = lookupValue(loop.step()[0]);
   llvm::Type *ivType = step->getType();
   llvm::Value *chunk = loop.schedule_chunk_var()
-                           ? valueMapping[loop.schedule_chunk_var()]
+                           ? lookupValue(loop.schedule_chunk_var())
                            : llvm::ConstantInt::get(ivType, 1);
 
   // Set up the source location value for OpenMP runtime.
@@ -576,16 +571,15 @@ LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
     llvm::IRBuilder<>::InsertPointGuard guard(builder);
 
     // Make sure further conversions know about the induction variable.
-    valueMapping[loop.getRegion().front().getArgument(0)] = iv;
+    mapValue(loop.getRegion().front().getArgument(0), iv);
 
     llvm::BasicBlock *entryBlock = ip.getBlock();
     llvm::BasicBlock *exitBlock =
         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
 
     // Convert the body of the loop.
-    convertOmpOpRegions(loop.region(), "omp.wsloop.region", valueMapping,
-                        blockMapping, *entryBlock, *exitBlock, builder,
-                        bodyGenStatus);
+    convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
+                        *exitBlock, builder, bodyGenStatus);
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
@@ -715,17 +709,14 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
     auto operands = lookupValues(op.getOperands());
     ArrayRef<llvm::Value *> operandsRef(operands);
-    if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee")) {
-      return builder.CreateCall(functionMapping.lookup(attr.getValue()),
-                                operandsRef);
-    } else {
-      auto *calleePtrType =
-          cast<llvm::PointerType>(operandsRef.front()->getType());
-      auto *calleeType =
-          cast<llvm::FunctionType>(calleePtrType->getElementType());
-      return builder.CreateCall(calleeType, operandsRef.front(),
-                                operandsRef.drop_front());
-    }
+    if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
+      return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef);
+    auto *calleePtrType =
+        cast<llvm::PointerType>(operandsRef.front()->getType());
+    auto *calleeType =
+        cast<llvm::FunctionType>(calleePtrType->getElementType());
+    return builder.CreateCall(calleeType, operandsRef.front(),
+                              operandsRef.drop_front());
   };
 
   // Emit calls.  If the called function has a result, remap the corresponding
@@ -733,7 +724,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   if (isa<LLVM::CallOp>(opInst)) {
     llvm::Value *result = convertCall(opInst);
     if (opInst.getNumResults() != 0) {
-      valueMapping[opInst.getResult(0)] = result;
+      mapValue(opInst.getResult(0), result);
       return success();
     }
     // Check that LLVM call returns void for 0-result functions.
@@ -770,7 +761,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     llvm::Value *result =
         builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands()));
     if (opInst.getNumResults() != 0)
-      valueMapping[opInst.getResult(0)] = result;
+      mapValue(opInst.getResult(0), result);
     return success();
   }
 
@@ -778,17 +769,17 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     auto operands = lookupValues(opInst.getOperands());
     ArrayRef<llvm::Value *> operandsRef(operands);
     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
-      builder.CreateInvoke(functionMapping.lookup(attr.getValue()),
-                           blockMapping[invOp.getSuccessor(0)],
-                           blockMapping[invOp.getSuccessor(1)], operandsRef);
+      builder.CreateInvoke(lookupFunction(attr.getValue()),
+                           lookupBlock(invOp.getSuccessor(0)),
+                           lookupBlock(invOp.getSuccessor(1)), operandsRef);
     } else {
       auto *calleePtrType =
           cast<llvm::PointerType>(operandsRef.front()->getType());
       auto *calleeType =
           cast<llvm::FunctionType>(calleePtrType->getElementType());
       builder.CreateInvoke(
-          calleeType, operandsRef.front(), blockMapping[invOp.getSuccessor(0)],
-          blockMapping[invOp.getSuccessor(1)], operandsRef.drop_front());
+          calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)),
+          lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front());
     }
     return success();
   }
@@ -799,12 +790,12 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
         builder.CreateLandingPad(ty, lpOp.getNumOperands());
 
     // Add clauses
-    for (auto operand : lookupValues(lpOp.getOperands())) {
+    for (llvm::Value *operand : lookupValues(lpOp.getOperands())) {
       // All operands should be constant - checked by verifier
-      if (auto constOperand = dyn_cast<llvm::Constant>(operand))
+      if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
         lpi->addClause(constOperand);
     }
-    valueMapping[lpOp.getResult()] = lpi;
+    mapValue(lpOp.getResult(), lpi);
     return success();
   }
 
@@ -812,8 +803,8 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   // arguments that were transformed into PHI nodes.
   if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
     llvm::BranchInst *branch =
-        builder.CreateBr(blockMapping[brOp.getSuccessor()]);
-    branchMapping.try_emplace(&opInst, branch);
+        builder.CreateBr(lookupBlock(brOp.getSuccessor()));
+    mapBranch(&opInst, branch);
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
@@ -831,10 +822,10 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
                                    static_cast<uint32_t>(falseWeight));
     }
     llvm::BranchInst *branch = builder.CreateCondBr(
-        valueMapping.lookup(condbrOp.getOperand(0)),
-        blockMapping[condbrOp.getSuccessor(0)],
-        blockMapping[condbrOp.getSuccessor(1)], branchWeights);
-    branchMapping.try_emplace(&opInst, branch);
+        lookupValue(condbrOp.getOperand(0)),
+        lookupBlock(condbrOp.getSuccessor(0)),
+        lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
+    mapBranch(&opInst, branch);
     return success();
   }
   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
@@ -849,8 +840,8 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     }
 
     llvm::SwitchInst *switchInst =
-        builder.CreateSwitch(valueMapping[switchOp.value()],
-                             blockMapping[switchOp.defaultDestination()],
+        builder.CreateSwitch(lookupValue(switchOp.value()),
+                             lookupBlock(switchOp.defaultDestination()),
                              switchOp.caseDestinations().size(), branchWeights);
 
     auto *ty =
@@ -860,9 +851,9 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
                    switchOp.caseDestinations()))
       switchInst->addCase(
           llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
-          blockMapping[std::get<1>(i)]);
+          lookupBlock(std::get<1>(i)));
 
-    branchMapping.try_emplace(&opInst, switchInst);
+    mapBranch(&opInst, switchInst);
     return success();
   }
 
@@ -877,9 +868,9 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     assert((global || function) &&
            "referencing an undefined global or function");
 
-    valueMapping[addressOfOp.getResult()] =
-        global ? globalsMapping.lookup(global)
-               : functionMapping.lookup(function.getName());
+    mapValue(addressOfOp.getResult(), global
+                                          ? globalsMapping.lookup(global)
+                                          : lookupFunction(function.getName()));
     return success();
   }
 
@@ -899,7 +890,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
 /// suitable for further insertion into the end of the block.
 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
                                               llvm::IRBuilder<> &builder) {
-  builder.SetInsertPoint(blockMapping[&bb]);
+  builder.SetInsertPoint(lookupBlock(&bb));
   auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
 
   // Before traversing operations, make block arguments available through
@@ -919,7 +910,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
                          "block argument does not have an LLVM type");
       llvm::Type *type = convertType(wrappedType);
       llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
-      valueMapping[arg] = phi;
+      mapValue(arg, phi);
     }
   }
 
@@ -957,11 +948,11 @@ LogicalResult ModuleTranslation::convertGlobals() {
       llvm::IRBuilder<> builder(llvmModule->getContext());
       for (auto &op : initializer->without_terminator()) {
         if (failed(convertOperation(op, builder)) ||
-            !isa<llvm::Constant>(valueMapping.lookup(op.getResult(0))))
+            !isa<llvm::Constant>(lookupValue(op.getResult(0))))
           return emitError(op.getLoc(), "unemittable constant value");
       }
       ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
-      cst = cast<llvm::Constant>(valueMapping.lookup(ret.getOperand(0)));
+      cst = cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
     }
 
     auto linkage = convertLinkageToLLVM(op.linkage());
@@ -1064,7 +1055,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   blockMapping.clear();
   valueMapping.clear();
   branchMapping.clear();
-  llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
+  llvm::Function *llvmFunc = lookupFunction(func.getName());
 
   // Translate the debug information for this function.
   debugTranslation->translate(func, *llvmFunc);
@@ -1118,7 +1109,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
           llvmArg.getType()->getPointerElementType()));
     }
 
-    valueMapping[mlirArg] = &llvmArg;
+    mapValue(mlirArg, &llvmArg);
     argIdx++;
   }
 
@@ -1135,7 +1126,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   for (auto &bb : func) {
     auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
     llvmBB->insertInto(llvmFunc);
-    blockMapping[&bb] = llvmBB;
+    mapBlock(&bb, llvmBB);
   }
 
   // Then, convert blocks one by one in topological order to ensure defs are
@@ -1149,7 +1140,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Finally, after all blocks have been traversed and values mapped, connect
   // the PHI nodes to the results of preceding blocks.
-  connectPHINodes(func, valueMapping, blockMapping, branchMapping);
+  connectPHINodes(func, *this);
   return success();
 }
 
@@ -1170,7 +1161,7 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
         cast<llvm::FunctionType>(convertType(function.getType())));
     llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
     llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage()));
-    functionMapping[function.getName()] = llvmFunc;
+    mapFunction(function.getName(), llvmFunc);
 
     // Forward the pass-through attributes to LLVM.
     if (failed(forwardPassthroughAttributes(function.getLoc(),
@@ -1204,10 +1195,8 @@ SmallVector<llvm::Value *, 8>
 ModuleTranslation::lookupValues(ValueRange values) {
   SmallVector<llvm::Value *, 8> remapped;
   remapped.reserve(values.size());
-  for (Value v : values) {
-    assert(valueMapping.count(v) && "referencing undefined value");
-    remapped.push_back(valueMapping.lookup(v));
-  }
+  for (Value v : values)
+    remapped.push_back(lookupValue(v));
   return remapped;
 }
 

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 428b61589134..8384bbffe26c 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -126,14 +126,13 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
     // Then, rewrite the name based on its kind.
     bool isVariadicOperand = isVariadicOperandName(op, name);
     if (isOperandName(op, name)) {
-      auto result = isVariadicOperand
-                        ? formatv("lookupValues(op.{0}())", name)
-                        : formatv("valueMapping.lookup(op.{0}())", name);
+      auto result = isVariadicOperand ? formatv("lookupValues(op.{0}())", name)
+                                      : formatv("lookupValue(op.{0}())", name);
       bs << result;
     } else if (isAttributeName(op, name)) {
       bs << formatv("op.{0}()", name);
     } else if (isResultName(op, name)) {
-      bs << formatv("valueMapping[op.{0}()]", name);
+      bs << formatv("mapValue(op.{0}())", name);
     } else if (name == "_resultType") {
       bs << "convertType(op.getResult().getType())";
     } else if (name == "_hasResult") {


        


More information about the Mlir-commits mailing list