[Mlir-commits] [mlir] 85ab413 - [mlir][PDL] Add support for variadic operands and results in the PDL byte code
River Riddle
llvmlistbot at llvm.org
Tue Mar 16 13:20:37 PDT 2021
Author: River Riddle
Date: 2021-03-16T13:20:19-07:00
New Revision: 85ab413b53aeb135eb58dab066afcbf20bef0cf8
URL: https://github.com/llvm/llvm-project/commit/85ab413b53aeb135eb58dab066afcbf20bef0cf8
DIFF: https://github.com/llvm/llvm-project/commit/85ab413b53aeb135eb58dab066afcbf20bef0cf8.diff
LOG: [mlir][PDL] Add support for variadic operands and results in the PDL byte code
Supporting ranges in the byte code requires additional complexity, given that a range can't be easily representable as an opaque void *, as is possible with the existing bytecode value types (Attribute, Type, Value, etc.). To enable representing a range with void *, an auxillary storage is used for the actual range itself, with the pointer being passed around in the normal byte code memory. For type ranges, a TypeRange is stored. For value ranges, a ValueRange is stored. The above problem represents a majority of the complexity involved in this revision, the rest is adapting/adding byte code operations to support the changes made to the PDL interpreter in the parent revision.
After this revision, PDL will have initial end-to-end support for variadic operands/results.
Differential Revision: https://reviews.llvm.org/D95723
Added:
Modified:
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/IR/TypeRange.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Rewrite/ByteCode.h
mlir/lib/Rewrite/PatternApplicator.cpp
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index e35208747ade..ff9b3dda0f48 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -963,7 +963,7 @@ def PDLInterp_SwitchOperandCountOp
let builders = [
OpBuilder<(ins "Value":$operation, "ArrayRef<int32_t>":$counts,
- "Block *":$defaultDest, "BlockRange":$dests), [{
+ "Block *":$defaultDest, "BlockRange":$dests), [{
build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts),
defaultDest, dests);
}]>];
@@ -1033,7 +1033,7 @@ def PDLInterp_SwitchResultCountOp
let builders = [
OpBuilder<(ins "Value":$operation, "ArrayRef<int32_t>":$counts,
- "Block *":$defaultDest, "BlockRange":$dests), [{
+ "Block *":$defaultDest, "BlockRange":$dests), [{
build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts),
defaultDest, dests);
}]>];
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 56da9b870948..c797f5329bd5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -238,63 +238,92 @@ struct OpRewritePattern : public RewritePattern {
/// Storage type of byte-code interpreter values. These are passed to constraint
/// functions as arguments.
class PDLValue {
- /// The internal implementation type when the value is an Attribute,
- /// Operation*, or Type. See `impl` below for more details.
- using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>;
-
public:
- PDLValue(const PDLValue &other) : impl(other.impl) {}
- PDLValue(std::nullptr_t = nullptr) : impl() {}
- PDLValue(Attribute value) : impl(value) {}
- PDLValue(Operation *value) : impl(value) {}
- PDLValue(Type value) : impl(value) {}
- PDLValue(Value value) : impl(value) {}
+ /// The underlying kind of a PDL value.
+ enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
+
+ /// Construct a new PDL value.
+ PDLValue(const PDLValue &other) = default;
+ PDLValue(std::nullptr_t = nullptr) : value(nullptr), kind(Kind::Attribute) {}
+ PDLValue(Attribute value)
+ : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
+ PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
+ PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
+ PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
+ PDLValue(Value value)
+ : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
+ PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
/// Returns true if the type of the held value is `T`.
- template <typename T>
- std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const {
- return impl.is<Value>();
- }
- template <typename T>
- std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const {
- auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
- return attrOpTypeImpl && attrOpTypeImpl.is<T>();
+ template <typename T> bool isa() const {
+ assert(value && "isa<> used on a null value");
+ return kind == getKindOf<T>();
}
/// Attempt to dynamically cast this value to type `T`, returns null if this
/// value is not an instance of `T`.
- template <typename T>
- std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const {
- return impl.dyn_cast<T>();
- }
- template <typename T>
- std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const {
- auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
- return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>();
+ template <typename T,
+ typename ResultT = std::conditional_t<
+ std::is_convertible<T, bool>::value, T, Optional<T>>>
+ ResultT dyn_cast() const {
+ return isa<T>() ? castImpl<T>() : ResultT();
}
/// Cast this value to type `T`, asserts if this value is not an instance of
/// `T`.
- template <typename T>
- std::enable_if_t<std::is_same<T, Value>::value, T> cast() const {
- return impl.get<T>();
- }
- template <typename T>
- std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const {
- return impl.get<AttrOpTypeImplT>().get<T>();
+ template <typename T> T cast() const {
+ assert(isa<T>() && "expected value to be of type `T`");
+ return castImpl<T>();
}
/// Get an opaque pointer to the value.
- void *getAsOpaquePointer() { return impl.getOpaqueValue(); }
+ const void *getAsOpaquePointer() const { return value; }
+
+ /// Return if this value is null or not.
+ explicit operator bool() const { return value; }
+
+ /// Return the kind of this value.
+ Kind getKind() const { return kind; }
/// Print this value to the provided output stream.
- void print(raw_ostream &os);
+ void print(raw_ostream &os) const;
private:
- /// The internal opaque representation of a PDLValue. We use a nested
- /// PointerUnion structure here because `Value` only has 1 low bit
- /// available, where as the remaining types all have 3.
- llvm::PointerUnion<AttrOpTypeImplT, Value> impl;
+ /// Find the index of a given type in a range of other types.
+ template <typename...> struct index_of_t;
+ template <typename T, typename... R>
+ struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
+ template <typename T, typename F, typename... R>
+ struct index_of_t<T, F, R...>
+ : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
+
+ /// Return the kind used for the given T.
+ template <typename T> static Kind getKindOf() {
+ return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
+ TypeRange, Value, ValueRange>::value);
+ }
+
+ /// The internal implementation of `cast`, that returns the underlying value
+ /// as the given type `T`.
+ template <typename T>
+ std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
+ castImpl() const {
+ return T::getFromOpaquePointer(value);
+ }
+ template <typename T>
+ std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
+ castImpl() const {
+ return *reinterpret_cast<T *>(const_cast<void *>(value));
+ }
+ template <typename T>
+ std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
+ return reinterpret_cast<T>(const_cast<void *>(value));
+ }
+
+ /// The internal opaque representation of a PDLValue.
+ const void *value;
+ /// The kind of the opaque value.
+ Kind kind;
};
inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
@@ -319,14 +348,66 @@ class PDLResultList {
/// Push a new Type onto the result list.
void push_back(Type value) { results.push_back(value); }
+ /// Push a new TypeRange onto the result list.
+ void push_back(TypeRange value) {
+ // The lifetime of a TypeRange can't be guaranteed, so we'll need to
+ // allocate a storage for it.
+ llvm::OwningArrayRef<Type> storage(value.size());
+ llvm::copy(value, storage.begin());
+ allocatedTypeRanges.emplace_back(std::move(storage));
+ typeRanges.push_back(allocatedTypeRanges.back());
+ results.push_back(&typeRanges.back());
+ }
+ void push_back(ValueTypeRange<OperandRange> value) {
+ typeRanges.push_back(value);
+ results.push_back(&typeRanges.back());
+ }
+ void push_back(ValueTypeRange<ResultRange> value) {
+ typeRanges.push_back(value);
+ results.push_back(&typeRanges.back());
+ }
+
/// Push a new Value onto the result list.
void push_back(Value value) { results.push_back(value); }
+ /// Push a new ValueRange onto the result list.
+ void push_back(ValueRange value) {
+ // The lifetime of a ValueRange can't be guaranteed, so we'll need to
+ // allocate a storage for it.
+ llvm::OwningArrayRef<Value> storage(value.size());
+ llvm::copy(value, storage.begin());
+ allocatedValueRanges.emplace_back(std::move(storage));
+ valueRanges.push_back(allocatedValueRanges.back());
+ results.push_back(&valueRanges.back());
+ }
+ void push_back(OperandRange value) {
+ valueRanges.push_back(value);
+ results.push_back(&valueRanges.back());
+ }
+ void push_back(ResultRange value) {
+ valueRanges.push_back(value);
+ results.push_back(&valueRanges.back());
+ }
+
protected:
- PDLResultList() = default;
+ /// Create a new result list with the expected number of results.
+ PDLResultList(unsigned maxNumResults) {
+ // For now just reserve enough space for all of the results. We could do
+ // separate counts per range type, but it isn't really worth it unless there
+ // are a "large" number of results.
+ typeRanges.reserve(maxNumResults);
+ valueRanges.reserve(maxNumResults);
+ }
/// The PDL results held by this list.
SmallVector<PDLValue> results;
+ /// Memory used to store ranges held by the list.
+ SmallVector<TypeRange> typeRanges;
+ SmallVector<ValueRange> valueRanges;
+ /// Memory allocated to store ranges in the result list whose lifetime was
+ /// generated in the native function.
+ SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
+ SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index fe11fde58793..4fb40e127f9f 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -82,6 +82,12 @@ inline ::llvm::hash_code hash_value(TypeRange arg) {
return ::llvm::hash_combine_range(arg.begin(), arg.end());
}
+/// Emit a type range to the given output stream.
+inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
+ llvm::interleaveComma(types, os);
+ return os;
+}
+
//===----------------------------------------------------------------------===//
// ValueTypeRange
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 034698d85cb1..354d5f31bf74 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -73,22 +73,31 @@ void RewritePattern::anchor() {}
// PDLValue
//===----------------------------------------------------------------------===//
-void PDLValue::print(raw_ostream &os) {
- if (!impl) {
- os << "<Null-PDLValue>";
+void PDLValue::print(raw_ostream &os) const {
+ if (!value) {
+ os << "<NULL-PDLValue>";
return;
}
- if (Value val = impl.dyn_cast<Value>()) {
- os << val;
- return;
+ switch (kind) {
+ case Kind::Attribute:
+ os << cast<Attribute>();
+ break;
+ case Kind::Operation:
+ os << *cast<Operation *>();
+ break;
+ case Kind::Type:
+ os << cast<Type>();
+ break;
+ case Kind::TypeRange:
+ llvm::interleaveComma(cast<TypeRange>(), os);
+ break;
+ case Kind::Value:
+ os << cast<Value>();
+ break;
+ case Kind::ValueRange:
+ llvm::interleaveComma(cast<ValueRange>(), os);
+ break;
}
- AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>();
- if (Attribute attr = aotImpl.dyn_cast<Attribute>())
- os << attr;
- else if (Operation *op = aotImpl.dyn_cast<Operation *>())
- os << *op;
- else
- os << aotImpl.get<Type>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ef96e25c7be3..ea17f99deb9c 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -20,6 +20,9 @@
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <numeric>
#define DEBUG_TYPE "pdl-bytecode"
@@ -60,6 +63,14 @@ void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
currentPatternBenefits[patternIndex] = benefit;
}
+/// Cleanup any allocated state after a full match/rewrite has been completed.
+/// This method should be called irregardless of whether the match+rewrite was a
+/// success or not.
+void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
+ allocatedTypeRangeMemory.clear();
+ allocatedValueRangeMemory.clear();
+}
+
//===----------------------------------------------------------------------===//
// Bytecode OpCodes
//===----------------------------------------------------------------------===//
@@ -72,6 +83,8 @@ enum OpCode : ByteCodeField {
ApplyRewrite,
/// Check if two generic values are equal.
AreEqual,
+ /// Check if two ranges are equal.
+ AreRangesEqual,
/// Unconditional branch.
Branch,
/// Compare the operand count of an operation with a constant.
@@ -80,8 +93,12 @@ enum OpCode : ByteCodeField {
CheckOperationName,
/// Compare the result count of an operation with a constant.
CheckResultCount,
+ /// Compare a range of types to a constant range of types.
+ CheckTypes,
/// Create an operation.
CreateOperation,
+ /// Create a range of types.
+ CreateTypes,
/// Erase an operation.
EraseOp,
/// Terminate a matcher or rewrite sequence.
@@ -98,14 +115,20 @@ enum OpCode : ByteCodeField {
GetOperand2,
GetOperand3,
GetOperandN,
+ /// Get a specific operand group of an operation.
+ GetOperands,
/// Get a specific result of an operation.
GetResult0,
GetResult1,
GetResult2,
GetResult3,
GetResultN,
+ /// Get a specific result group of an operation.
+ GetResults,
/// Get the type of a value.
GetValueType,
+ /// Get the types of a value range.
+ GetValueRangeTypes,
/// Check if a generic value is not null.
IsNotNull,
/// Record a successful pattern match.
@@ -122,9 +145,9 @@ enum OpCode : ByteCodeField {
SwitchResultCount,
/// Compare a type with a set of constants.
SwitchType,
+ /// Compare a range of types with a set of constants.
+ SwitchTypes,
};
-
-enum class PDLValueKind { Attribute, Operation, Type, Value };
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -145,11 +168,15 @@ class Generator {
SmallVectorImpl<ByteCodeField> &rewriterByteCode,
SmallVectorImpl<PDLByteCodePattern> &patterns,
ByteCodeField &maxValueMemoryIndex,
+ ByteCodeField &maxTypeRangeMemoryIndex,
+ ByteCodeField &maxValueRangeMemoryIndex,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
- maxValueMemoryIndex(maxValueMemoryIndex) {
+ maxValueMemoryIndex(maxValueMemoryIndex),
+ maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
+ maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
for (auto it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(rewriteFns))
@@ -166,6 +193,13 @@ class Generator {
return valueToMemIndex[value];
}
+ /// Return the range memory index used to store the given range value.
+ ByteCodeField &getRangeStorageIndex(Value value) {
+ assert(valueToRangeIndex.count(value) &&
+ "expected range index to be assigned");
+ return valueToRangeIndex[value];
+ }
+
/// Return an index to use when referring to the given data that is uniqued in
/// the MLIR context.
template <typename T>
@@ -197,16 +231,20 @@ class Generator {
void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
@@ -214,6 +252,7 @@ class Generator {
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
@@ -221,6 +260,9 @@ class Generator {
/// Mapping from value to its corresponding memory index.
DenseMap<Value, ByteCodeField> valueToMemIndex;
+ /// Mapping from a range value to its corresponding range storage index.
+ DenseMap<Value, ByteCodeField> valueToRangeIndex;
+
/// Mapping from the name of an externally registered rewrite to its index in
/// the bytecode registry.
llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
@@ -246,6 +288,8 @@ class Generator {
SmallVectorImpl<ByteCodeField> &rewriterByteCode;
SmallVectorImpl<PDLByteCodePattern> &patterns;
ByteCodeField &maxValueMemoryIndex;
+ ByteCodeField &maxTypeRangeMemoryIndex;
+ ByteCodeField &maxValueRangeMemoryIndex;
};
/// This class provides utilities for writing a bytecode stream.
@@ -281,19 +325,33 @@ struct ByteCodeWriter {
/// Append a range of values that will be read as generic PDLValues.
void appendPDLValueList(OperandRange values) {
bytecode.push_back(values.size());
- for (Value value : values) {
- // Append the type of the value in addition to the value itself.
- PDLValueKind kind =
- TypeSwitch<Type, PDLValueKind>(value.getType())
- .Case<pdl::AttributeType>(
- [](Type) { return PDLValueKind::Attribute; })
- .Case<pdl::OperationType>(
- [](Type) { return PDLValueKind::Operation; })
- .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
- .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
- bytecode.push_back(static_cast<ByteCodeField>(kind));
- append(value);
- }
+ for (Value value : values)
+ appendPDLValue(value);
+ }
+
+ /// Append a value as a PDLValue.
+ void appendPDLValue(Value value) {
+ appendPDLValueKind(value);
+ append(value);
+ }
+
+ /// Append the PDLValue::Kind of the given value.
+ void appendPDLValueKind(Value value) {
+ // Append the type of the value in addition to the value itself.
+ PDLValue::Kind kind =
+ TypeSwitch<Type, PDLValue::Kind>(value.getType())
+ .Case<pdl::AttributeType>(
+ [](Type) { return PDLValue::Kind::Attribute; })
+ .Case<pdl::OperationType>(
+ [](Type) { return PDLValue::Kind::Operation; })
+ .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
+ if (rangeTy.getElementType().isa<pdl::TypeType>())
+ return PDLValue::Kind::TypeRange;
+ return PDLValue::Kind::ValueRange;
+ })
+ .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
+ .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
+ bytecode.push_back(static_cast<ByteCodeField>(kind));
}
/// Check if the given class `T` has an iterator type.
@@ -334,6 +392,36 @@ struct ByteCodeWriter {
/// The main generator producing PDL.
Generator &generator;
};
+
+/// This class represents a live range of PDL Interpreter values, containing
+/// information about when values are live within a match/rewrite.
+struct ByteCodeLiveRange {
+ using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
+ using Allocator = Set::Allocator;
+
+ ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
+
+ /// Union this live range with the one provided.
+ void unionWith(const ByteCodeLiveRange &rhs) {
+ for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
+ liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ }
+
+ /// Returns true if this range overlaps with the one provided.
+ bool overlaps(const ByteCodeLiveRange &rhs) const {
+ return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
+ }
+
+ /// A map representing the ranges of the match/rewrite that a value is live in
+ /// the interpreter.
+ llvm::IntervalMap<ByteCodeField, char, 16> liveness;
+
+ /// The type range storage index for this range.
+ Optional<unsigned> typeRangeIndex;
+
+ /// The value range storage index for this range.
+ Optional<unsigned> valueRangeIndex;
+};
} // end anonymous namespace
void Generator::generate(ModuleOp module) {
@@ -381,15 +469,30 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Rewriters use simplistic allocation scheme that simply assigns an index to
// each result.
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
- ByteCodeField index = 0;
+ ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
+ auto processRewriterValue = [&](Value val) {
+ valueToMemIndex.try_emplace(val, index++);
+ if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
+ Type elementTy = rangeType.getElementType();
+ if (elementTy.isa<pdl::TypeType>())
+ valueToRangeIndex.try_emplace(val, typeRangeIndex++);
+ else if (elementTy.isa<pdl::ValueType>())
+ valueToRangeIndex.try_emplace(val, valueRangeIndex++);
+ }
+ };
+
for (BlockArgument arg : rewriterFunc.getArguments())
- valueToMemIndex.try_emplace(arg, index++);
+ processRewriterValue(arg);
rewriterFunc.getBody().walk([&](Operation *op) {
for (Value result : op->getResults())
- valueToMemIndex.try_emplace(result, index++);
+ processRewriterValue(result);
});
if (index > maxValueMemoryIndex)
maxValueMemoryIndex = index;
+ if (typeRangeIndex > maxTypeRangeMemoryIndex)
+ maxTypeRangeMemoryIndex = typeRangeIndex;
+ if (valueRangeIndex > maxValueRangeMemoryIndex)
+ maxValueRangeMemoryIndex = valueRangeIndex;
}
// The matcher function uses a more sophisticated numbering that tries to
@@ -404,9 +507,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
});
// Liveness info for each of the defs within the matcher.
- using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
- LivenessSet::Allocator allocator;
- DenseMap<Value, LivenessSet> valueDefRanges;
+ ByteCodeLiveRange::Allocator allocator;
+ DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
// Assign the root operation being matched to slot 0.
BlockArgument rootOpArg = matcherFunc.getArgument(0);
@@ -425,10 +527,19 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
// Set indices for the range of this block that the value is used.
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
- defRangeIt->second.insert(
+ defRangeIt->second.liveness.insert(
opToIndex[firstUseOrDef],
opToIndex[info->getEndOperation(value, firstUseOrDef)],
/*dummyValue*/ 0);
+
+ // Check to see if this value is a range type.
+ if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
+ Type eleType = rangeTy.getElementType();
+ if (eleType.isa<pdl::TypeType>())
+ defRangeIt->second.typeRangeIndex = 0;
+ else if (eleType.isa<pdl::ValueType>())
+ defRangeIt->second.valueRangeIndex = 0;
+ }
};
// Process the live-ins of this block.
@@ -442,37 +553,59 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
}
// Greedily allocate memory slots using the computed def live ranges.
- std::vector<LivenessSet> allocatedIndices;
+ std::vector<ByteCodeLiveRange> allocatedIndices;
+ ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
for (auto &defIt : valueDefRanges) {
ByteCodeField &memIndex = valueToMemIndex[defIt.first];
- LivenessSet &defSet = defIt.second;
+ ByteCodeLiveRange &defRange = defIt.second;
// Try to allocate to an existing index.
for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
- LivenessSet &existingIndex = existingIndexIt.value();
- llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
- defIt.second, existingIndex);
- if (overlaps.valid())
- continue;
- // Union the range of the def within the existing index.
- for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
- existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
- memIndex = existingIndexIt.index() + 1;
+ ByteCodeLiveRange &existingRange = existingIndexIt.value();
+ if (!defRange.overlaps(existingRange)) {
+ existingRange.unionWith(defRange);
+ memIndex = existingIndexIt.index() + 1;
+
+ if (defRange.typeRangeIndex) {
+ if (!existingRange.typeRangeIndex)
+ existingRange.typeRangeIndex = numTypeRanges++;
+ valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
+ } else if (defRange.valueRangeIndex) {
+ if (!existingRange.valueRangeIndex)
+ existingRange.valueRangeIndex = numValueRanges++;
+ valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
+ }
+ break;
+ }
}
// If no existing index could be used, add a new one.
if (memIndex == 0) {
allocatedIndices.emplace_back(allocator);
- for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
- allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ ByteCodeLiveRange &newRange = allocatedIndices.back();
+ newRange.unionWith(defRange);
+
+ // Allocate an index for type/value ranges.
+ if (defRange.typeRangeIndex) {
+ newRange.typeRangeIndex = numTypeRanges;
+ valueToRangeIndex[defIt.first] = numTypeRanges++;
+ } else if (defRange.valueRangeIndex) {
+ newRange.valueRangeIndex = numValueRanges;
+ valueToRangeIndex[defIt.first] = numValueRanges++;
+ }
+
memIndex = allocatedIndices.size();
+ ++numIndices;
}
}
// Update the max number of indices.
- ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
- if (numMatcherIndices > maxValueMemoryIndex)
- maxValueMemoryIndex = numMatcherIndices;
+ if (numIndices > maxValueMemoryIndex)
+ maxValueMemoryIndex = numIndices;
+ if (numTypeRanges > maxTypeRangeMemoryIndex)
+ maxTypeRangeMemoryIndex = numTypeRanges;
+ if (numValueRanges > maxValueRangeMemoryIndex)
+ maxValueRangeMemoryIndex = numValueRanges;
}
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
@@ -481,17 +614,19 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::AreEqualOp, pdl_interp::BranchOp,
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
- pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
- pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+ pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
+ pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
+ pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
pdl_interp::EraseOp, pdl_interp::FinalizeOp,
pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
- pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
+ pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
+ pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
- pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
- pdl_interp::SwitchResultCountOp>(
+ pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
+ pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
@@ -515,16 +650,31 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
op.constParamsAttr());
writer.appendPDLValueList(op.args());
+ ResultRange results = op.results();
+ writer.append(ByteCodeField(results.size()));
+ for (Value result : results) {
+ // In debug mode we also record the expected kind of the result, so that we
+ // can provide extra verification of the native rewrite function.
#ifndef NDEBUG
- // In debug mode we also append the number of results so that we can assert
- // that the native creation function gave us the correct number of results.
- writer.append(ByteCodeField(op.results().size()));
+ writer.appendPDLValueKind(result);
#endif
- for (Value result : op.results())
+
+ // Range results also need to append the range storage index.
+ if (result.getType().isa<pdl::RangeType>())
+ writer.append(getRangeStorageIndex(result));
writer.append(result);
+ }
}
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
- writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
+ Value lhs = op.lhs();
+ if (lhs.getType().isa<pdl::RangeType>()) {
+ writer.append(OpCode::AreRangesEqual);
+ writer.appendPDLValueKind(lhs);
+ writer.append(op.lhs(), op.rhs(), op.getSuccessors());
+ return;
+ }
+
+ writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
}
void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
@@ -537,6 +687,7 @@ void Generator::generate(pdl_interp::CheckAttributeOp op,
void Generator::generate(pdl_interp::CheckOperandCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
+ static_cast<ByteCodeField>(op.compareAtLeast()),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckOperationNameOp op,
@@ -547,11 +698,15 @@ void Generator::generate(pdl_interp::CheckOperationNameOp op,
void Generator::generate(pdl_interp::CheckResultCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
+ static_cast<ByteCodeField>(op.compareAtLeast()),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
}
+void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
+}
void Generator::generate(pdl_interp::CreateAttributeOp op,
ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
@@ -560,7 +715,8 @@ void Generator::generate(pdl_interp::CreateAttributeOp op,
void Generator::generate(pdl_interp::CreateOperationOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CreateOperation, op.operation(),
- OperationName(op.name(), ctx), op.operands());
+ OperationName(op.name(), ctx));
+ writer.appendPDLValueList(op.operands());
// Add the attributes.
OperandRange attributes = op.attributes();
@@ -570,12 +726,16 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
std::get<1>(it));
}
- writer.append(op.types());
+ writer.appendPDLValueList(op.types());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.result()) = getMemIndex(op.value());
}
+void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::CreateTypes, op.result(),
+ getRangeStorageIndex(op.result()), op.value());
+}
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
writer.append(OpCode::EraseOp, op.operation());
}
@@ -593,7 +753,8 @@ void Generator::generate(pdl_interp::GetAttributeTypeOp op,
}
void Generator::generate(pdl_interp::GetDefiningOpOp op,
ByteCodeWriter &writer) {
- writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
+ writer.append(OpCode::GetDefiningOp, op.operation());
+ writer.appendPDLValue(op.value());
}
void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
uint32_t index = op.index();
@@ -603,6 +764,18 @@ void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetOperandN, index);
writer.append(op.operation(), op.value());
}
+void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
+ Value result = op.value();
+ Optional<uint32_t> index = op.index();
+ writer.append(OpCode::GetOperands,
+ index.getValueOr(std::numeric_limits<uint32_t>::max()),
+ op.operation());
+ if (result.getType().isa<pdl::RangeType>())
+ writer.append(getRangeStorageIndex(result));
+ else
+ writer.append(std::numeric_limits<ByteCodeField>::max());
+ writer.append(result);
+}
void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
uint32_t index = op.index();
if (index < 4)
@@ -611,10 +784,29 @@ void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetResultN, index);
writer.append(op.operation(), op.value());
}
+void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
+ Value result = op.value();
+ Optional<uint32_t> index = op.index();
+ writer.append(OpCode::GetResults,
+ index.getValueOr(std::numeric_limits<uint32_t>::max()),
+ op.operation());
+ if (result.getType().isa<pdl::RangeType>())
+ writer.append(getRangeStorageIndex(result));
+ else
+ writer.append(std::numeric_limits<ByteCodeField>::max());
+ writer.append(result);
+}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
- writer.append(OpCode::GetValueType, op.result(), op.value());
+ if (op.getType().isa<pdl::RangeType>()) {
+ Value result = op.result();
+ writer.append(OpCode::GetValueRangeTypes, result,
+ getRangeStorageIndex(result), op.value());
+ } else {
+ writer.append(OpCode::GetValueType, op.result(), op.value());
+ }
}
+
void Generator::generate(pdl_interp::InferredTypesOp op,
ByteCodeWriter &writer) {
// InferType maps to a null type as a marker for inferring result types.
@@ -628,11 +820,12 @@ void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
patterns.emplace_back(PDLByteCodePattern::create(
op, rewriterToAddr[op.rewriter().getLeafReference()]));
writer.append(OpCode::RecordMatch, patternIndex,
- SuccessorRange(op.getOperation()), op.matchedOps(),
- op.inputs());
+ SuccessorRange(op.getOperation()), op.matchedOps());
+ writer.appendPDLValueList(op.inputs());
}
void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
- writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
+ writer.append(OpCode::ReplaceOp, op.operation());
+ writer.appendPDLValueList(op.replValues());
}
void Generator::generate(pdl_interp::SwitchAttributeOp op,
ByteCodeWriter &writer) {
@@ -661,6 +854,10 @@ void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
op.getSuccessors());
}
+void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
+ op.getSuccessors());
+}
//===----------------------------------------------------------------------===//
// PDLByteCode
@@ -671,7 +868,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
- constraintFns, rewriteFns);
+ maxTypeRangeCount, maxValueRangeCount, constraintFns,
+ rewriteFns);
generator.generate(module);
// Initialize the external functions.
@@ -685,6 +883,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
/// bytecode.
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
state.memory.resize(maxValueMemoryIndex, nullptr);
+ state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
+ state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
state.currentPatternBenefits.reserve(patterns.size());
for (const PDLByteCodePattern &pattern : patterns)
state.currentPatternBenefits.push_back(pattern.getBenefit());
@@ -697,17 +897,24 @@ namespace {
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
- ByteCodeExecutor(const ByteCodeField *curCodeIt,
- MutableArrayRef<const void *> memory,
- ArrayRef<const void *> uniquedMemory,
- ArrayRef<ByteCodeField> code,
- ArrayRef<PatternBenefit> currentPatternBenefits,
- ArrayRef<PDLByteCodePattern> patterns,
- ArrayRef<PDLConstraintFunction> constraintFunctions,
- ArrayRef<PDLRewriteFunction> rewriteFunctions)
- : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
- code(code), currentPatternBenefits(currentPatternBenefits),
- patterns(patterns), constraintFunctions(constraintFunctions),
+ ByteCodeExecutor(
+ const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
+ MutableArrayRef<TypeRange> typeRangeMemory,
+ std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
+ MutableArrayRef<ValueRange> valueRangeMemory,
+ std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
+ ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
+ ArrayRef<PatternBenefit> currentPatternBenefits,
+ ArrayRef<PDLByteCodePattern> patterns,
+ ArrayRef<PDLConstraintFunction> constraintFunctions,
+ ArrayRef<PDLRewriteFunction> rewriteFunctions)
+ : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
+ allocatedTypeRangeMemory(allocatedTypeRangeMemory),
+ valueRangeMemory(valueRangeMemory),
+ allocatedValueRangeMemory(allocatedValueRangeMemory),
+ uniquedMemory(uniquedMemory), code(code),
+ currentPatternBenefits(currentPatternBenefits), patterns(patterns),
+ constraintFunctions(constraintFunctions),
rewriteFunctions(rewriteFunctions) {}
/// Start executing the code at the current bytecode index. `matches` is an
@@ -722,19 +929,25 @@ class ByteCodeExecutor {
void executeApplyConstraint(PatternRewriter &rewriter);
void executeApplyRewrite(PatternRewriter &rewriter);
void executeAreEqual();
+ void executeAreRangesEqual();
void executeBranch();
void executeCheckOperandCount();
void executeCheckOperationName();
void executeCheckResultCount();
+ void executeCheckTypes();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
+ void executeCreateTypes();
void executeEraseOp(PatternRewriter &rewriter);
void executeGetAttribute();
void executeGetAttributeType();
void executeGetDefiningOp();
void executeGetOperand(unsigned index);
+ void executeGetOperands();
void executeGetResult(unsigned index);
+ void executeGetResults();
void executeGetValueType();
+ void executeGetValueRangeTypes();
void executeIsNotNull();
void executeRecordMatch(PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> &matches);
@@ -744,6 +957,7 @@ class ByteCodeExecutor {
void executeSwitchOperationName();
void executeSwitchResultCount();
void executeSwitchType();
+ void executeSwitchTypes();
/// Read a value from the bytecode buffer, optionally skipping a certain
/// number of prefix values. These methods always update the buffer to point
@@ -763,6 +977,19 @@ class ByteCodeExecutor {
list.push_back(read<ValueT>());
}
+ /// Read a list of values from the bytecode buffer. The values may be encoded
+ /// as either Value or ValueRange elements.
+ void readValueList(SmallVectorImpl<Value> &list) {
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+ list.push_back(read<Value>());
+ } else {
+ ValueRange *values = read<ValueRange *>();
+ list.append(values->begin(), values->end());
+ }
+ }
+ }
+
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
@@ -771,8 +998,8 @@ class ByteCodeExecutor {
}
/// Handle a switch operation with the provided value and cases.
- template <typename T, typename RangeT>
- void handleSwitch(const T &value, RangeT &&cases) {
+ template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
+ void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
LLVM_DEBUG({
llvm::dbgs() << " * Value: " << value << "\n"
<< " * Cases: ";
@@ -783,7 +1010,7 @@ class ByteCodeExecutor {
// Check to see if the attribute value is within the case list. Jump to
// the correct successor index based on the result.
for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
- if (*it == value)
+ if (cmp(*it, value))
return selectJump(size_t((it - cases.begin()) + 1));
selectJump(size_t(0));
}
@@ -795,7 +1022,9 @@ class ByteCodeExecutor {
size_t index = *curCodeIt++;
// If this type is an SSA value, it can only be stored in non-const memory.
- if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
+ if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
+ Value>::value ||
+ index < memory.size())
return memory[index];
// Otherwise, if this index is not inbounds it is uniqued.
@@ -813,17 +1042,21 @@ class ByteCodeExecutor {
}
template <typename T>
std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
- switch (static_cast<PDLValueKind>(read())) {
- case PDLValueKind::Attribute:
+ switch (read<PDLValue::Kind>()) {
+ case PDLValue::Kind::Attribute:
return read<Attribute>();
- case PDLValueKind::Operation:
+ case PDLValue::Kind::Operation:
return read<Operation *>();
- case PDLValueKind::Type:
+ case PDLValue::Kind::Type:
return read<Type>();
- case PDLValueKind::Value:
+ case PDLValue::Kind::Value:
return read<Value>();
+ case PDLValue::Kind::TypeRange:
+ return read<TypeRange *>();
+ case PDLValue::Kind::ValueRange:
+ return read<ValueRange *>();
}
- llvm_unreachable("unhandled PDLValueKind");
+ llvm_unreachable("unhandled PDLValue::Kind");
}
template <typename T>
std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
@@ -838,12 +1071,20 @@ class ByteCodeExecutor {
std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
return *curCodeIt++;
}
+ template <typename T>
+ std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
+ return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
+ }
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
/// The current execution memory.
MutableArrayRef<const void *> memory;
+ MutableArrayRef<TypeRange> typeRangeMemory;
+ std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
+ MutableArrayRef<ValueRange> valueRangeMemory;
+ std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
/// References to ByteCode data necessary for execution.
ArrayRef<const void *> uniquedMemory;
@@ -859,8 +1100,21 @@ class ByteCodeExecutor {
/// overexposing access to information specific solely to the ByteCode.
class ByteCodeRewriteResultList : public PDLResultList {
public:
+ ByteCodeRewriteResultList(unsigned maxNumResults)
+ : PDLResultList(maxNumResults) {}
+
/// Return the list of PDL results.
MutableArrayRef<PDLValue> getResults() { return results; }
+
+ /// Return the type ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+ return allocatedTypeRanges;
+ }
+
+ /// Return the value ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+ return allocatedValueRanges;
+ }
};
} // end anonymous namespace
@@ -893,21 +1147,46 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
});
- ByteCodeRewriteResultList results;
+
+ // Execute the rewrite function.
+ ByteCodeField numResults = read();
+ ByteCodeRewriteResultList results(numResults);
rewriteFn(args, constParams, rewriter, results);
- // Store the results in the bytecode memory.
-#ifndef NDEBUG
- ByteCodeField expectedNumberOfResults = read();
- assert(results.getResults().size() == expectedNumberOfResults &&
+ assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
-#endif
// Store the results in the bytecode memory.
for (PDLValue &result : results.getResults()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
- memory[read()] = result.getAsOpaquePointer();
+
+// In debug mode we also verify the expected kind of the result.
+#ifndef NDEBUG
+ assert(result.getKind() == read<PDLValue::Kind>() &&
+ "native PDL rewrite function returned an unexpected type of result");
+#endif
+
+ // If the result is a range, we need to copy it over to the bytecodes
+ // range memory.
+ if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
+ unsigned rangeIndex = read();
+ typeRangeMemory[rangeIndex] = *typeRange;
+ memory[read()] = &typeRangeMemory[rangeIndex];
+ } else if (Optional<ValueRange> valueRange =
+ result.dyn_cast<ValueRange>()) {
+ unsigned rangeIndex = read();
+ valueRangeMemory[rangeIndex] = *valueRange;
+ memory[read()] = &valueRangeMemory[rangeIndex];
+ } else {
+ memory[read()] = result.getAsOpaquePointer();
+ }
}
+
+ // Copy over any underlying storage allocated for result ranges.
+ for (auto &it : results.getAllocatedTypeRanges())
+ allocatedTypeRangeMemory.push_back(std::move(it));
+ for (auto &it : results.getAllocatedValueRanges())
+ allocatedValueRangeMemory.push_back(std::move(it));
}
void ByteCodeExecutor::executeAreEqual() {
@@ -919,6 +1198,32 @@ void ByteCodeExecutor::executeAreEqual() {
selectJump(lhs == rhs);
}
+void ByteCodeExecutor::executeAreRangesEqual() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
+ PDLValue::Kind valueKind = read<PDLValue::Kind>();
+ const void *lhs = read<const void *>();
+ const void *rhs = read<const void *>();
+
+ switch (valueKind) {
+ case PDLValue::Kind::TypeRange: {
+ const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
+ const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ selectJump(*lhsRange == *rhsRange);
+ break;
+ }
+ case PDLValue::Kind::ValueRange: {
+ const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
+ const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ selectJump(*lhsRange == *rhsRange);
+ break;
+ }
+ default:
+ llvm_unreachable("unexpected `AreRangesEqual` value kind");
+ }
+}
+
void ByteCodeExecutor::executeBranch() {
LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
curCodeIt = &code[read<ByteCodeAddr>()];
@@ -928,10 +1233,16 @@ void ByteCodeExecutor::executeCheckOperandCount() {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
+ bool compareAtLeast = read();
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
- << " * Expected: " << expectedCount << "\n");
- selectJump(op->getNumOperands() == expectedCount);
+ << " * Expected: " << expectedCount << "\n"
+ << " * Comparator: "
+ << (compareAtLeast ? ">=" : "==") << "\n");
+ if (compareAtLeast)
+ selectJump(op->getNumOperands() >= expectedCount);
+ else
+ selectJump(op->getNumOperands() == expectedCount);
}
void ByteCodeExecutor::executeCheckOperationName() {
@@ -948,10 +1259,44 @@ void ByteCodeExecutor::executeCheckResultCount() {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
+ bool compareAtLeast = read();
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
- << " * Expected: " << expectedCount << "\n");
- selectJump(op->getNumResults() == expectedCount);
+ << " * Expected: " << expectedCount << "\n"
+ << " * Comparator: "
+ << (compareAtLeast ? ">=" : "==") << "\n");
+ if (compareAtLeast)
+ selectJump(op->getNumResults() >= expectedCount);
+ else
+ selectJump(op->getNumResults() == expectedCount);
+}
+
+void ByteCodeExecutor::executeCheckTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ TypeRange *lhs = read<TypeRange *>();
+ Attribute rhs = read<Attribute>();
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+
+ selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
+}
+
+void ByteCodeExecutor::executeCreateTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
+
+ // Allocate a buffer for this type range.
+ llvm::OwningArrayRef<Type> storage(typesAttr.size());
+ llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
+ allocatedTypeRangeMemory.emplace_back(std::move(storage));
+
+ // Assign this to the range slot and use the range as the value for the
+ // memory index.
+ typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
+ memory[memIndex] = &typeRangeMemory[rangeIndex];
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@@ -960,22 +1305,26 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
- readList<Value>(state.operands);
+ readValueList(state.operands);
for (unsigned i = 0, e = read(); i != e; ++i) {
Identifier name = read<Identifier>();
if (Attribute attr = read<Attribute>())
state.addAttribute(name, attr);
}
- bool hasInferredTypes = false;
for (unsigned i = 0, e = read(); i != e; ++i) {
- Type resultType = read<Type>();
- hasInferredTypes |= !resultType;
- state.types.push_back(resultType);
- }
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+ state.types.push_back(read<Type>());
+ continue;
+ }
+
+ // If we find a null range, this signals that the types are infered.
+ if (TypeRange *resultTypes = read<TypeRange *>()) {
+ state.types.append(resultTypes->begin(), resultTypes->end());
+ continue;
+ }
- // Handle the case where the operation has inferred types.
- if (hasInferredTypes) {
+ // Handle the case where the operation has inferred types.
InferTypeOpInterface::Concept *concept =
state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
@@ -986,7 +1335,9 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
state.attributes.getDictionary(state.getContext()), state.regions,
state.types)))
return;
+ break;
}
+
Operation *resultOp = rewriter.createOperation(state);
memory[memIndex] = resultOp;
@@ -1036,11 +1387,21 @@ void ByteCodeExecutor::executeGetAttributeType() {
void ByteCodeExecutor::executeGetDefiningOp() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
unsigned memIndex = read();
- Value value = read<Value>();
- Operation *op = value ? value.getDefiningOp() : nullptr;
+ Operation *op = nullptr;
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+ Value value = read<Value>();
+ if (value)
+ op = value.getDefiningOp();
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
+ } else {
+ ValueRange *values = read<ValueRange *>();
+ if (values && !values->empty()) {
+ op = values->front().getDefiningOp();
+ }
+ LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
+ }
- LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
- << " * Result: " << *op << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
memory[memIndex] = op;
}
@@ -1056,6 +1417,75 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) {
memory[memIndex] = operand.getAsOpaquePointer();
}
+/// This function is the internal implementation of `GetResults` and
+/// `GetOperands` that provides support for extracting a value range from the
+/// given operation.
+template <template <typename> class AttrSizedSegmentsT, typename RangeT>
+static void *
+executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
+ ByteCodeField rangeIndex, StringRef attrSizedSegments,
+ MutableArrayRef<ValueRange> &valueRangeMemory) {
+ // Check for the sentinel index that signals that all values should be
+ // returned.
+ if (index == std::numeric_limits<uint32_t>::max()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
+ // `values` is already the full value range.
+
+ // Otherwise, check to see if this operation uses AttrSizedSegments.
+ } else if (op->hasTrait<AttrSizedSegmentsT>()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " * Extracting values from `" << attrSizedSegments << "`\n");
+
+ auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
+ if (!segmentAttr || segmentAttr.getNumElements() <= index)
+ return nullptr;
+
+ auto segments = segmentAttr.getValues<int32_t>();
+ unsigned startIndex =
+ std::accumulate(segments.begin(), segments.begin() + index, 0);
+ values = values.slice(startIndex, *std::next(segments.begin(), index));
+
+ LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
+ << *std::next(segments.begin(), index) << "]\n");
+
+ // Otherwise, assume this is the last operand group of the operation.
+ // FIXME: We currently don't support operations with
+ // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
+ // have a way to detect it's presence.
+ } else if (values.size() >= index) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " * Treating values as trailing variadic range\n");
+ values = values.drop_front(index);
+
+ // If we couldn't detect a way to compute the values, bail out.
+ } else {
+ return nullptr;
+ }
+
+ // If the range index is valid, we are returning a range.
+ if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
+ valueRangeMemory[rangeIndex] = values;
+ return &valueRangeMemory[rangeIndex];
+ }
+
+ // If a range index wasn't provided, the range is required to be non-variadic.
+ return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetOperands() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
+ unsigned index = read<uint32_t>();
+ Operation *op = read<Operation *>();
+ ByteCodeField rangeIndex = read();
+
+ void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
+ op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
+ valueRangeMemory);
+ if (!result)
+ LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
+ memory[read()] = result;
+}
+
void ByteCodeExecutor::executeGetResult(unsigned index) {
Operation *op = read<Operation *>();
unsigned memIndex = read();
@@ -1068,6 +1498,20 @@ void ByteCodeExecutor::executeGetResult(unsigned index) {
memory[memIndex] = result.getAsOpaquePointer();
}
+void ByteCodeExecutor::executeGetResults() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
+ unsigned index = read<uint32_t>();
+ Operation *op = read<Operation *>();
+ ByteCodeField rangeIndex = read();
+
+ void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
+ op->getResults(), op, index, rangeIndex, "result_segment_sizes",
+ valueRangeMemory);
+ if (!result)
+ LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
+ memory[read()] = result;
+}
+
void ByteCodeExecutor::executeGetValueType() {
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
unsigned memIndex = read();
@@ -1079,6 +1523,28 @@ void ByteCodeExecutor::executeGetValueType() {
memory[memIndex] = type.getAsOpaquePointer();
}
+void ByteCodeExecutor::executeGetValueRangeTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ ValueRange *values = read<ValueRange *>();
+ if (!values) {
+ LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
+ memory[memIndex] = nullptr;
+ return;
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Values (" << values->size() << "): ";
+ llvm::interleaveComma(*values, llvm::dbgs());
+ llvm::dbgs() << "\n * Result: ";
+ llvm::interleaveComma(values->getType(), llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ typeRangeMemory[rangeIndex] = values->getType();
+ memory[memIndex] = &typeRangeMemory[rangeIndex];
+}
+
void ByteCodeExecutor::executeIsNotNull() {
LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
const void *value = read<const void *>();
@@ -1117,7 +1583,30 @@ void ByteCodeExecutor::executeRecordMatch(
LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
<< " * Location: " << matchLoc << "\n");
matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
- readList<const void *>(matches.back().values);
+ PDLByteCode::MatchResult &match = matches.back();
+
+ // Record all of the inputs to the match. If any of the inputs are ranges, we
+ // will also need to remap the range pointer to memory stored in the match
+ // state.
+ unsigned numInputs = read();
+ match.values.reserve(numInputs);
+ match.typeRangeValues.reserve(numInputs);
+ match.valueRangeValues.reserve(numInputs);
+ for (unsigned i = 0; i < numInputs; ++i) {
+ switch (read<PDLValue::Kind>()) {
+ case PDLValue::Kind::TypeRange:
+ match.typeRangeValues.push_back(*read<TypeRange *>());
+ match.values.push_back(&match.typeRangeValues.back());
+ break;
+ case PDLValue::Kind::ValueRange:
+ match.valueRangeValues.push_back(*read<ValueRange *>());
+ match.values.push_back(&match.valueRangeValues.back());
+ break;
+ default:
+ match.values.push_back(read<const void *>());
+ break;
+ }
+ }
curCodeIt = dest;
}
@@ -1125,7 +1614,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
- readList<Value>(args);
+ readValueList(args);
LLVM_DEBUG({
llvm::dbgs() << " * Operation: " << *op << "\n"
@@ -1198,6 +1687,19 @@ void ByteCodeExecutor::executeSwitchType() {
handleSwitch(value, cases);
}
+void ByteCodeExecutor::executeSwitchTypes() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
+ TypeRange *value = read<TypeRange *>();
+ auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
+ if (!value) {
+ LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
+ return selectJump(size_t(0));
+ }
+ handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
+ return value == caseValue.getAsValueRange<TypeAttr>();
+ });
+}
+
void ByteCodeExecutor::execute(
PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> *matches,
@@ -1214,6 +1716,9 @@ void ByteCodeExecutor::execute(
case AreEqual:
executeAreEqual();
break;
+ case AreRangesEqual:
+ executeAreRangesEqual();
+ break;
case Branch:
executeBranch();
break;
@@ -1226,9 +1731,15 @@ void ByteCodeExecutor::execute(
case CheckResultCount:
executeCheckResultCount();
break;
+ case CheckTypes:
+ executeCheckTypes();
+ break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
+ case CreateTypes:
+ executeCreateTypes();
+ break;
case EraseOp:
executeEraseOp(rewriter);
break;
@@ -1257,6 +1768,9 @@ void ByteCodeExecutor::execute(
LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
executeGetOperand(read<uint32_t>());
break;
+ case GetOperands:
+ executeGetOperands();
+ break;
case GetResult0:
case GetResult1:
case GetResult2:
@@ -1270,9 +1784,15 @@ void ByteCodeExecutor::execute(
LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
executeGetResult(read<uint32_t>());
break;
+ case GetResults:
+ executeGetResults();
+ break;
case GetValueType:
executeGetValueType();
break;
+ case GetValueRangeTypes:
+ executeGetValueRangeTypes();
+ break;
case IsNotNull:
executeIsNotNull();
break;
@@ -1299,6 +1819,9 @@ void ByteCodeExecutor::execute(
case SwitchType:
executeSwitchType();
break;
+ case SwitchTypes:
+ executeSwitchTypes();
+ break;
}
LLVM_DEBUG(llvm::dbgs() << "\n");
}
@@ -1313,9 +1836,12 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
state.memory[0] = op;
// The matcher function always starts at code address 0.
- ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
- matcherByteCode, state.currentPatternBenefits,
- patterns, constraintFunctions, rewriteFunctions);
+ ByteCodeExecutor executor(
+ matcherByteCode.data(), state.memory, state.typeRangeMemory,
+ state.allocatedTypeRangeMemory, state.valueRangeMemory,
+ state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
+ state.currentPatternBenefits, patterns, constraintFunctions,
+ rewriteFunctions);
executor.execute(rewriter, &matches);
// Order the found matches by benefit.
@@ -1332,9 +1858,11 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
// memory buffer.
llvm::copy(match.values, state.memory.begin());
- ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
- state.memory, uniquedData, rewriterByteCode,
- state.currentPatternBenefits, patterns,
- constraintFunctions, rewriteFunctions);
+ ByteCodeExecutor executor(
+ &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
+ state.typeRangeMemory, state.allocatedTypeRangeMemory,
+ state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
+ rewriterByteCode, state.currentPatternBenefits, patterns,
+ constraintFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location);
}
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index f6a3bcbe54f9..c6f41be768de 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -25,8 +25,7 @@ namespace detail {
class PDLByteCode;
/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
-/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of
-/// indices into the bytecode. Correctness is checked with static asserts.
+/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
@@ -62,14 +61,16 @@ class PDLByteCodePattern : public Pattern {
/// threads/drivers.
class PDLByteCodeMutableState {
public:
- /// Initialize the state from a bytecode instance.
- void initialize(PDLByteCode &bytecode);
-
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
/// to the position of the pattern within the range returned by
/// `PDLByteCode::getPatterns`.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
+ /// Cleanup any allocated state after a match/rewrite has been completed. This
+ /// method should be called irregardless of whether the match+rewrite was a
+ /// success or not.
+ void cleanupAfterMatchAndRewrite();
+
private:
/// Allow access to data fields.
friend class PDLByteCode;
@@ -78,6 +79,20 @@ class PDLByteCodeMutableState {
/// of the bytecode.
std::vector<const void *> memory;
+ /// A mutable block of memory used during the matching and rewriting phase of
+ /// the bytecode to store ranges of types.
+ std::vector<TypeRange> typeRangeMemory;
+ /// A set of type ranges that have been allocated by the byte code interpreter
+ /// to provide a guaranteed lifetime.
+ std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
+
+ /// A mutable block of memory used during the matching and rewriting phase of
+ /// the bytecode to store ranges of values.
+ std::vector<ValueRange> valueRangeMemory;
+ /// A set of value ranges that have been allocated by the byte code
+ /// interpreter to provide a guaranteed lifetime.
+ std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
+
/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
@@ -98,11 +113,19 @@ class PDLByteCode {
MatchResult(Location loc, const PDLByteCodePattern &pattern,
PatternBenefit benefit)
: location(loc), pattern(&pattern), benefit(benefit) {}
+ MatchResult(const MatchResult &) = delete;
+ MatchResult &operator=(const MatchResult &) = delete;
+ MatchResult(MatchResult &&other) = default;
+ MatchResult &operator=(MatchResult &&) = default;
/// The location of operations to be replaced.
Location location;
/// Memory values defined in the matcher that are passed to the rewriter.
- SmallVector<const void *, 4> values;
+ SmallVector<const void *> values;
+ /// Memory used for the range input values.
+ SmallVector<TypeRange, 0> typeRangeValues;
+ SmallVector<ValueRange, 0> valueRangeValues;
+
/// The originating pattern that was matched. This is always non-null, but
/// represented with a pointer to allow for assignment.
const PDLByteCodePattern *pattern;
@@ -163,6 +186,10 @@ class PDLByteCode {
/// The maximum memory index used by a value.
ByteCodeField maxValueMemoryIndex = 0;
+
+ /// The maximum number of
diff erent types of ranges.
+ ByteCodeField maxTypeRangeCount = 0;
+ ByteCodeField maxValueRangeCount = 0;
};
} // end namespace detail
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 6f5e1f299f26..5032f0203257 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -129,29 +129,40 @@ LogicalResult PatternApplicator::matchAndRewrite(
// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
- auto opIt = opPatterns.begin(), opE = opPatterns.end();
- auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
- auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
- while (true) {
+ unsigned opIt = 0, opE = opPatterns.size();
+ unsigned anyIt = 0, anyE = anyOpPatterns.size();
+ unsigned pdlIt = 0, pdlE = pdlMatches.size();
+ LogicalResult result = failure();
+ do {
// Find the next pattern with the highest benefit.
const Pattern *bestPattern = nullptr;
+ unsigned *bestPatternIt = &opIt;
const PDLByteCode::MatchResult *pdlMatch = nullptr;
+
/// Operation specific patterns.
- if (opIt != opE)
- bestPattern = *(opIt++);
+ if (opIt < opE)
+ bestPattern = opPatterns[opIt];
/// Operation agnostic patterns.
- if (anyIt != anyE &&
- (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
- bestPattern = *(anyIt++);
+ if (anyIt < anyE &&
+ (!bestPattern ||
+ bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
+ bestPatternIt = &anyIt;
+ bestPattern = anyOpPatterns[anyIt];
+ }
/// PDL patterns.
- if (pdlIt != pdlE &&
- (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
- pdlMatch = pdlIt;
- bestPattern = (pdlIt++)->pattern;
+ if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
+ pdlMatches[pdlIt].benefit)) {
+ bestPatternIt = &pdlIt;
+ pdlMatch = &pdlMatches[pdlIt];
+ bestPattern = pdlMatch->pattern;
}
if (!bestPattern)
break;
+ // Update the pattern iterator on failure so that this pattern isn't
+ // attempted again.
+ ++(*bestPatternIt);
+
// Check that the pattern can be applied.
if (canApply && !canApply(*bestPattern))
continue;
@@ -160,19 +171,25 @@ LogicalResult PatternApplicator::matchAndRewrite(
// benefit, so if we match we can immediately rewrite. For PDL patterns, the
// match has already been performed, we just need to rewrite.
rewriter.setInsertionPoint(op);
- LogicalResult result = success();
if (pdlMatch) {
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+ result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
+
} else {
- result = static_cast<const RewritePattern *>(bestPattern)
- ->matchAndRewrite(op, rewriter);
+ const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
+ result = pattern->matchAndRewrite(op, rewriter);
+ if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
+ result = failure();
}
- if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
- return success();
+ if (succeeded(result))
+ break;
// Perform any necessary cleanups.
if (onFailure)
onFailure(*bestPattern);
- }
- return failure();
+ } while (true);
+
+ if (mutableByteCodeState)
+ mutableByteCodeState->cleanupAfterMatchAndRewrite();
+ return result;
}
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index b0acd328147a..d630fa2aa14d 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -40,6 +40,38 @@ module @ir attributes { test.apply_constraint_1 } {
// -----
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ pdl_interp.apply_constraint "multi_entity_var_constraint"(%results, %types : !pdl.range<value>, !pdl.range<type>) -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.replaced_by_pattern"
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_constraint_2
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"
+module @ir attributes { test.apply_constraint_2 } {
+ "test.failure_op"() { test_attr } : () -> ()
+ "test.success_op"() : () -> (i32, i64)
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::ApplyRewriteOp
//===----------------------------------------------------------------------===//
@@ -103,6 +135,68 @@ module @ir attributes { test.apply_rewrite_2 } {
// -----
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %operands, %types = pdl_interp.apply_rewrite "var_creator"(%root : !pdl.operation) : !pdl.range<value>, !pdl.range<type>
+ %op = pdl_interp.create_operation "test.success"(%operands : !pdl.range<value>) -> (%types : !pdl.range<type>)
+ pdl_interp.replace %root with (%operands : !pdl.range<value>)
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_rewrite_3
+// CHECK: %[[OPERAND:.*]] = "test.producer"
+// CHECK: "test.success"(%[[OPERAND]]) : (i32) -> i32
+// CHECK: "test.consumer"(%[[OPERAND]])
+module @ir attributes { test.apply_rewrite_3 } {
+ %first_operand = "test.producer"() : () -> (i32)
+ %operand = "test.op"(%first_operand) : (i32) -> (i32)
+ "test.consumer"(%operand) : (i32) -> ()
+}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %type = pdl_interp.apply_rewrite "type_creator" : !pdl.type
+ %newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_rewrite_4
+// CHECK: "test.success"() : () -> f32
+module @ir attributes { test.apply_rewrite_4 } {
+ "test.op"() : () -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::AreEqualOp
//===----------------------------------------------------------------------===//
@@ -137,6 +231,40 @@ module @ir attributes { test.are_equal_1 } {
// -----
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %const_types = pdl_interp.create_types [i32, i64]
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %result_types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ pdl_interp.are_equal %result_types, %const_types : !pdl.range<type> -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.are_equal_2
+// CHECK: "test.not_equal"
+// CHECK: "test.success"
+// CHECK-NOT: "test.op"
+module @ir attributes { test.are_equal_2 } {
+ "test.not_equal"() : () -> (i32)
+ "test.op"() : () -> (i32, i64)
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::BranchOp
//===----------------------------------------------------------------------===//
@@ -211,7 +339,10 @@ module @ir attributes { test.check_attribute_1 } {
module @patterns {
func @matcher(%root : !pdl.operation) {
- pdl_interp.check_operand_count of %root is 1 -> ^pat, ^end
+ pdl_interp.check_operand_count of %root is at_least 1 -> ^exact_check, ^end
+
+ ^exact_check:
+ pdl_interp.check_operand_count of %root is 2 -> ^pat, ^end
^pat:
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
@@ -234,7 +365,7 @@ module @patterns {
// CHECK: "test.success"
module @ir attributes { test.check_operand_count_1 } {
%operand = "test.op"() : () -> i32
- "test.op"(%operand) : (i32) -> ()
+ "test.op"(%operand, %operand) : (i32, i32) -> ()
}
// -----
@@ -277,7 +408,10 @@ module @ir attributes { test.check_operation_name_1 } {
module @patterns {
func @matcher(%root : !pdl.operation) {
- pdl_interp.check_result_count of %root is 1 -> ^pat, ^end
+ pdl_interp.check_result_count of %root is at_least 1 -> ^exact_check, ^end
+
+ ^exact_check:
+ pdl_interp.check_result_count of %root is 2 -> ^pat, ^end
^pat:
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
@@ -296,9 +430,12 @@ module @patterns {
}
// CHECK-LABEL: test.check_result_count_1
+// CHECK: "test.op"() : () -> i32
// CHECK: "test.success"() : () -> ()
+// CHECK-NOT: "test.op"() : () -> (i32, i32)
module @ir attributes { test.check_result_count_1 } {
"test.op"() : () -> i32
+ "test.op"() : () -> (i32, i32)
}
// -----
@@ -340,6 +477,43 @@ module @ir attributes { test.check_type_1 } {
// -----
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypesOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %result_types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ pdl_interp.check_types %result_types are [i32] -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.check_types_1
+// CHECK: "test.op"() : () -> (i32, i64)
+// CHECK: "test.success"
+// CHECK-NOT: "test.op"() : () -> i32
+module @ir attributes { test.check_types_1 } {
+ "test.op"() : () -> (i32, i64)
+ "test.op"() : () -> i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
@@ -390,6 +564,12 @@ module @ir attributes { test.create_type_1 } {
// -----
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypesOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
//===----------------------------------------------------------------------===//
// pdl_interp::EraseOp
//===----------------------------------------------------------------------===//
@@ -465,6 +645,110 @@ module @ir attributes { test.get_defining_op_1 } {
// Fully tested within the tests for other operations.
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandsOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end
+
+ ^pat1:
+ %operands = pdl_interp.get_operands 0 of %root : !pdl.range<value>
+ %full_operands = pdl_interp.get_operands of %root : !pdl.range<value>
+ pdl_interp.are_equal %operands, %full_operands : !pdl.range<value> -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_operands_1
+// CHECK: "test.success"
+module @ir attributes { test.get_operands_1 } {
+ %inputs:2 = "test.producer"() : () -> (i32, i32)
+ "test.op"(%inputs#0, %inputs#1) : (i32, i32) -> ()
+}
+
+// -----
+
+// Test all of the various combinations related to `AttrSizedOperandSegments`.
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.attr_sized_operands" -> ^pat1, ^end
+
+ ^pat1:
+ %operands_0 = pdl_interp.get_operands 0 of %root : !pdl.range<value>
+ pdl_interp.is_not_null %operands_0 : !pdl.range<value> -> ^pat2, ^end
+
+ ^pat2:
+ %operands_0_single = pdl_interp.get_operands 0 of %root : !pdl.value
+ pdl_interp.is_not_null %operands_0_single : !pdl.value -> ^end, ^pat3
+
+ ^pat3:
+ %operands_1 = pdl_interp.get_operands 1 of %root : !pdl.range<value>
+ pdl_interp.is_not_null %operands_1 : !pdl.range<value> -> ^pat4, ^end
+
+ ^pat4:
+ %operands_1_single = pdl_interp.get_operands 1 of %root : !pdl.value
+ pdl_interp.is_not_null %operands_1_single : !pdl.value -> ^end, ^pat5
+
+ ^pat5:
+ %operands_2 = pdl_interp.get_operands 2 of %root : !pdl.range<value>
+ pdl_interp.is_not_null %operands_2 : !pdl.range<value> -> ^pat6, ^end
+
+ ^pat6:
+ %operands_2_single = pdl_interp.get_operands 2 of %root : !pdl.value
+ pdl_interp.is_not_null %operands_2_single : !pdl.value -> ^pat7, ^end
+
+ ^pat7:
+ %invalid_operands = pdl_interp.get_operands 50 of %root : !pdl.value
+ pdl_interp.is_not_null %invalid_operands : !pdl.value -> ^end, ^pat8
+
+ ^pat8:
+ pdl_interp.record_match @rewriters::@success(%root, %operands_0, %operands_1, %operands_2, %operands_2_single : !pdl.operation, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.value) : benefit(1), loc([%root]) -> ^end
+
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root: !pdl.operation, %operands_0: !pdl.range<value>, %operands_1: !pdl.range<value>, %operands_2: !pdl.range<value>, %operands_2_single: !pdl.value) {
+ %op0 = pdl_interp.create_operation "test.success"(%operands_0 : !pdl.range<value>)
+ %op1 = pdl_interp.create_operation "test.success"(%operands_1 : !pdl.range<value>)
+ %op2 = pdl_interp.create_operation "test.success"(%operands_2 : !pdl.range<value>)
+ %op3 = pdl_interp.create_operation "test.success"(%operands_2_single : !pdl.value)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_operands_2
+// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32)
+// CHECK-NEXT: "test.success"() : () -> ()
+// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> ()
+// CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> ()
+// CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> ()
+module @ir attributes { test.get_operands_2 } {
+ %inputs:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32)
+ "test.attr_sized_operands"(%inputs#0, %inputs#1, %inputs#2, %inputs#3, %inputs#4) {operand_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : (i32, i32, i32, i32, i32) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetResultOp
//===----------------------------------------------------------------------===//
@@ -506,6 +790,119 @@ module @ir attributes { test.get_result_1 } {
// -----
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultsOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end
+
+ ^pat1:
+ %results = pdl_interp.get_results 0 of %root : !pdl.range<value>
+ %full_results = pdl_interp.get_results of %root : !pdl.range<value>
+ pdl_interp.are_equal %results, %full_results : !pdl.range<value> -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_results_1
+// CHECK: "test.success"
+module @ir attributes { test.get_results_1 } {
+ %a:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32)
+}
+
+// -----
+
+// Test all of the various combinations related to `AttrSizedResultSegments`.
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.attr_sized_results" -> ^pat1, ^end
+
+ ^pat1:
+ %results_0 = pdl_interp.get_results 0 of %root : !pdl.range<value>
+ pdl_interp.is_not_null %results_0 : !pdl.range<value> -> ^pat2, ^end
+
+ ^pat2:
+ %results_0_single = pdl_interp.get_results 0 of %root : !pdl.value
+ pdl_interp.is_not_null %results_0_single : !pdl.value -> ^end, ^pat3
+
+ ^pat3:
+ %results_1 = pdl_interp.get_results 1 of %root : !pdl.range<value>
+ pdl_interp.is_not_null %results_1 : !pdl.range<value> -> ^pat4, ^end
+
+ ^pat4:
+ %results_1_single = pdl_interp.get_results 1 of %root : !pdl.value
+ pdl_interp.is_not_null %results_1_single : !pdl.value -> ^end, ^pat5
+
+ ^pat5:
+ %results_2 = pdl_interp.get_results 2 of %root : !pdl.range<value>
+ pdl_interp.is_not_null %results_2 : !pdl.range<value> -> ^pat6, ^end
+
+ ^pat6:
+ %results_2_single = pdl_interp.get_results 2 of %root : !pdl.value
+ pdl_interp.is_not_null %results_2_single : !pdl.value -> ^pat7, ^end
+
+ ^pat7:
+ %invalid_results = pdl_interp.get_results 50 of %root : !pdl.value
+ pdl_interp.is_not_null %invalid_results : !pdl.value -> ^end, ^pat8
+
+ ^pat8:
+ pdl_interp.record_match @rewriters::@success(%root, %results_0, %results_1, %results_2, %results_2_single : !pdl.operation, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.value) : benefit(1), loc([%root]) -> ^end
+
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root: !pdl.operation, %results_0: !pdl.range<value>, %results_1: !pdl.range<value>, %results_2: !pdl.range<value>, %results_2_single: !pdl.value) {
+ %results_0_types = pdl_interp.get_value_type of %results_0 : !pdl.range<type>
+ %results_1_types = pdl_interp.get_value_type of %results_1 : !pdl.range<type>
+ %results_2_types = pdl_interp.get_value_type of %results_2 : !pdl.range<type>
+ %results_2_single_types = pdl_interp.get_value_type of %results_2_single : !pdl.type
+
+ %op0 = pdl_interp.create_operation "test.success" -> (%results_0_types : !pdl.range<type>)
+ %op1 = pdl_interp.create_operation "test.success" -> (%results_1_types : !pdl.range<type>)
+ %op2 = pdl_interp.create_operation "test.success" -> (%results_2_types : !pdl.range<type>)
+ %op3 = pdl_interp.create_operation "test.success" -> (%results_2_single_types : !pdl.type)
+
+ %new_results_0 = pdl_interp.get_results of %op0 : !pdl.range<value>
+ %new_results_1 = pdl_interp.get_results of %op1 : !pdl.range<value>
+ %new_results_2 = pdl_interp.get_results of %op2 : !pdl.range<value>
+
+ pdl_interp.replace %root with (%new_results_0, %new_results_1, %new_results_2 : !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_results_2
+// CHECK: "test.success"() : () -> ()
+// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32)
+// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32
+// CHECK: %[[RESULTS_2_SINGLE:.*]] = "test.success"() : () -> i32
+// CHECK: "test.consumer"(%[[RESULTS_1]]#0, %[[RESULTS_1]]#1, %[[RESULTS_1]]#2, %[[RESULTS_1]]#3, %[[RESULTS_2]]) : (i32, i32, i32, i32, i32) -> ()
+module @ir attributes { test.get_results_2 } {
+ %results:5 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : () -> (i32, i32, i32, i32, i32)
+ "test.consumer"(%results#0, %results#1, %results#2, %results#3, %results#4) : (i32, i32, i32, i32, i32) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//
@@ -564,6 +961,43 @@ module @ir attributes { test.record_match_1 } {
// -----
+// Check that ranges are properly forwarded to the result.
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end
+
+ ^pat1:
+ %operands = pdl_interp.get_operands of %root : !pdl.range<value>
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ pdl_interp.record_match @rewriters::@success(%operands, %types, %root : !pdl.range<value>, !pdl.range<type>, !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%operands: !pdl.range<value>, %types: !pdl.range<type>, %root: !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"(%operands : !pdl.range<value>) -> (%types : !pdl.range<type>)
+ %results = pdl_interp.get_results of %op : !pdl.range<value>
+ pdl_interp.replace %root with (%results : !pdl.range<value>)
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.record_match_2
+// CHECK: %[[OPERAND:.*]] = "test.producer"() : () -> i32
+// CHECK: %[[RESULTS:.*]]:2 = "test.success"(%[[OPERAND]]) : (i32) -> (i64, i32)
+// CHECK: "test.consumer"(%[[RESULTS]]#0, %[[RESULTS]]#1) : (i64, i32) -> ()
+module @ir attributes { test.record_match_2 } {
+ %input = "test.producer"() : () -> i32
+ %results:2 = "test.op"(%input) : (i32) -> (i64, i32)
+ "test.consumer"(%results#0, %results#1) : (i64, i32) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::ReplaceOp
//===----------------------------------------------------------------------===//
@@ -780,3 +1214,40 @@ module @patterns {
module @ir attributes { test.switch_type_1 } {
"test.op"() { test_attr = 10 : i32 } : () -> ()
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypesOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ pdl_interp.switch_types %types to [[i64, i64], [i32]](^pat2, ^end) -> ^end
+
+ ^pat2:
+ pdl_interp.switch_types %types to [[i32], [i64, i32]](^end, ^end) -> ^pat3
+
+ ^pat3:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.switch_types_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_types_1 } {
+ %results:2 = "test.op"() : () -> (i64, i64)
+}
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index e60022ba94cc..bc45d7b083aa 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -24,6 +24,18 @@ static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
PatternRewriter &rewriter) {
return customSingleEntityConstraint(values[1], constantParams, rewriter);
}
+static LogicalResult
+customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
+ ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
+ return failure();
+ ValueRange operandValues = values[0].cast<ValueRange>();
+ TypeRange typeValues = values[1].cast<TypeRange>();
+ if (operandValues.size() != 2 || typeValues.size() != 2)
+ return failure();
+ return success();
+}
// Custom creator invoked from PDL.
static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
@@ -31,6 +43,19 @@ static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
results.push_back(rewriter.createOperation(
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
}
+static void customVariadicResultCreate(ArrayRef<PDLValue> args,
+ ArrayAttr constantParams,
+ PatternRewriter &rewriter,
+ PDLResultList &results) {
+ Operation *root = args[0].cast<Operation *>();
+ results.push_back(root->getOperands());
+ results.push_back(root->getOperands().getTypes());
+}
+static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+ PatternRewriter &rewriter,
+ PDLResultList &results) {
+ results.push_back(rewriter.getF32Type());
+}
/// Custom rewriter invoked from PDL.
static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
@@ -63,7 +88,12 @@ struct TestPDLByteCodePass
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
+ pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
+ customMultiEntityVariadicConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
+ pdlPattern.registerRewriteFunction("var_creator",
+ customVariadicResultCreate);
+ pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
OwningRewritePatternList patternList(std::move(pdlPattern));
More information about the Mlir-commits
mailing list