[Mlir-commits] [mlir] 154cabe - [mlir][pdl][NFC] Extract the execution of each bytecode operation into its own function
River Riddle
llvmlistbot at llvm.org
Mon Feb 22 19:07:51 PST 2021
Author: River Riddle
Date: 2021-02-22T19:02:48-08:00
New Revision: 154cabe722de4d9837d49790e913d2b511f17d70
URL: https://github.com/llvm/llvm-project/commit/154cabe722de4d9837d49790e913d2b511f17d70
DIFF: https://github.com/llvm/llvm-project/commit/154cabe722de4d9837d49790e913d2b511f17d70.diff
LOG: [mlir][pdl][NFC] Extract the execution of each bytecode operation into its own function
This makes the implementation of each bytecode operation much easier to reason about, and lets the compiler decide which implementations are beneficial to inline into the main switch.
Differential Revision: https://reviews.llvm.org/D95716
Added:
Modified:
mlir/lib/Rewrite/ByteCode.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 11a9db79b322..04b965869fb7 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -732,6 +732,34 @@ class ByteCodeExecutor {
Optional<Location> mainRewriteLoc = {});
private:
+ /// Internal implementation of executing each of the bytecode commands.
+ void executeApplyConstraint(PatternRewriter &rewriter);
+ void executeApplyRewrite(PatternRewriter &rewriter);
+ void executeAreEqual();
+ void executeBranch();
+ void executeCheckOperandCount();
+ void executeCheckOperationName();
+ void executeCheckResultCount();
+ void executeCreateNative(PatternRewriter &rewriter);
+ void executeCreateOperation(PatternRewriter &rewriter,
+ Location mainRewriteLoc);
+ void executeEraseOp(PatternRewriter &rewriter);
+ void executeGetAttribute();
+ void executeGetAttributeType();
+ void executeGetDefiningOp();
+ void executeGetOperand(unsigned index);
+ void executeGetResult(unsigned index);
+ void executeGetValueType();
+ void executeIsNotNull();
+ void executeRecordMatch(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> &matches);
+ void executeReplaceOp(PatternRewriter &rewriter);
+ void executeSwitchAttribute();
+ void executeSwitchOperandCount();
+ void executeSwitchOperationName();
+ void executeSwitchResultCount();
+ void executeSwitchType();
+
/// Read a value from the bytecode buffer, optionally skipping a certain
/// number of prefix values. These methods always update the buffer to point
/// to the next field after the read data.
@@ -764,7 +792,7 @@ class ByteCodeExecutor {
llvm::dbgs() << " * Value: " << value << "\n"
<< " * Cases: ";
llvm::interleaveComma(cases, llvm::dbgs());
- llvm::dbgs() << "\n\n";
+ llvm::dbgs() << "\n";
});
// Check to see if the attribute value is within the case list. Jump to
@@ -843,6 +871,353 @@ class ByteCodeExecutor {
};
} // end anonymous namespace
+void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
+ const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
+ ArrayAttr constParams = read<ArrayAttr>();
+ SmallVector<PDLValue, 16> args;
+ readList<PDLValue>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Arguments: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
+ });
+
+ // Invoke the constraint and jump to the proper destination.
+ selectJump(succeeded(constraintFn(args, constParams, rewriter)));
+}
+
+void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
+ const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
+ ArrayAttr constParams = read<ArrayAttr>();
+ Operation *root = read<Operation *>();
+ SmallVector<PDLValue, 16> args;
+ readList<PDLValue>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Root: " << *root << "\n * Arguments: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
+ });
+
+ // Invoke the native rewrite function.
+ rewriteFn(root, args, constParams, rewriter);
+}
+
+void ByteCodeExecutor::executeAreEqual() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ const void *lhs = read<const void *>();
+ const void *rhs = read<const void *>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
+ selectJump(lhs == rhs);
+}
+
+void ByteCodeExecutor::executeBranch() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
+ curCodeIt = &code[read<ByteCodeAddr>()];
+}
+
+void ByteCodeExecutor::executeCheckOperandCount() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
+ Operation *op = read<Operation *>();
+ uint32_t expectedCount = read<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
+ << " * Expected: " << expectedCount << "\n");
+ selectJump(op->getNumOperands() == expectedCount);
+}
+
+void ByteCodeExecutor::executeCheckOperationName() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
+ Operation *op = read<Operation *>();
+ OperationName expectedName = read<OperationName>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
+ << " * Expected: \"" << expectedName << "\"\n");
+ selectJump(op->getName() == expectedName);
+}
+
+void ByteCodeExecutor::executeCheckResultCount() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
+ Operation *op = read<Operation *>();
+ uint32_t expectedCount = read<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
+ << " * Expected: " << expectedCount << "\n");
+ selectJump(op->getNumResults() == expectedCount);
+}
+
+void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
+ const PDLCreateFunction &createFn = createFunctions[read()];
+ ByteCodeField resultIndex = read();
+ ArrayAttr constParams = read<ArrayAttr>();
+ SmallVector<PDLValue, 16> args;
+ readList<PDLValue>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Arguments: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
+ });
+
+ PDLValue result = createFn(args, constParams, rewriter);
+ memory[resultIndex] = result.getAsOpaquePointer();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
+}
+
+void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
+ Location mainRewriteLoc) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
+
+ unsigned memIndex = read();
+ OperationState state(mainRewriteLoc, read<OperationName>());
+ readList<Value>(state.operands);
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ Identifier name = read<Identifier>();
+ if (Attribute attr = read<Attribute>())
+ state.addAttribute(name, attr);
+ }
+
+ bool hasInferredTypes = false;
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ Type resultType = read<Type>();
+ hasInferredTypes |= !resultType;
+ state.types.push_back(resultType);
+ }
+
+ // Handle the case where the operation has inferred types.
+ if (hasInferredTypes) {
+ InferTypeOpInterface::Concept *concept =
+ state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
+
+ // TODO: Handle failure.
+ SmallVector<Type, 2> inferredTypes;
+ if (failed(concept->inferReturnTypes(
+ state.getContext(), state.location, state.operands,
+ state.attributes.getDictionary(state.getContext()), state.regions,
+ inferredTypes)))
+ return;
+
+ for (unsigned i = 0, e = state.types.size(); i != e; ++i)
+ if (!state.types[i])
+ state.types[i] = inferredTypes[i];
+ }
+ Operation *resultOp = rewriter.createOperation(state);
+ memory[memIndex] = resultOp;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Attributes: "
+ << state.attributes.getDictionary(state.getContext())
+ << "\n * Operands: ";
+ llvm::interleaveComma(state.operands, llvm::dbgs());
+ llvm::dbgs() << "\n * Result Types: ";
+ llvm::interleaveComma(state.types, llvm::dbgs());
+ llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
+ });
+}
+
+void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
+ Operation *op = read<Operation *>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ rewriter.eraseOp(op);
+}
+
+void ByteCodeExecutor::executeGetAttribute() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
+ unsigned memIndex = read();
+ Operation *op = read<Operation *>();
+ Identifier attrName = read<Identifier>();
+ Attribute attr = op->getAttr(attrName);
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Attribute: " << attrName << "\n"
+ << " * Result: " << attr << "\n");
+ memory[memIndex] = attr.getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetAttributeType() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
+ unsigned memIndex = read();
+ Attribute attr = read<Attribute>();
+ Type type = attr ? attr.getType() : Type();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
+ << " * Result: " << type << "\n");
+ memory[memIndex] = type.getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetDefiningOp() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
+ unsigned memIndex = read();
+ Value value = read<Value>();
+ Operation *op = value ? value.getDefiningOp() : nullptr;
+
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Result: " << *op << "\n");
+ memory[memIndex] = op;
+}
+
+void ByteCodeExecutor::executeGetOperand(unsigned index) {
+ Operation *op = read<Operation *>();
+ unsigned memIndex = read();
+ Value operand =
+ index < op->getNumOperands() ? op->getOperand(index) : Value();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Index: " << index << "\n"
+ << " * Result: " << operand << "\n");
+ memory[memIndex] = operand.getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetResult(unsigned index) {
+ Operation *op = read<Operation *>();
+ unsigned memIndex = read();
+ OpResult result =
+ index < op->getNumResults() ? op->getResult(index) : OpResult();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Index: " << index << "\n"
+ << " * Result: " << result << "\n");
+ memory[memIndex] = result.getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetValueType() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
+ unsigned memIndex = read();
+ Value value = read<Value>();
+ Type type = value ? value.getType() : Type();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Result: " << type << "\n");
+ memory[memIndex] = type.getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeIsNotNull() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
+ const void *value = read<const void *>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ selectJump(value != nullptr);
+}
+
+void ByteCodeExecutor::executeRecordMatch(
+ PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
+ unsigned patternIndex = read();
+ PatternBenefit benefit = currentPatternBenefits[patternIndex];
+ const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
+
+ // If the benefit of the pattern is impossible, skip the processing of the
+ // rest of the pattern.
+ if (benefit.isImpossibleToMatch()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
+ curCodeIt = dest;
+ return;
+ }
+
+ // Create a fused location containing the locations of each of the
+ // operations used in the match. This will be used as the location for
+ // created operations during the rewrite that don't already have an
+ // explicit location set.
+ unsigned numMatchLocs = read();
+ SmallVector<Location, 4> matchLocs;
+ matchLocs.reserve(numMatchLocs);
+ for (unsigned i = 0; i != numMatchLocs; ++i)
+ matchLocs.push_back(read<Operation *>()->getLoc());
+ Location matchLoc = rewriter.getFusedLoc(matchLocs);
+
+ LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
+ << " * Location: " << matchLoc << "\n");
+ matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
+ readList<const void *>(matches.back().values);
+ curCodeIt = dest;
+}
+
+void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
+ Operation *op = read<Operation *>();
+ SmallVector<Value, 16> args;
+ readList<Value>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Values: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ rewriter.replaceOp(op, args);
+}
+
+void ByteCodeExecutor::executeSwitchAttribute() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
+ Attribute value = read<Attribute>();
+ ArrayAttr cases = read<ArrayAttr>();
+ handleSwitch(value, cases);
+}
+
+void ByteCodeExecutor::executeSwitchOperandCount() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
+ Operation *op = read<Operation *>();
+ auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ handleSwitch(op->getNumOperands(), cases);
+}
+
+void ByteCodeExecutor::executeSwitchOperationName() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
+ OperationName value = read<Operation *>()->getName();
+ size_t caseCount = read();
+
+ // The operation names are stored in-line, so to print them out for
+ // debugging purposes we need to read the array before executing the
+ // switch so that we can display all of the possible values.
+ LLVM_DEBUG({
+ const ByteCodeField *prevCodeIt = curCodeIt;
+ llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Cases: ";
+ llvm::interleaveComma(
+ llvm::map_range(llvm::seq<size_t>(0, caseCount),
+ [&](size_t) { return read<OperationName>(); }),
+ llvm::dbgs());
+ llvm::dbgs() << "\n";
+ curCodeIt = prevCodeIt;
+ });
+
+ // Try to find the switch value within any of the cases.
+ for (size_t i = 0; i != caseCount; ++i) {
+ if (read<OperationName>() == value) {
+ curCodeIt += (caseCount - i - 1);
+ return selectJump(i + 1);
+ }
+ }
+ selectJump(size_t(0));
+}
+
+void ByteCodeExecutor::executeSwitchResultCount() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
+ Operation *op = read<Operation *>();
+ auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ handleSwitch(op->getNumResults(), cases);
+}
+
+void ByteCodeExecutor::executeSwitchType() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
+ Type value = read<Type>();
+ auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
+ handleSwitch(value, cases);
+}
+
void ByteCodeExecutor::execute(
PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> *matches,
@@ -850,383 +1225,105 @@ void ByteCodeExecutor::execute(
while (true) {
OpCode opCode = static_cast<OpCode>(read());
switch (opCode) {
- case ApplyConstraint: {
- LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
- const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
- ArrayAttr constParams = read<ArrayAttr>();
- SmallVector<PDLValue, 16> args;
- readList<PDLValue>(args);
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
- });
-
- // Invoke the constraint and jump to the proper destination.
- selectJump(succeeded(constraintFn(args, constParams, rewriter)));
+ case ApplyConstraint:
+ executeApplyConstraint(rewriter);
break;
- }
- case ApplyRewrite: {
- LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
- const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
- ArrayAttr constParams = read<ArrayAttr>();
- Operation *root = read<Operation *>();
- SmallVector<PDLValue, 16> args;
- readList<PDLValue>(args);
-
- LLVM_DEBUG({
- llvm::dbgs() << " * Root: " << *root << "\n"
- << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
- });
- rewriteFn(root, args, constParams, rewriter);
+ case ApplyRewrite:
+ executeApplyRewrite(rewriter);
break;
- }
- case AreEqual: {
- LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
- const void *lhs = read<const void *>();
- const void *rhs = read<const void *>();
-
- LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
- selectJump(lhs == rhs);
+ case AreEqual:
+ executeAreEqual();
break;
- }
- case Branch: {
- LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
- curCodeIt = &code[read<ByteCodeAddr>()];
+ case Branch:
+ executeBranch();
break;
- }
- case CheckOperandCount: {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
- Operation *op = read<Operation *>();
- uint32_t expectedCount = read<uint32_t>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
- << " * Expected: " << expectedCount << "\n\n");
- selectJump(op->getNumOperands() == expectedCount);
+ case CheckOperandCount:
+ executeCheckOperandCount();
break;
- }
- case CheckOperationName: {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
- Operation *op = read<Operation *>();
- OperationName expectedName = read<OperationName>();
-
- LLVM_DEBUG(llvm::dbgs()
- << " * Found: \"" << op->getName() << "\"\n"
- << " * Expected: \"" << expectedName << "\"\n\n");
- selectJump(op->getName() == expectedName);
+ case CheckOperationName:
+ executeCheckOperationName();
break;
- }
- case CheckResultCount: {
- LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
- Operation *op = read<Operation *>();
- uint32_t expectedCount = read<uint32_t>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
- << " * Expected: " << expectedCount << "\n\n");
- selectJump(op->getNumResults() == expectedCount);
+ case CheckResultCount:
+ executeCheckResultCount();
break;
- }
- case CreateNative: {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
- const PDLCreateFunction &createFn = createFunctions[read()];
- ByteCodeField resultIndex = read();
- ArrayAttr constParams = read<ArrayAttr>();
- SmallVector<PDLValue, 16> args;
- readList<PDLValue>(args);
-
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
- });
-
- PDLValue result = createFn(args, constParams, rewriter);
- memory[resultIndex] = result.getAsOpaquePointer();
-
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n");
+ case CreateNative:
+ executeCreateNative(rewriter);
break;
- }
- case CreateOperation: {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
- assert(mainRewriteLoc && "expected rewrite loc to be provided when "
- "executing the rewriter bytecode");
-
- unsigned memIndex = read();
- OperationState state(*mainRewriteLoc, read<OperationName>());
- readList<Value>(state.operands);
- for (unsigned i = 0, e = read(); i != e; ++i) {
- Identifier name = read<Identifier>();
- if (Attribute attr = read<Attribute>())
- state.addAttribute(name, attr);
- }
-
- bool hasInferredTypes = false;
- for (unsigned i = 0, e = read(); i != e; ++i) {
- Type resultType = read<Type>();
- hasInferredTypes |= !resultType;
- state.types.push_back(resultType);
- }
-
- // Handle the case where the operation has inferred types.
- if (hasInferredTypes) {
- InferTypeOpInterface::Concept *concept =
- state.name.getAbstractOperation()
- ->getInterface<InferTypeOpInterface>();
-
- // TODO: Handle failure.
- SmallVector<Type, 2> inferredTypes;
- if (failed(concept->inferReturnTypes(
- state.getContext(), state.location, state.operands,
- state.attributes.getDictionary(state.getContext()),
- state.regions, inferredTypes)))
- return;
-
- for (unsigned i = 0, e = state.types.size(); i != e; ++i)
- if (!state.types[i])
- state.types[i] = inferredTypes[i];
- }
- Operation *resultOp = rewriter.createOperation(state);
- memory[memIndex] = resultOp;
-
- LLVM_DEBUG({
- llvm::dbgs() << " * Attributes: "
- << state.attributes.getDictionary(state.getContext())
- << "\n * Operands: ";
- llvm::interleaveComma(state.operands, llvm::dbgs());
- llvm::dbgs() << "\n * Result Types: ";
- llvm::interleaveComma(state.types, llvm::dbgs());
- llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n";
- });
+ case CreateOperation:
+ executeCreateOperation(rewriter, *mainRewriteLoc);
break;
- }
- case EraseOp: {
- LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
- Operation *op = read<Operation *>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n");
- rewriter.eraseOp(op);
+ case EraseOp:
+ executeEraseOp(rewriter);
break;
- }
- case Finalize: {
+ case Finalize:
LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
return;
- }
- case GetAttribute: {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
- unsigned memIndex = read();
- Operation *op = read<Operation *>();
- Identifier attrName = read<Identifier>();
- Attribute attr = op->getAttr(attrName);
-
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Attribute: " << attrName << "\n"
- << " * Result: " << attr << "\n\n");
- memory[memIndex] = attr.getAsOpaquePointer();
+ case GetAttribute:
+ executeGetAttribute();
break;
- }
- case GetAttributeType: {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
- unsigned memIndex = read();
- Attribute attr = read<Attribute>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
- << " * Result: " << attr.getType() << "\n\n");
- memory[memIndex] = attr.getType().getAsOpaquePointer();
+ case GetAttributeType:
+ executeGetAttributeType();
break;
- }
- case GetDefiningOp: {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
- unsigned memIndex = read();
- Value value = read<Value>();
- Operation *op = value ? value.getDefiningOp() : nullptr;
-
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
- << " * Result: " << *op << "\n\n");
- memory[memIndex] = op;
+ case GetDefiningOp:
+ executeGetDefiningOp();
break;
- }
case GetOperand0:
case GetOperand1:
case GetOperand2:
- case GetOperand3:
- case GetOperandN: {
- LLVM_DEBUG({
- llvm::dbgs() << "Executing GetOperand"
- << (opCode == GetOperandN ? Twine("N")
- : Twine(opCode - GetOperand0))
- << ":\n";
- });
- unsigned index =
- opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
- Operation *op = read<Operation *>();
- unsigned memIndex = read();
- Value operand =
- index < op->getNumOperands() ? op->getOperand(index) : Value();
-
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Index: " << index << "\n"
- << " * Result: " << operand << "\n\n");
- memory[memIndex] = operand.getAsOpaquePointer();
+ case GetOperand3: {
+ unsigned index = opCode - GetOperand0;
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
+ executeGetOperand(opCode - GetOperand0);
break;
}
+ case GetOperandN:
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
+ executeGetOperand(read<uint32_t>());
+ break;
case GetResult0:
case GetResult1:
case GetResult2:
- case GetResult3:
- case GetResultN: {
- LLVM_DEBUG({
- llvm::dbgs() << "Executing GetResult"
- << (opCode == GetResultN ? Twine("N")
- : Twine(opCode - GetResult0))
- << ":\n";
- });
- unsigned index =
- opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
- Operation *op = read<Operation *>();
- unsigned memIndex = read();
- OpResult result =
- index < op->getNumResults() ? op->getResult(index) : OpResult();
-
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Index: " << index << "\n"
- << " * Result: " << result << "\n\n");
- memory[memIndex] = result.getAsOpaquePointer();
+ case GetResult3: {
+ unsigned index = opCode - GetResult0;
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
+ executeGetResult(opCode - GetResult0);
break;
}
- case GetValueType: {
- LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
- unsigned memIndex = read();
- Value value = read<Value>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
- << " * Result: " << value.getType() << "\n\n");
- memory[memIndex] = value.getType().getAsOpaquePointer();
+ case GetResultN:
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
+ executeGetResult(read<uint32_t>());
break;
- }
- case IsNotNull: {
- LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
- const void *value = read<const void *>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n");
- selectJump(value != nullptr);
+ case GetValueType:
+ executeGetValueType();
break;
- }
- case RecordMatch: {
- LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
+ case IsNotNull:
+ executeIsNotNull();
+ break;
+ case RecordMatch:
assert(matches &&
"expected matches to be provided when executing the matcher");
- unsigned patternIndex = read();
- PatternBenefit benefit = currentPatternBenefits[patternIndex];
- const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
-
- // If the benefit of the pattern is impossible, skip the processing of the
- // rest of the pattern.
- if (benefit.isImpossibleToMatch()) {
- LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n");
- curCodeIt = dest;
- break;
- }
-
- // Create a fused location containing the locations of each of the
- // operations used in the match. This will be used as the location for
- // created operations during the rewrite that don't already have an
- // explicit location set.
- unsigned numMatchLocs = read();
- SmallVector<Location, 4> matchLocs;
- matchLocs.reserve(numMatchLocs);
- for (unsigned i = 0; i != numMatchLocs; ++i)
- matchLocs.push_back(read<Operation *>()->getLoc());
- Location matchLoc = rewriter.getFusedLoc(matchLocs);
-
- LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
- << " * Location: " << matchLoc << "\n\n");
- matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
- readList<const void *>(matches->back().values);
- curCodeIt = dest;
+ executeRecordMatch(rewriter, *matches);
break;
- }
- case ReplaceOp: {
- LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
- Operation *op = read<Operation *>();
- SmallVector<Value, 16> args;
- readList<Value>(args);
-
- LLVM_DEBUG({
- llvm::dbgs() << " * Operation: " << *op << "\n"
- << " * Values: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n\n";
- });
- rewriter.replaceOp(op, args);
+ case ReplaceOp:
+ executeReplaceOp(rewriter);
break;
- }
- case SwitchAttribute: {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
- Attribute value = read<Attribute>();
- ArrayAttr cases = read<ArrayAttr>();
- handleSwitch(value, cases);
+ case SwitchAttribute:
+ executeSwitchAttribute();
break;
- }
- case SwitchOperandCount: {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
- Operation *op = read<Operation *>();
- auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
- handleSwitch(op->getNumOperands(), cases);
+ case SwitchOperandCount:
+ executeSwitchOperandCount();
break;
- }
- case SwitchOperationName: {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
- OperationName value = read<Operation *>()->getName();
- size_t caseCount = read();
-
- // The operation names are stored in-line, so to print them out for
- // debugging purposes we need to read the array before executing the
- // switch so that we can display all of the possible values.
- LLVM_DEBUG({
- const ByteCodeField *prevCodeIt = curCodeIt;
- llvm::dbgs() << " * Value: " << value << "\n"
- << " * Cases: ";
- llvm::interleaveComma(
- llvm::map_range(llvm::seq<size_t>(0, caseCount),
- [&](size_t i) { return read<OperationName>(); }),
- llvm::dbgs());
- llvm::dbgs() << "\n\n";
- curCodeIt = prevCodeIt;
- });
-
- // Try to find the switch value within any of the cases.
- size_t jumpDest = 0;
- for (size_t i = 0; i != caseCount; ++i) {
- if (read<OperationName>() == value) {
- curCodeIt += (caseCount - i - 1);
- jumpDest = i + 1;
- break;
- }
- }
- selectJump(jumpDest);
+ case SwitchOperationName:
+ executeSwitchOperationName();
break;
- }
- case SwitchResultCount: {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
- Operation *op = read<Operation *>();
- auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
-
- LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
- handleSwitch(op->getNumResults(), cases);
+ case SwitchResultCount:
+ executeSwitchResultCount();
break;
- }
- case SwitchType: {
- LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
- Type value = read<Type>();
- auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
- handleSwitch(value, cases);
+ case SwitchType:
+ executeSwitchType();
break;
}
- }
+ LLVM_DEBUG(llvm::dbgs() << "\n");
}
}
More information about the Mlir-commits
mailing list