[Mlir-commits] [mlir] 3eb1647 - Introduced iterative bytecode execution.
Uday Bondhugula
llvmlistbot at llvm.org
Fri Nov 26 04:43:49 PST 2021
Author: Stanislav Funiak
Date: 2021-11-26T18:11:37+05:30
New Revision: 3eb1647af036dc0e8370ed5a8b1ecbb5701f850b
URL: https://github.com/llvm/llvm-project/commit/3eb1647af036dc0e8370ed5a8b1ecbb5701f850b
DIFF: https://github.com/llvm/llvm-project/commit/3eb1647af036dc0e8370ed5a8b1ecbb5701f850b.diff
LOG: Introduced iterative bytecode execution.
This is commit 2 of 4 for the multi-root matching in PDL, discussed in https://llvm.discourse.group/t/rfc-multi-root-pdl-patterns-for-kernel-matching/4148 (topic flagged for review).
This commit implements the features needed for the execution of the new operations pdl_interp.get_accepting_ops, pdl_interp.choose_op:
1. The implementation of the generation and execution of the two ops.
2. The addition of Stack of bytecode positions within the ByteCodeExecutor. This is needed because in pdl_interp.choose_op, we iterate over the values returned by pdl_interp.get_accepting_ops until we reach finalize. When we reach finalize, we need to return back to the position marked in the stack.
3. The functionality to extend the lifetime of values that cross the nondeterministic choice. The existing bytecode generator allocates the values to memory positions by representing the liveness of values as a collection of disjoint intervals over the matcher positions. This is akin to register allocation, and substantially reduces the footprint of the bytecode executor. However, because with iterative operation pdl_interp.choose_op, execution "returns" back, so any values whose original liveness cross the nondeterminstic choice must have their lifetime executed until finalize.
Testing: pdl-bytecode.mlir test
Reviewed By: rriddle, Mogball
Differential Revision: https://reviews.llvm.org/D108547
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Rewrite/ByteCode.h
mlir/test/Rewrite/pdl-bytecode.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ba8d1455528cf..d02bda793b084 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -446,6 +446,9 @@ class PDLValue {
/// Print this value to the provided output stream.
void print(raw_ostream &os) const;
+ /// Print the specified value kind to an output stream.
+ static void print(raw_ostream &os, Kind kind);
+
private:
/// Find the index of a given type in a range of other types.
template <typename...>
@@ -491,6 +494,11 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
return os;
}
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
+ PDLValue::print(os, kind);
+ return os;
+}
+
//===----------------------------------------------------------------------===//
// PDLResultList
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 4482b5cb219af..39d8bad2bbd26 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -126,6 +126,29 @@ void PDLValue::print(raw_ostream &os) const {
}
}
+void PDLValue::print(raw_ostream &os, Kind kind) {
+ switch (kind) {
+ case Kind::Attribute:
+ os << "Attribute";
+ break;
+ case Kind::Operation:
+ os << "Operation";
+ break;
+ case Kind::Type:
+ os << "Type";
+ break;
+ case Kind::TypeRange:
+ os << "TypeRange";
+ break;
+ case Kind::Value:
+ os << "Value";
+ break;
+ case Kind::ValueRange:
+ os << "ValueRange";
+ break;
+ }
+}
+
//===----------------------------------------------------------------------===//
// PDLPatternModule
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 810bcf67bccb0..380f54ddc6cb1 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -95,14 +95,24 @@ enum OpCode : ByteCodeField {
CheckResultCount,
/// Compare a range of types to a constant range of types.
CheckTypes,
+ /// Continue to the next iteration of a loop.
+ Continue,
/// Create an operation.
CreateOperation,
/// Create a range of types.
CreateTypes,
/// Erase an operation.
EraseOp,
+ /// Extract the op from a range at the specified index.
+ ExtractOp,
+ /// Extract the type from a range at the specified index.
+ ExtractType,
+ /// Extract the value from a range at the specified index.
+ ExtractValue,
/// Terminate a matcher or rewrite sequence.
Finalize,
+ /// Iterate over a range of values.
+ ForEach,
/// Get a specific attribute of an operation.
GetAttribute,
/// Get the type of an attribute.
@@ -125,6 +135,8 @@ enum OpCode : ByteCodeField {
GetResultN,
/// Get a specific result group of an operation.
GetResults,
+ /// Get the users of a value or a range of values.
+ GetUsers,
/// Get the type of a value.
GetValueType,
/// Get the types of a value range.
@@ -158,8 +170,13 @@ enum OpCode : ByteCodeField {
// Generator
namespace {
+struct ByteCodeLiveRange;
struct ByteCodeWriter;
+/// Check if the given class `T` can be converted to an opaque pointer.
+template <typename T, typename... Args>
+using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
+
/// This class represents the main generator for the pattern bytecode.
class Generator {
public:
@@ -168,15 +185,19 @@ class Generator {
SmallVectorImpl<ByteCodeField> &rewriterByteCode,
SmallVectorImpl<PDLByteCodePattern> &patterns,
ByteCodeField &maxValueMemoryIndex,
+ ByteCodeField &maxOpRangeMemoryIndex,
ByteCodeField &maxTypeRangeMemoryIndex,
ByteCodeField &maxValueRangeMemoryIndex,
+ ByteCodeField &maxLoopLevel,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex),
+ maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
- maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
+ maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
+ maxLoopLevel(maxLoopLevel) {
for (auto it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(rewriteFns))
@@ -221,6 +242,7 @@ class Generator {
void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
/// Generate the bytecode for the given operation.
+ void generate(Region *region, ByteCodeWriter &writer);
void generate(Operation *op, ByteCodeWriter &writer);
void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
@@ -232,12 +254,15 @@ class Generator {
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
@@ -245,6 +270,7 @@ class Generator {
void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
@@ -279,17 +305,25 @@ class Generator {
/// `uniquedData`.
DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
+ /// The current level of the foreach loop.
+ ByteCodeField curLoopLevel = 0;
+
/// The current MLIR context.
MLIRContext *ctx;
+ /// Mapping from block to its address.
+ DenseMap<Block *, ByteCodeAddr> blockToAddr;
+
/// Data of the ByteCode class to be populated.
std::vector<const void *> &uniquedData;
SmallVectorImpl<ByteCodeField> &matcherByteCode;
SmallVectorImpl<ByteCodeField> &rewriterByteCode;
SmallVectorImpl<PDLByteCodePattern> &patterns;
ByteCodeField &maxValueMemoryIndex;
+ ByteCodeField &maxOpRangeMemoryIndex;
ByteCodeField &maxTypeRangeMemoryIndex;
ByteCodeField &maxValueRangeMemoryIndex;
+ ByteCodeField &maxLoopLevel;
};
/// This class provides utilities for writing a bytecode stream.
@@ -311,15 +345,20 @@ struct ByteCodeWriter {
bytecode.append({fieldParts[0], fieldParts[1]});
}
+ /// Append a single successor to the bytecode, the exact address will need to
+ /// be resolved later.
+ void append(Block *successor) {
+ // Add back a reference to the successor so that the address can be resolved
+ // later.
+ unresolvedSuccessorRefs[successor].push_back(bytecode.size());
+ append(ByteCodeAddr(0));
+ }
+
/// Append a successor range to the bytecode, the exact address will need to
/// be resolved later.
void append(SuccessorRange successors) {
- // Add back references to the any successors so that the address can be
- // resolved later.
- for (Block *successor : successors) {
- unresolvedSuccessorRefs[successor].push_back(bytecode.size());
- append(ByteCodeAddr(0));
- }
+ for (Block *successor : successors)
+ append(successor);
}
/// Append a range of values that will be read as generic PDLValues.
@@ -336,10 +375,12 @@ struct ByteCodeWriter {
}
/// Append the PDLValue::Kind of the given value.
- void appendPDLValueKind(Value value) {
- // Append the type of the value in addition to the value itself.
+ void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
+
+ /// Append the PDLValue::Kind of the given type.
+ void appendPDLValueKind(Type type) {
PDLValue::Kind kind =
- TypeSwitch<Type, PDLValue::Kind>(value.getType())
+ TypeSwitch<Type, PDLValue::Kind>(type)
.Case<pdl::AttributeType>(
[](Type) { return PDLValue::Kind::Attribute; })
.Case<pdl::OperationType>(
@@ -354,10 +395,6 @@ struct ByteCodeWriter {
bytecode.push_back(static_cast<ByteCodeField>(kind));
}
- /// Check if the given class `T` has an iterator type.
- template <typename T, typename... Args>
- using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
-
/// Append a value that will be stored in a memory slot and not inline within
/// the bytecode.
template <typename T>
@@ -396,25 +433,34 @@ struct ByteCodeWriter {
/// This class represents a live range of PDL Interpreter values, containing
/// information about when values are live within a match/rewrite.
struct ByteCodeLiveRange {
- using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
+ using Set = llvm::IntervalMap<uint64_t, char, 16>;
using Allocator = Set::Allocator;
- ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
+ ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
/// Union this live range with the one provided.
void unionWith(const ByteCodeLiveRange &rhs) {
- for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
- liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
+ ++it)
+ liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
}
/// Returns true if this range overlaps with the one provided.
bool overlaps(const ByteCodeLiveRange &rhs) const {
- return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
+ return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
+ .valid();
}
/// A map representing the ranges of the match/rewrite that a value is live in
/// the interpreter.
- llvm::IntervalMap<ByteCodeField, char, 16> liveness;
+ ///
+ /// We use std::unique_ptr here, because IntervalMap does not provide a
+ /// correct copy or move constructor. We can eliminate the pointer once
+ /// https://reviews.llvm.org/D113240 lands.
+ std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
+
+ /// The operation range storage index for this range.
+ Optional<unsigned> opRangeIndex;
/// The type range storage index for this range.
Optional<unsigned> typeRangeIndex;
@@ -446,15 +492,8 @@ void Generator::generate(ModuleOp module) {
"unexpected branches in rewriter function");
// Generate code for the matcher function.
- DenseMap<Block *, ByteCodeAddr> blockToAddr;
- llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
- for (Block *block : rpot) {
- // Keep track of where this block begins within the matcher function.
- blockToAddr.try_emplace(block, matcherByteCode.size());
- for (Operation &op : *block)
- generate(&op, matcherByteCodeWriter);
- }
+ generate(&matcherFunc.getBody(), matcherByteCodeWriter);
// Resolve successor references in the matcher.
for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
@@ -501,7 +540,7 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// finding the minimal number of overlapping live ranges. This is essentially
// a simplified form of register allocation where we don't necessarily have a
// limited number of registers, but we still want to minimize the number used.
- DenseMap<Operation *, ByteCodeField> opToIndex;
+ DenseMap<Operation *, unsigned> opToIndex;
matcherFunc.getBody().walk([&](Operation *op) {
opToIndex.insert(std::make_pair(op, opToIndex.size()));
});
@@ -516,8 +555,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Walk each of the blocks, computing the def interval that the value is used.
Liveness matcherLiveness(matcherFunc);
- for (Block &block : matcherFunc.getBody()) {
- const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
+ matcherFunc->walk([&](Block *block) {
+ const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
assert(info && "expected liveness info for block");
auto processValue = [&](Value value, Operation *firstUseOrDef) {
// We don't need to process the root op argument, this value is always
@@ -527,7 +566,7 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Set indices for the range of this block that the value is used.
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
- defRangeIt->second.liveness.insert(
+ defRangeIt->second.liveness->insert(
opToIndex[firstUseOrDef],
opToIndex[info->getEndOperation(value, firstUseOrDef)],
/*dummyValue*/ 0);
@@ -535,7 +574,9 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Check to see if this value is a range type.
if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
Type eleType = rangeTy.getElementType();
- if (eleType.isa<pdl::TypeType>())
+ if (eleType.isa<pdl::OperationType>())
+ defRangeIt->second.opRangeIndex = 0;
+ else if (eleType.isa<pdl::TypeType>())
defRangeIt->second.typeRangeIndex = 0;
else if (eleType.isa<pdl::ValueType>())
defRangeIt->second.valueRangeIndex = 0;
@@ -543,18 +584,37 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
};
// Process the live-ins of this block.
- for (Value liveIn : info->in())
- processValue(liveIn, &block.front());
+ for (Value liveIn : info->in()) {
+ // Only process the value if it has been defined in the current region.
+ // Other values that span across pdl_interp.foreach will be added higher
+ // up. This ensures that the we keep them alive for the entire duration
+ // of the loop.
+ if (liveIn.getParentRegion() == block->getParent())
+ processValue(liveIn, &block->front());
+ }
+
+ // Process the block arguments for the entry block (those are not live-in).
+ if (block->isEntryBlock()) {
+ for (Value argument : block->getArguments())
+ processValue(argument, &block->front());
+ }
// Process any new defs within this block.
- for (Operation &op : block)
+ for (Operation &op : *block)
for (Value result : op.getResults())
processValue(result, &op);
- }
+ });
// Greedily allocate memory slots using the computed def live ranges.
std::vector<ByteCodeLiveRange> allocatedIndices;
- ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
+
+ // The number of memory indices currently allocated (and its next value).
+ // Recall that the root gets allocated memory index 0.
+ ByteCodeField numIndices = 1;
+
+ // The number of memory ranges of various types (and their next values).
+ ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
+
for (auto &defIt : valueDefRanges) {
ByteCodeField &memIndex = valueToMemIndex[defIt.first];
ByteCodeLiveRange &defRange = defIt.second;
@@ -566,7 +626,11 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
existingRange.unionWith(defRange);
memIndex = existingIndexIt.index() + 1;
- if (defRange.typeRangeIndex) {
+ if (defRange.opRangeIndex) {
+ if (!existingRange.opRangeIndex)
+ existingRange.opRangeIndex = numOpRanges++;
+ valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
+ } else if (defRange.typeRangeIndex) {
if (!existingRange.typeRangeIndex)
existingRange.typeRangeIndex = numTypeRanges++;
valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
@@ -585,8 +649,11 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
ByteCodeLiveRange &newRange = allocatedIndices.back();
newRange.unionWith(defRange);
- // Allocate an index for type/value ranges.
- if (defRange.typeRangeIndex) {
+ // Allocate an index for op/type/value ranges.
+ if (defRange.opRangeIndex) {
+ newRange.opRangeIndex = numOpRanges;
+ valueToRangeIndex[defIt.first] = numOpRanges++;
+ } else if (defRange.typeRangeIndex) {
newRange.typeRangeIndex = numTypeRanges;
valueToRangeIndex[defIt.first] = numTypeRanges++;
} else if (defRange.valueRangeIndex) {
@@ -599,15 +666,35 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
}
}
+ // Print the index usage and ensure that we did not run out of index space.
+ LLVM_DEBUG({
+ llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
+ << "(down from initial " << valueDefRanges.size() << ").\n";
+ });
+ assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
+ "Ran out of memory for allocated indices");
+
// Update the max number of indices.
if (numIndices > maxValueMemoryIndex)
maxValueMemoryIndex = numIndices;
+ if (numOpRanges > maxOpRangeMemoryIndex)
+ maxOpRangeMemoryIndex = numOpRanges;
if (numTypeRanges > maxTypeRangeMemoryIndex)
maxTypeRangeMemoryIndex = numTypeRanges;
if (numValueRanges > maxValueRangeMemoryIndex)
maxValueRangeMemoryIndex = numValueRanges;
}
+void Generator::generate(Region *region, ByteCodeWriter &writer) {
+ llvm::ReversePostOrderTraversal<Region *> rpot(region);
+ for (Block *block : rpot) {
+ // Keep track of where this block begins within the matcher function.
+ blockToAddr.try_emplace(block, matcherByteCode.size());
+ for (Operation &op : *block)
+ generate(&op, writer);
+ }
+}
+
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
TypeSwitch<Operation *>(op)
.Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
@@ -615,13 +702,15 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
- pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
- pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
- pdl_interp::EraseOp, pdl_interp::FinalizeOp,
- pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
- pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
- pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
- pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
+ pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
+ pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+ pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
+ pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
+ pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
+ pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
+ pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
+ pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
+ pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
@@ -707,6 +796,10 @@ void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
}
+void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
+ assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
+ writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
+}
void Generator::generate(pdl_interp::CreateAttributeOp op,
ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
@@ -736,9 +829,31 @@ void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
writer.append(OpCode::EraseOp, op.operation());
}
+void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
+ OpCode opCode =
+ TypeSwitch<Type, OpCode>(op.result().getType())
+ .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
+ .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
+ .Case([](pdl::TypeType) { return OpCode::ExtractType; })
+ .Default([](Type) -> OpCode {
+ llvm_unreachable("unsupported element type");
+ });
+ writer.append(opCode, op.range(), op.index(), op.result());
+}
void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::Finalize);
}
+void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
+ BlockArgument arg = op.getLoopVariable();
+ writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
+ writer.appendPDLValueKind(arg.getType());
+ writer.append(curLoopLevel, op.successor());
+ ++curLoopLevel;
+ if (curLoopLevel > maxLoopLevel)
+ maxLoopLevel = curLoopLevel;
+ generate(&op.region(), writer);
+ --curLoopLevel;
+}
void Generator::generate(pdl_interp::GetAttributeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
@@ -793,6 +908,12 @@ void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
writer.append(std::numeric_limits<ByteCodeField>::max());
writer.append(result);
}
+void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
+ Value operations = op.operations();
+ ByteCodeField rangeIndex = getRangeStorageIndex(operations);
+ writer.append(OpCode::GetUsers, operations, rangeIndex);
+ writer.appendPDLValue(op.value());
+}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
if (op.getType().isa<pdl::RangeType>()) {
@@ -865,8 +986,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
- maxTypeRangeCount, maxValueRangeCount, constraintFns,
- rewriteFns);
+ maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
+ maxLoopLevel, constraintFns, rewriteFns);
generator.generate(module);
// Initialize the external functions.
@@ -880,8 +1001,10 @@ PDLByteCode::PDLByteCode(ModuleOp module,
/// bytecode.
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
state.memory.resize(maxValueMemoryIndex, nullptr);
+ state.opRangeMemory.resize(maxOpRangeCount);
state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
+ state.loopIndex.resize(maxLoopLevel, 0);
state.currentPatternBenefits.reserve(patterns.size());
for (const PDLByteCodePattern &pattern : patterns)
state.currentPatternBenefits.push_back(pattern.getBenefit());
@@ -896,20 +1019,23 @@ class ByteCodeExecutor {
public:
ByteCodeExecutor(
const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
+ MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
MutableArrayRef<TypeRange> typeRangeMemory,
std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
MutableArrayRef<ValueRange> valueRangeMemory,
std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
- ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
+ MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
+ ArrayRef<ByteCodeField> code,
ArrayRef<PatternBenefit> currentPatternBenefits,
ArrayRef<PDLByteCodePattern> patterns,
ArrayRef<PDLConstraintFunction> constraintFunctions,
ArrayRef<PDLRewriteFunction> rewriteFunctions)
- : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
+ : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
+ typeRangeMemory(typeRangeMemory),
allocatedTypeRangeMemory(allocatedTypeRangeMemory),
valueRangeMemory(valueRangeMemory),
allocatedValueRangeMemory(allocatedValueRangeMemory),
- uniquedMemory(uniquedMemory), code(code),
+ loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
currentPatternBenefits(currentPatternBenefits), patterns(patterns),
constraintFunctions(constraintFunctions),
rewriteFunctions(rewriteFunctions) {}
@@ -932,10 +1058,15 @@ class ByteCodeExecutor {
void executeCheckOperationName();
void executeCheckResultCount();
void executeCheckTypes();
+ void executeContinue();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
void executeCreateTypes();
void executeEraseOp(PatternRewriter &rewriter);
+ template <typename T, typename Range, PDLValue::Kind kind>
+ void executeExtract();
+ void executeFinalize();
+ void executeForEach();
void executeGetAttribute();
void executeGetAttributeType();
void executeGetDefiningOp();
@@ -943,6 +1074,7 @@ class ByteCodeExecutor {
void executeGetOperands();
void executeGetResult(unsigned index);
void executeGetResults();
+ void executeGetUsers();
void executeGetValueType();
void executeGetValueRangeTypes();
void executeIsNotNull();
@@ -956,6 +1088,16 @@ class ByteCodeExecutor {
void executeSwitchType();
void executeSwitchTypes();
+ /// Pushes a code iterator to the stack.
+ void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
+
+ /// Pops a code iterator from the stack, returning true on success.
+ void popCodeIt() {
+ assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
+ curCodeIt = resumeCodeIt.back();
+ resumeCodeIt.pop_back();
+ }
+
/// Read a value from the bytecode buffer, optionally skipping a certain
/// number of prefix values. These methods always update the buffer to point
/// to the next field after the read data.
@@ -1012,6 +1154,18 @@ class ByteCodeExecutor {
selectJump(size_t(0));
}
+ /// Store a pointer to memory.
+ void storeToMemory(unsigned index, const void *value) {
+ memory[index] = value;
+ }
+
+ /// Store a value to memory as an opaque pointer.
+ template <typename T>
+ std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
+ storeToMemory(unsigned index, T value) {
+ memory[index] = value.getAsOpaquePointer();
+ }
+
/// Internal implementation of reading various data types from the bytecode
/// stream.
template <typename T>
@@ -1076,13 +1230,20 @@ class ByteCodeExecutor {
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
+ /// The stack of bytecode positions at which to resume operation.
+ SmallVector<const ByteCodeField *> resumeCodeIt;
+
/// The current execution memory.
MutableArrayRef<const void *> memory;
+ MutableArrayRef<OwningOpRange> opRangeMemory;
MutableArrayRef<TypeRange> typeRangeMemory;
std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
MutableArrayRef<ValueRange> valueRangeMemory;
std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
+ /// The current loop indices.
+ MutableArrayRef<unsigned> loopIndex;
+
/// References to ByteCode data necessary for execution.
ArrayRef<const void *> uniquedMemory;
ArrayRef<ByteCodeField> code;
@@ -1277,6 +1438,14 @@ void ByteCodeExecutor::executeCheckTypes() {
selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
}
+void ByteCodeExecutor::executeContinue() {
+ ByteCodeField level = read();
+ LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
+ << " * Level: " << level << "\n");
+ ++loopIndex[level];
+ popCodeIt();
+}
+
void ByteCodeExecutor::executeCreateTypes() {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
unsigned memIndex = read();
@@ -1357,6 +1526,65 @@ void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
rewriter.eraseOp(op);
}
+template <typename T, typename Range, PDLValue::Kind kind>
+void ByteCodeExecutor::executeExtract() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
+ Range *range = read<Range *>();
+ unsigned index = read<uint32_t>();
+ unsigned memIndex = read();
+
+ if (!range) {
+ memory[memIndex] = nullptr;
+ return;
+ }
+
+ T result = index < range->size() ? (*range)[index] : T();
+ LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
+ << " * Index: " << index << "\n"
+ << " * Result: " << result << "\n");
+ storeToMemory(memIndex, result);
+}
+
+void ByteCodeExecutor::executeFinalize() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
+}
+
+void ByteCodeExecutor::executeForEach() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
+ // Subtract 1 for the op code.
+ const ByteCodeField *it = curCodeIt - 1;
+ unsigned rangeIndex = read();
+ unsigned memIndex = read();
+ const void *value = nullptr;
+
+ switch (read<PDLValue::Kind>()) {
+ case PDLValue::Kind::Operation: {
+ unsigned &index = loopIndex[read()];
+ ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
+ assert(index <= array.size() && "iterated past the end");
+ if (index < array.size()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
+ value = array[index];
+ break;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ index = 0;
+ selectJump(size_t(0));
+ return;
+ }
+ default:
+ llvm_unreachable("unexpected `ForEach` value kind");
+ }
+
+ // Store the iterate value and the stack address.
+ memory[memIndex] = value;
+ pushCodeIt(it);
+
+ // Skip over the successor (we will enter the body of the loop).
+ read<ByteCodeAddr>();
+}
+
void ByteCodeExecutor::executeGetAttribute() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
unsigned memIndex = read();
@@ -1421,7 +1649,7 @@ template <template <typename> class AttrSizedSegmentsT, typename RangeT>
static void *
executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
ByteCodeField rangeIndex, StringRef attrSizedSegments,
- MutableArrayRef<ValueRange> &valueRangeMemory) {
+ MutableArrayRef<ValueRange> valueRangeMemory) {
// Check for the sentinel index that signals that all values should be
// returned.
if (index == std::numeric_limits<uint32_t>::max()) {
@@ -1509,6 +1737,46 @@ void ByteCodeExecutor::executeGetResults() {
memory[read()] = result;
}
+void ByteCodeExecutor::executeGetUsers() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ OwningOpRange &range = opRangeMemory[rangeIndex];
+ memory[memIndex] = ⦥
+
+ range = OwningOpRange();
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+ // Read the value.
+ Value value = read<Value>();
+ if (!value)
+ return;
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+
+ // Extract the users of a single value.
+ range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
+ llvm::copy(value.getUsers(), range.begin());
+ } else {
+ // Read a range of values.
+ ValueRange *values = read<ValueRange *>();
+ if (!values)
+ return;
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Values (" << values->size() << "): ";
+ llvm::interleaveComma(*values, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ // Extract all the users of a range of values.
+ SmallVector<Operation *> users;
+ for (Value value : *values)
+ users.append(value.user_begin(), value.user_end());
+ range = OwningOpRange(users.size());
+ llvm::copy(users, range.begin());
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
+}
+
void ByteCodeExecutor::executeGetValueType() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
unsigned memIndex = read();
@@ -1731,6 +1999,9 @@ void ByteCodeExecutor::execute(
case CheckTypes:
executeCheckTypes();
break;
+ case Continue:
+ executeContinue();
+ break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
@@ -1740,9 +2011,22 @@ void ByteCodeExecutor::execute(
case EraseOp:
executeEraseOp(rewriter);
break;
+ case ExtractOp:
+ executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
+ break;
+ case ExtractType:
+ executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
+ break;
+ case ExtractValue:
+ executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
+ break;
case Finalize:
- LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
+ executeFinalize();
+ LLVM_DEBUG(llvm::dbgs() << "\n");
return;
+ case ForEach:
+ executeForEach();
+ break;
case GetAttribute:
executeGetAttribute();
break;
@@ -1784,6 +2068,9 @@ void ByteCodeExecutor::execute(
case GetResults:
executeGetResults();
break;
+ case GetUsers:
+ executeGetUsers();
+ break;
case GetValueType:
executeGetValueType();
break;
@@ -1834,11 +2121,11 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
// The matcher function always starts at code address 0.
ByteCodeExecutor executor(
- matcherByteCode.data(), state.memory, state.typeRangeMemory,
- state.allocatedTypeRangeMemory, state.valueRangeMemory,
- state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
- state.currentPatternBenefits, patterns, constraintFunctions,
- rewriteFunctions);
+ matcherByteCode.data(), state.memory, state.opRangeMemory,
+ state.typeRangeMemory, state.allocatedTypeRangeMemory,
+ state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
+ uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
+ constraintFunctions, rewriteFunctions);
executor.execute(rewriter, &matches);
// Order the found matches by benefit.
@@ -1857,8 +2144,9 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
ByteCodeExecutor executor(
&rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
- state.typeRangeMemory, state.allocatedTypeRangeMemory,
- state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
+ state.opRangeMemory, state.typeRangeMemory,
+ state.allocatedTypeRangeMemory, state.valueRangeMemory,
+ state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location);
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index 941b782dcd4e7..a26c3b295de96 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -28,6 +28,7 @@ class PDLByteCode;
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
+using OwningOpRange = llvm::OwningArrayRef<Operation *>;
//===----------------------------------------------------------------------===//
// PDLByteCodePattern
@@ -79,6 +80,12 @@ class PDLByteCodeMutableState {
/// of the bytecode.
std::vector<const void *> memory;
+ /// A mutable block of memory used during the matching and rewriting phase of
+ /// the bytecode to store ranges of operations. These are always stored by
+ /// owning references, because at no point in the execution of the byte code
+ /// we get an indexed range (view) of operations.
+ std::vector<OwningOpRange> opRangeMemory;
+
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
@@ -93,6 +100,11 @@ class PDLByteCodeMutableState {
/// interpreter to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
+ /// The current index of ranges being iterated over for each level of nesting.
+ /// These are always maintained at 0 for the loops that are not active, so we
+ /// do not need to have a separate initialization phase for each loop.
+ std::vector<unsigned> loopIndex;
+
/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
@@ -188,8 +200,12 @@ class PDLByteCode {
ByteCodeField maxValueMemoryIndex = 0;
/// The maximum number of
diff erent types of ranges.
+ ByteCodeField maxOpRangeCount = 0;
ByteCodeField maxTypeRangeCount = 0;
ByteCodeField maxValueRangeCount = 0;
+
+ /// The maximum number of nested loops.
+ ByteCodeField maxLoopLevel = 0;
};
} // end namespace detail
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index d630fa2aa14db..1dc7568d633cd 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -514,6 +514,12 @@ module @ir attributes { test.check_types_1 } {
// -----
+//===----------------------------------------------------------------------===//
+// pdl_interp::ContinueOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
@@ -576,12 +582,277 @@ module @ir attributes { test.create_type_1 } {
// Fully tested within the tests for other operations.
+//===----------------------------------------------------------------------===//
+// pdl_interp::ExtractOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %val = pdl_interp.get_result 0 of %root
+ %ops = pdl_interp.get_users of %val : !pdl.value
+ %op1 = pdl_interp.extract 1 of %ops : !pdl.operation
+ pdl_interp.is_not_null %op1 : !pdl.operation -> ^success, ^end
+ ^success:
+ pdl_interp.record_match @rewriters::@success(%op1 : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.extract_op
+// CHECK: "test.success"
+// CHECK: %[[OPERAND:.*]] = "test.op"
+// CHECK: "test.op"(%[[OPERAND]])
+module @ir attributes { test.extract_op } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32)
+ "test.op"(%operand, %operand) : (i32, i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %vals : !pdl.range<type>
+ %type1 = pdl_interp.extract 1 of %types : !pdl.type
+ pdl_interp.is_not_null %type1 : !pdl.type -> ^success, ^end
+ ^success:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.extract_type
+// CHECK: %[[OPERAND:.*]] = "test.op"
+// CHECK: "test.success"
+// CHECK: "test.op"(%[[OPERAND]])
+module @ir attributes { test.extract_type } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32, i32)
+ "test.op"(%operand) : (i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %val1 = pdl_interp.extract 1 of %vals : !pdl.value
+ pdl_interp.is_not_null %val1 : !pdl.value -> ^success, ^end
+ ^success:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.extract_value
+// CHECK: %[[OPERAND:.*]] = "test.op"
+// CHECK: "test.success"
+// CHECK: "test.op"(%[[OPERAND]])
+module @ir attributes { test.extract_value } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32, i32)
+ "test.op"(%operand) : (i32) -> (i32)
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::FinalizeOp
//===----------------------------------------------------------------------===//
// Fully tested within the tests for other operations.
+//===----------------------------------------------------------------------===//
+// pdl_interp::ForEachOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %val1 = pdl_interp.get_result 0 of %root
+ %ops1 = pdl_interp.get_users of %val1 : !pdl.value
+ pdl_interp.foreach %op1 : !pdl.operation in %ops1 {
+ %val2 = pdl_interp.get_result 0 of %op1
+ %ops2 = pdl_interp.get_users of %val2 : !pdl.value
+ pdl_interp.foreach %op2 : !pdl.operation in %ops2 {
+ pdl_interp.record_match @rewriters::@success(%op2 : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.foreach
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: %[[ROOT:.*]] = "test.op"
+// CHECK: %[[VALA:.*]] = "test.op"(%[[ROOT]])
+// CHECK: %[[VALB:.*]] = "test.op"(%[[ROOT]])
+module @ir attributes { test.foreach } {
+ %root = "test.op"() : () -> i32
+ %valA = "test.op"(%root) : (i32) -> (i32)
+ "test.op"(%valA) : (i32) -> (i32)
+ "test.op"(%valA) : (i32) -> (i32)
+ %valB = "test.op"(%root) : (i32) -> (i32)
+ "test.op"(%valB) : (i32) -> (i32)
+ "test.op"(%valB) : (i32) -> (i32)
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetUsersOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %val = pdl_interp.get_result 0 of %root
+ %ops = pdl_interp.get_users of %val : !pdl.value
+ pdl_interp.foreach %op : !pdl.operation in %ops {
+ pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_users_of_value
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: %[[OPERAND:.*]] = "test.op"
+module @ir attributes { test.get_users_of_value } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> (i32)
+ "test.op"(%operand, %operand) : (i32, i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end
+ ^next:
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %ops = pdl_interp.get_users of %vals : !pdl.range<value>
+ pdl_interp.foreach %op : !pdl.operation in %ops {
+ pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_all_users_of_range
+// CHECK: "test.success"
+// CHECK: "test.success"
+// CHECK: %[[OPERANDS:.*]]:2 = "test.op"
+module @ir attributes { test.get_all_users_of_range } {
+ %operands:2 = "test.op"() : () -> (i32, i32)
+ "test.op"(%operands#0) : (i32) -> (i32)
+ "test.op"(%operands#1) : (i32) -> (i32)
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end
+ ^next:
+ %vals = pdl_interp.get_results of %root : !pdl.range<value>
+ %val = pdl_interp.extract 0 of %vals : !pdl.value
+ %ops = pdl_interp.get_users of %val : !pdl.value
+ pdl_interp.foreach %op : !pdl.operation in %ops {
+ pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%matched : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %matched
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_first_users_of_range
+// CHECK: "test.success"
+// CHECK: %[[OPERANDS:.*]]:2 = "test.op"
+// CHECK: "test.op"
+module @ir attributes { test.get_first_users_of_range } {
+ %operands:2 = "test.op"() : () -> (i32, i32)
+ "test.op"(%operands#0) : (i32) -> (i32)
+ "test.op"(%operands#1) : (i32) -> (i32)
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetAttributeOp
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list