[Mlir-commits] [mlir] 85ab413 - [mlir][PDL] Add support for variadic operands and results in the PDL byte code

River Riddle llvmlistbot at llvm.org
Tue Mar 16 13:20:37 PDT 2021


Author: River Riddle
Date: 2021-03-16T13:20:19-07:00
New Revision: 85ab413b53aeb135eb58dab066afcbf20bef0cf8

URL: https://github.com/llvm/llvm-project/commit/85ab413b53aeb135eb58dab066afcbf20bef0cf8
DIFF: https://github.com/llvm/llvm-project/commit/85ab413b53aeb135eb58dab066afcbf20bef0cf8.diff

LOG: [mlir][PDL] Add support for variadic operands and results in the PDL byte code

Supporting ranges in the byte code requires additional complexity, given that a range can't be easily representable as an opaque void *, as is possible with the existing bytecode value types (Attribute, Type, Value, etc.). To enable representing a range with void *, an auxillary storage is used for the actual range itself, with the pointer being passed around in the normal byte code memory. For type ranges, a TypeRange is stored. For value ranges, a ValueRange is stored. The above problem represents a majority of the complexity involved in this revision, the rest is adapting/adding byte code operations to support the changes made to the PDL interpreter in the parent revision.

After this revision, PDL will have initial end-to-end support for variadic operands/results.

Differential Revision: https://reviews.llvm.org/D95723

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/IR/TypeRange.h
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Rewrite/ByteCode.h
    mlir/lib/Rewrite/PatternApplicator.cpp
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index e35208747ade..ff9b3dda0f48 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -963,7 +963,7 @@ def PDLInterp_SwitchOperandCountOp
 
   let builders = [
     OpBuilder<(ins "Value":$operation, "ArrayRef<int32_t>":$counts,
-      "Block *":$defaultDest, "BlockRange":$dests), [{
+                   "Block *":$defaultDest, "BlockRange":$dests), [{
     build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts),
           defaultDest, dests);
   }]>];
@@ -1033,7 +1033,7 @@ def PDLInterp_SwitchResultCountOp
 
   let builders = [
     OpBuilder<(ins "Value":$operation, "ArrayRef<int32_t>":$counts,
-      "Block *":$defaultDest, "BlockRange":$dests), [{
+                   "Block *":$defaultDest, "BlockRange":$dests), [{
     build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts),
           defaultDest, dests);
   }]>];

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 56da9b870948..c797f5329bd5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -238,63 +238,92 @@ struct OpRewritePattern : public RewritePattern {
 /// Storage type of byte-code interpreter values. These are passed to constraint
 /// functions as arguments.
 class PDLValue {
-  /// The internal implementation type when the value is an Attribute,
-  /// Operation*, or Type. See `impl` below for more details.
-  using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>;
-
 public:
-  PDLValue(const PDLValue &other) : impl(other.impl) {}
-  PDLValue(std::nullptr_t = nullptr) : impl() {}
-  PDLValue(Attribute value) : impl(value) {}
-  PDLValue(Operation *value) : impl(value) {}
-  PDLValue(Type value) : impl(value) {}
-  PDLValue(Value value) : impl(value) {}
+  /// The underlying kind of a PDL value.
+  enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
+
+  /// Construct a new PDL value.
+  PDLValue(const PDLValue &other) = default;
+  PDLValue(std::nullptr_t = nullptr) : value(nullptr), kind(Kind::Attribute) {}
+  PDLValue(Attribute value)
+      : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
+  PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
+  PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
+  PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
+  PDLValue(Value value)
+      : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
+  PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
 
   /// Returns true if the type of the held value is `T`.
-  template <typename T>
-  std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const {
-    return impl.is<Value>();
-  }
-  template <typename T>
-  std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const {
-    auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
-    return attrOpTypeImpl && attrOpTypeImpl.is<T>();
+  template <typename T> bool isa() const {
+    assert(value && "isa<> used on a null value");
+    return kind == getKindOf<T>();
   }
 
   /// Attempt to dynamically cast this value to type `T`, returns null if this
   /// value is not an instance of `T`.
-  template <typename T>
-  std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const {
-    return impl.dyn_cast<T>();
-  }
-  template <typename T>
-  std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const {
-    auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
-    return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>();
+  template <typename T,
+            typename ResultT = std::conditional_t<
+                std::is_convertible<T, bool>::value, T, Optional<T>>>
+  ResultT dyn_cast() const {
+    return isa<T>() ? castImpl<T>() : ResultT();
   }
 
   /// Cast this value to type `T`, asserts if this value is not an instance of
   /// `T`.
-  template <typename T>
-  std::enable_if_t<std::is_same<T, Value>::value, T> cast() const {
-    return impl.get<T>();
-  }
-  template <typename T>
-  std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const {
-    return impl.get<AttrOpTypeImplT>().get<T>();
+  template <typename T> T cast() const {
+    assert(isa<T>() && "expected value to be of type `T`");
+    return castImpl<T>();
   }
 
   /// Get an opaque pointer to the value.
-  void *getAsOpaquePointer() { return impl.getOpaqueValue(); }
+  const void *getAsOpaquePointer() const { return value; }
+
+  /// Return if this value is null or not.
+  explicit operator bool() const { return value; }
+
+  /// Return the kind of this value.
+  Kind getKind() const { return kind; }
 
   /// Print this value to the provided output stream.
-  void print(raw_ostream &os);
+  void print(raw_ostream &os) const;
 
 private:
-  /// The internal opaque representation of a PDLValue. We use a nested
-  /// PointerUnion structure here because `Value` only has 1 low bit
-  /// available, where as the remaining types all have 3.
-  llvm::PointerUnion<AttrOpTypeImplT, Value> impl;
+  /// Find the index of a given type in a range of other types.
+  template <typename...> struct index_of_t;
+  template <typename T, typename... R>
+  struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
+  template <typename T, typename F, typename... R>
+  struct index_of_t<T, F, R...>
+      : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
+
+  /// Return the kind used for the given T.
+  template <typename T> static Kind getKindOf() {
+    return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
+                                        TypeRange, Value, ValueRange>::value);
+  }
+
+  /// The internal implementation of `cast`, that returns the underlying value
+  /// as the given type `T`.
+  template <typename T>
+  std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
+  castImpl() const {
+    return T::getFromOpaquePointer(value);
+  }
+  template <typename T>
+  std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
+  castImpl() const {
+    return *reinterpret_cast<T *>(const_cast<void *>(value));
+  }
+  template <typename T>
+  std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
+    return reinterpret_cast<T>(const_cast<void *>(value));
+  }
+
+  /// The internal opaque representation of a PDLValue.
+  const void *value;
+  /// The kind of the opaque value.
+  Kind kind;
 };
 
 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
@@ -319,14 +348,66 @@ class PDLResultList {
   /// Push a new Type onto the result list.
   void push_back(Type value) { results.push_back(value); }
 
+  /// Push a new TypeRange onto the result list.
+  void push_back(TypeRange value) {
+    // The lifetime of a TypeRange can't be guaranteed, so we'll need to
+    // allocate a storage for it.
+    llvm::OwningArrayRef<Type> storage(value.size());
+    llvm::copy(value, storage.begin());
+    allocatedTypeRanges.emplace_back(std::move(storage));
+    typeRanges.push_back(allocatedTypeRanges.back());
+    results.push_back(&typeRanges.back());
+  }
+  void push_back(ValueTypeRange<OperandRange> value) {
+    typeRanges.push_back(value);
+    results.push_back(&typeRanges.back());
+  }
+  void push_back(ValueTypeRange<ResultRange> value) {
+    typeRanges.push_back(value);
+    results.push_back(&typeRanges.back());
+  }
+
   /// Push a new Value onto the result list.
   void push_back(Value value) { results.push_back(value); }
 
+  /// Push a new ValueRange onto the result list.
+  void push_back(ValueRange value) {
+    // The lifetime of a ValueRange can't be guaranteed, so we'll need to
+    // allocate a storage for it.
+    llvm::OwningArrayRef<Value> storage(value.size());
+    llvm::copy(value, storage.begin());
+    allocatedValueRanges.emplace_back(std::move(storage));
+    valueRanges.push_back(allocatedValueRanges.back());
+    results.push_back(&valueRanges.back());
+  }
+  void push_back(OperandRange value) {
+    valueRanges.push_back(value);
+    results.push_back(&valueRanges.back());
+  }
+  void push_back(ResultRange value) {
+    valueRanges.push_back(value);
+    results.push_back(&valueRanges.back());
+  }
+
 protected:
-  PDLResultList() = default;
+  /// Create a new result list with the expected number of results.
+  PDLResultList(unsigned maxNumResults) {
+    // For now just reserve enough space for all of the results. We could do
+    // separate counts per range type, but it isn't really worth it unless there
+    // are a "large" number of results.
+    typeRanges.reserve(maxNumResults);
+    valueRanges.reserve(maxNumResults);
+  }
 
   /// The PDL results held by this list.
   SmallVector<PDLValue> results;
+  /// Memory used to store ranges held by the list.
+  SmallVector<TypeRange> typeRanges;
+  SmallVector<ValueRange> valueRanges;
+  /// Memory allocated to store ranges in the result list whose lifetime was
+  /// generated in the native function.
+  SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
+  SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index fe11fde58793..4fb40e127f9f 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -82,6 +82,12 @@ inline ::llvm::hash_code hash_value(TypeRange arg) {
   return ::llvm::hash_combine_range(arg.begin(), arg.end());
 }
 
+/// Emit a type range to the given output stream.
+inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
+  llvm::interleaveComma(types, os);
+  return os;
+}
+
 //===----------------------------------------------------------------------===//
 // ValueTypeRange
 

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 034698d85cb1..354d5f31bf74 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -73,22 +73,31 @@ void RewritePattern::anchor() {}
 // PDLValue
 //===----------------------------------------------------------------------===//
 
-void PDLValue::print(raw_ostream &os) {
-  if (!impl) {
-    os << "<Null-PDLValue>";
+void PDLValue::print(raw_ostream &os) const {
+  if (!value) {
+    os << "<NULL-PDLValue>";
     return;
   }
-  if (Value val = impl.dyn_cast<Value>()) {
-    os << val;
-    return;
+  switch (kind) {
+  case Kind::Attribute:
+    os << cast<Attribute>();
+    break;
+  case Kind::Operation:
+    os << *cast<Operation *>();
+    break;
+  case Kind::Type:
+    os << cast<Type>();
+    break;
+  case Kind::TypeRange:
+    llvm::interleaveComma(cast<TypeRange>(), os);
+    break;
+  case Kind::Value:
+    os << cast<Value>();
+    break;
+  case Kind::ValueRange:
+    llvm::interleaveComma(cast<ValueRange>(), os);
+    break;
   }
-  AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>();
-  if (Attribute attr = aotImpl.dyn_cast<Attribute>())
-    os << attr;
-  else if (Operation *op = aotImpl.dyn_cast<Operation *>())
-    os << *op;
-  else
-    os << aotImpl.get<Type>();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ef96e25c7be3..ea17f99deb9c 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -20,6 +20,9 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <numeric>
 
 #define DEBUG_TYPE "pdl-bytecode"
 
@@ -60,6 +63,14 @@ void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
   currentPatternBenefits[patternIndex] = benefit;
 }
 
+/// Cleanup any allocated state after a full match/rewrite has been completed.
+/// This method should be called irregardless of whether the match+rewrite was a
+/// success or not.
+void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
+  allocatedTypeRangeMemory.clear();
+  allocatedValueRangeMemory.clear();
+}
+
 //===----------------------------------------------------------------------===//
 // Bytecode OpCodes
 //===----------------------------------------------------------------------===//
@@ -72,6 +83,8 @@ enum OpCode : ByteCodeField {
   ApplyRewrite,
   /// Check if two generic values are equal.
   AreEqual,
+  /// Check if two ranges are equal.
+  AreRangesEqual,
   /// Unconditional branch.
   Branch,
   /// Compare the operand count of an operation with a constant.
@@ -80,8 +93,12 @@ enum OpCode : ByteCodeField {
   CheckOperationName,
   /// Compare the result count of an operation with a constant.
   CheckResultCount,
+  /// Compare a range of types to a constant range of types.
+  CheckTypes,
   /// Create an operation.
   CreateOperation,
+  /// Create a range of types.
+  CreateTypes,
   /// Erase an operation.
   EraseOp,
   /// Terminate a matcher or rewrite sequence.
@@ -98,14 +115,20 @@ enum OpCode : ByteCodeField {
   GetOperand2,
   GetOperand3,
   GetOperandN,
+  /// Get a specific operand group of an operation.
+  GetOperands,
   /// Get a specific result of an operation.
   GetResult0,
   GetResult1,
   GetResult2,
   GetResult3,
   GetResultN,
+  /// Get a specific result group of an operation.
+  GetResults,
   /// Get the type of a value.
   GetValueType,
+  /// Get the types of a value range.
+  GetValueRangeTypes,
   /// Check if a generic value is not null.
   IsNotNull,
   /// Record a successful pattern match.
@@ -122,9 +145,9 @@ enum OpCode : ByteCodeField {
   SwitchResultCount,
   /// Compare a type with a set of constants.
   SwitchType,
+  /// Compare a range of types with a set of constants.
+  SwitchTypes,
 };
-
-enum class PDLValueKind { Attribute, Operation, Type, Value };
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -145,11 +168,15 @@ class Generator {
             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
             SmallVectorImpl<PDLByteCodePattern> &patterns,
             ByteCodeField &maxValueMemoryIndex,
+            ByteCodeField &maxTypeRangeMemoryIndex,
+            ByteCodeField &maxValueRangeMemoryIndex,
             llvm::StringMap<PDLConstraintFunction> &constraintFns,
             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
         rewriterByteCode(rewriterByteCode), patterns(patterns),
-        maxValueMemoryIndex(maxValueMemoryIndex) {
+        maxValueMemoryIndex(maxValueMemoryIndex),
+        maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
+        maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
     for (auto it : llvm::enumerate(constraintFns))
       constraintToMemIndex.try_emplace(it.value().first(), it.index());
     for (auto it : llvm::enumerate(rewriteFns))
@@ -166,6 +193,13 @@ class Generator {
     return valueToMemIndex[value];
   }
 
+  /// Return the range memory index used to store the given range value.
+  ByteCodeField &getRangeStorageIndex(Value value) {
+    assert(valueToRangeIndex.count(value) &&
+           "expected range index to be assigned");
+    return valueToRangeIndex[value];
+  }
+
   /// Return an index to use when referring to the given data that is uniqued in
   /// the MLIR context.
   template <typename T>
@@ -197,16 +231,20 @@ class Generator {
   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
@@ -214,6 +252,7 @@ class Generator {
   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
+  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
@@ -221,6 +260,9 @@ class Generator {
   /// Mapping from value to its corresponding memory index.
   DenseMap<Value, ByteCodeField> valueToMemIndex;
 
+  /// Mapping from a range value to its corresponding range storage index.
+  DenseMap<Value, ByteCodeField> valueToRangeIndex;
+
   /// Mapping from the name of an externally registered rewrite to its index in
   /// the bytecode registry.
   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
@@ -246,6 +288,8 @@ class Generator {
   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
   SmallVectorImpl<PDLByteCodePattern> &patterns;
   ByteCodeField &maxValueMemoryIndex;
+  ByteCodeField &maxTypeRangeMemoryIndex;
+  ByteCodeField &maxValueRangeMemoryIndex;
 };
 
 /// This class provides utilities for writing a bytecode stream.
@@ -281,19 +325,33 @@ struct ByteCodeWriter {
   /// Append a range of values that will be read as generic PDLValues.
   void appendPDLValueList(OperandRange values) {
     bytecode.push_back(values.size());
-    for (Value value : values) {
-      // Append the type of the value in addition to the value itself.
-      PDLValueKind kind =
-          TypeSwitch<Type, PDLValueKind>(value.getType())
-              .Case<pdl::AttributeType>(
-                  [](Type) { return PDLValueKind::Attribute; })
-              .Case<pdl::OperationType>(
-                  [](Type) { return PDLValueKind::Operation; })
-              .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
-              .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
-      bytecode.push_back(static_cast<ByteCodeField>(kind));
-      append(value);
-    }
+    for (Value value : values)
+      appendPDLValue(value);
+  }
+
+  /// Append a value as a PDLValue.
+  void appendPDLValue(Value value) {
+    appendPDLValueKind(value);
+    append(value);
+  }
+
+  /// Append the PDLValue::Kind of the given value.
+  void appendPDLValueKind(Value value) {
+    // Append the type of the value in addition to the value itself.
+    PDLValue::Kind kind =
+        TypeSwitch<Type, PDLValue::Kind>(value.getType())
+            .Case<pdl::AttributeType>(
+                [](Type) { return PDLValue::Kind::Attribute; })
+            .Case<pdl::OperationType>(
+                [](Type) { return PDLValue::Kind::Operation; })
+            .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
+              if (rangeTy.getElementType().isa<pdl::TypeType>())
+                return PDLValue::Kind::TypeRange;
+              return PDLValue::Kind::ValueRange;
+            })
+            .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
+            .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
+    bytecode.push_back(static_cast<ByteCodeField>(kind));
   }
 
   /// Check if the given class `T` has an iterator type.
@@ -334,6 +392,36 @@ struct ByteCodeWriter {
   /// The main generator producing PDL.
   Generator &generator;
 };
+
+/// This class represents a live range of PDL Interpreter values, containing
+/// information about when values are live within a match/rewrite.
+struct ByteCodeLiveRange {
+  using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
+  using Allocator = Set::Allocator;
+
+  ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
+
+  /// Union this live range with the one provided.
+  void unionWith(const ByteCodeLiveRange &rhs) {
+    for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
+      liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+  }
+
+  /// Returns true if this range overlaps with the one provided.
+  bool overlaps(const ByteCodeLiveRange &rhs) const {
+    return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
+  }
+
+  /// A map representing the ranges of the match/rewrite that a value is live in
+  /// the interpreter.
+  llvm::IntervalMap<ByteCodeField, char, 16> liveness;
+
+  /// The type range storage index for this range.
+  Optional<unsigned> typeRangeIndex;
+
+  /// The value range storage index for this range.
+  Optional<unsigned> valueRangeIndex;
+};
 } // end anonymous namespace
 
 void Generator::generate(ModuleOp module) {
@@ -381,15 +469,30 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
   // Rewriters use simplistic allocation scheme that simply assigns an index to
   // each result.
   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
-    ByteCodeField index = 0;
+    ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
+    auto processRewriterValue = [&](Value val) {
+      valueToMemIndex.try_emplace(val, index++);
+      if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
+        Type elementTy = rangeType.getElementType();
+        if (elementTy.isa<pdl::TypeType>())
+          valueToRangeIndex.try_emplace(val, typeRangeIndex++);
+        else if (elementTy.isa<pdl::ValueType>())
+          valueToRangeIndex.try_emplace(val, valueRangeIndex++);
+      }
+    };
+
     for (BlockArgument arg : rewriterFunc.getArguments())
-      valueToMemIndex.try_emplace(arg, index++);
+      processRewriterValue(arg);
     rewriterFunc.getBody().walk([&](Operation *op) {
       for (Value result : op->getResults())
-        valueToMemIndex.try_emplace(result, index++);
+        processRewriterValue(result);
     });
     if (index > maxValueMemoryIndex)
       maxValueMemoryIndex = index;
+    if (typeRangeIndex > maxTypeRangeMemoryIndex)
+      maxTypeRangeMemoryIndex = typeRangeIndex;
+    if (valueRangeIndex > maxValueRangeMemoryIndex)
+      maxValueRangeMemoryIndex = valueRangeIndex;
   }
 
   // The matcher function uses a more sophisticated numbering that tries to
@@ -404,9 +507,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
   });
 
   // Liveness info for each of the defs within the matcher.
-  using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
-  LivenessSet::Allocator allocator;
-  DenseMap<Value, LivenessSet> valueDefRanges;
+  ByteCodeLiveRange::Allocator allocator;
+  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
 
   // Assign the root operation being matched to slot 0.
   BlockArgument rootOpArg = matcherFunc.getArgument(0);
@@ -425,10 +527,19 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
 
       // Set indices for the range of this block that the value is used.
       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
-      defRangeIt->second.insert(
+      defRangeIt->second.liveness.insert(
           opToIndex[firstUseOrDef],
           opToIndex[info->getEndOperation(value, firstUseOrDef)],
           /*dummyValue*/ 0);
+
+      // Check to see if this value is a range type.
+      if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
+        Type eleType = rangeTy.getElementType();
+        if (eleType.isa<pdl::TypeType>())
+          defRangeIt->second.typeRangeIndex = 0;
+        else if (eleType.isa<pdl::ValueType>())
+          defRangeIt->second.valueRangeIndex = 0;
+      }
     };
 
     // Process the live-ins of this block.
@@ -442,37 +553,59 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
   }
 
   // Greedily allocate memory slots using the computed def live ranges.
-  std::vector<LivenessSet> allocatedIndices;
+  std::vector<ByteCodeLiveRange> allocatedIndices;
+  ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
   for (auto &defIt : valueDefRanges) {
     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
-    LivenessSet &defSet = defIt.second;
+    ByteCodeLiveRange &defRange = defIt.second;
 
     // Try to allocate to an existing index.
     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
-      LivenessSet &existingIndex = existingIndexIt.value();
-      llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
-          defIt.second, existingIndex);
-      if (overlaps.valid())
-        continue;
-      // Union the range of the def within the existing index.
-      for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
-        existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
-      memIndex = existingIndexIt.index() + 1;
+      ByteCodeLiveRange &existingRange = existingIndexIt.value();
+      if (!defRange.overlaps(existingRange)) {
+        existingRange.unionWith(defRange);
+        memIndex = existingIndexIt.index() + 1;
+
+        if (defRange.typeRangeIndex) {
+          if (!existingRange.typeRangeIndex)
+            existingRange.typeRangeIndex = numTypeRanges++;
+          valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
+        } else if (defRange.valueRangeIndex) {
+          if (!existingRange.valueRangeIndex)
+            existingRange.valueRangeIndex = numValueRanges++;
+          valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
+        }
+        break;
+      }
     }
 
     // If no existing index could be used, add a new one.
     if (memIndex == 0) {
       allocatedIndices.emplace_back(allocator);
-      for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
-        allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
+      ByteCodeLiveRange &newRange = allocatedIndices.back();
+      newRange.unionWith(defRange);
+
+      // Allocate an index for type/value ranges.
+      if (defRange.typeRangeIndex) {
+        newRange.typeRangeIndex = numTypeRanges;
+        valueToRangeIndex[defIt.first] = numTypeRanges++;
+      } else if (defRange.valueRangeIndex) {
+        newRange.valueRangeIndex = numValueRanges;
+        valueToRangeIndex[defIt.first] = numValueRanges++;
+      }
+
       memIndex = allocatedIndices.size();
+      ++numIndices;
     }
   }
 
   // Update the max number of indices.
-  ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
-  if (numMatcherIndices > maxValueMemoryIndex)
-    maxValueMemoryIndex = numMatcherIndices;
+  if (numIndices > maxValueMemoryIndex)
+    maxValueMemoryIndex = numIndices;
+  if (numTypeRanges > maxTypeRangeMemoryIndex)
+    maxTypeRangeMemoryIndex = numTypeRanges;
+  if (numValueRanges > maxValueRangeMemoryIndex)
+    maxValueRangeMemoryIndex = numValueRanges;
 }
 
 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
@@ -481,17 +614,19 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
-            pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
-            pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+            pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
+            pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
+            pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
-            pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
+            pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
+            pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
-            pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
-            pdl_interp::SwitchResultCountOp>(
+            pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
+            pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
           [&](auto interpOp) { this->generate(interpOp, writer); })
       .Default([](Operation *) {
         llvm_unreachable("unknown `pdl_interp` operation");
@@ -515,16 +650,31 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
                 op.constParamsAttr());
   writer.appendPDLValueList(op.args());
 
+  ResultRange results = op.results();
+  writer.append(ByteCodeField(results.size()));
+  for (Value result : results) {
+    // In debug mode we also record the expected kind of the result, so that we
+    // can provide extra verification of the native rewrite function.
 #ifndef NDEBUG
-  // In debug mode we also append the number of results so that we can assert
-  // that the native creation function gave us the correct number of results.
-  writer.append(ByteCodeField(op.results().size()));
+    writer.appendPDLValueKind(result);
 #endif
-  for (Value result : op.results())
+
+    // Range results also need to append the range storage index.
+    if (result.getType().isa<pdl::RangeType>())
+      writer.append(getRangeStorageIndex(result));
     writer.append(result);
+  }
 }
 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
-  writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
+  Value lhs = op.lhs();
+  if (lhs.getType().isa<pdl::RangeType>()) {
+    writer.append(OpCode::AreRangesEqual);
+    writer.appendPDLValueKind(lhs);
+    writer.append(op.lhs(), op.rhs(), op.getSuccessors());
+    return;
+  }
+
+  writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
 }
 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
@@ -537,6 +687,7 @@ void Generator::generate(pdl_interp::CheckAttributeOp op,
 void Generator::generate(pdl_interp::CheckOperandCountOp op,
                          ByteCodeWriter &writer) {
   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
+                static_cast<ByteCodeField>(op.compareAtLeast()),
                 op.getSuccessors());
 }
 void Generator::generate(pdl_interp::CheckOperationNameOp op,
@@ -547,11 +698,15 @@ void Generator::generate(pdl_interp::CheckOperationNameOp op,
 void Generator::generate(pdl_interp::CheckResultCountOp op,
                          ByteCodeWriter &writer) {
   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
+                static_cast<ByteCodeField>(op.compareAtLeast()),
                 op.getSuccessors());
 }
 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
 }
+void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
+}
 void Generator::generate(pdl_interp::CreateAttributeOp op,
                          ByteCodeWriter &writer) {
   // Simply repoint the memory index of the result to the constant.
@@ -560,7 +715,8 @@ void Generator::generate(pdl_interp::CreateAttributeOp op,
 void Generator::generate(pdl_interp::CreateOperationOp op,
                          ByteCodeWriter &writer) {
   writer.append(OpCode::CreateOperation, op.operation(),
-                OperationName(op.name(), ctx), op.operands());
+                OperationName(op.name(), ctx));
+  writer.appendPDLValueList(op.operands());
 
   // Add the attributes.
   OperandRange attributes = op.attributes();
@@ -570,12 +726,16 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
         std::get<1>(it));
   }
-  writer.append(op.types());
+  writer.appendPDLValueList(op.types());
 }
 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
   // Simply repoint the memory index of the result to the constant.
   getMemIndex(op.result()) = getMemIndex(op.value());
 }
+void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::CreateTypes, op.result(),
+                getRangeStorageIndex(op.result()), op.value());
+}
 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
   writer.append(OpCode::EraseOp, op.operation());
 }
@@ -593,7 +753,8 @@ void Generator::generate(pdl_interp::GetAttributeTypeOp op,
 }
 void Generator::generate(pdl_interp::GetDefiningOpOp op,
                          ByteCodeWriter &writer) {
-  writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
+  writer.append(OpCode::GetDefiningOp, op.operation());
+  writer.appendPDLValue(op.value());
 }
 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
   uint32_t index = op.index();
@@ -603,6 +764,18 @@ void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
     writer.append(OpCode::GetOperandN, index);
   writer.append(op.operation(), op.value());
 }
+void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
+  Value result = op.value();
+  Optional<uint32_t> index = op.index();
+  writer.append(OpCode::GetOperands,
+                index.getValueOr(std::numeric_limits<uint32_t>::max()),
+                op.operation());
+  if (result.getType().isa<pdl::RangeType>())
+    writer.append(getRangeStorageIndex(result));
+  else
+    writer.append(std::numeric_limits<ByteCodeField>::max());
+  writer.append(result);
+}
 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
   uint32_t index = op.index();
   if (index < 4)
@@ -611,10 +784,29 @@ void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
     writer.append(OpCode::GetResultN, index);
   writer.append(op.operation(), op.value());
 }
+void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
+  Value result = op.value();
+  Optional<uint32_t> index = op.index();
+  writer.append(OpCode::GetResults,
+                index.getValueOr(std::numeric_limits<uint32_t>::max()),
+                op.operation());
+  if (result.getType().isa<pdl::RangeType>())
+    writer.append(getRangeStorageIndex(result));
+  else
+    writer.append(std::numeric_limits<ByteCodeField>::max());
+  writer.append(result);
+}
 void Generator::generate(pdl_interp::GetValueTypeOp op,
                          ByteCodeWriter &writer) {
-  writer.append(OpCode::GetValueType, op.result(), op.value());
+  if (op.getType().isa<pdl::RangeType>()) {
+    Value result = op.result();
+    writer.append(OpCode::GetValueRangeTypes, result,
+                  getRangeStorageIndex(result), op.value());
+  } else {
+    writer.append(OpCode::GetValueType, op.result(), op.value());
+  }
 }
+
 void Generator::generate(pdl_interp::InferredTypesOp op,
                          ByteCodeWriter &writer) {
   // InferType maps to a null type as a marker for inferring result types.
@@ -628,11 +820,12 @@ void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
   patterns.emplace_back(PDLByteCodePattern::create(
       op, rewriterToAddr[op.rewriter().getLeafReference()]));
   writer.append(OpCode::RecordMatch, patternIndex,
-                SuccessorRange(op.getOperation()), op.matchedOps(),
-                op.inputs());
+                SuccessorRange(op.getOperation()), op.matchedOps());
+  writer.appendPDLValueList(op.inputs());
 }
 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
-  writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
+  writer.append(OpCode::ReplaceOp, op.operation());
+  writer.appendPDLValueList(op.replValues());
 }
 void Generator::generate(pdl_interp::SwitchAttributeOp op,
                          ByteCodeWriter &writer) {
@@ -661,6 +854,10 @@ void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
                 op.getSuccessors());
 }
+void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
+  writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
+                op.getSuccessors());
+}
 
 //===----------------------------------------------------------------------===//
 // PDLByteCode
@@ -671,7 +868,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
   Generator generator(module.getContext(), uniquedData, matcherByteCode,
                       rewriterByteCode, patterns, maxValueMemoryIndex,
-                      constraintFns, rewriteFns);
+                      maxTypeRangeCount, maxValueRangeCount, constraintFns,
+                      rewriteFns);
   generator.generate(module);
 
   // Initialize the external functions.
@@ -685,6 +883,8 @@ PDLByteCode::PDLByteCode(ModuleOp module,
 /// bytecode.
 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
   state.memory.resize(maxValueMemoryIndex, nullptr);
+  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
+  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
   state.currentPatternBenefits.reserve(patterns.size());
   for (const PDLByteCodePattern &pattern : patterns)
     state.currentPatternBenefits.push_back(pattern.getBenefit());
@@ -697,17 +897,24 @@ namespace {
 /// This class provides support for executing a bytecode stream.
 class ByteCodeExecutor {
 public:
-  ByteCodeExecutor(const ByteCodeField *curCodeIt,
-                   MutableArrayRef<const void *> memory,
-                   ArrayRef<const void *> uniquedMemory,
-                   ArrayRef<ByteCodeField> code,
-                   ArrayRef<PatternBenefit> currentPatternBenefits,
-                   ArrayRef<PDLByteCodePattern> patterns,
-                   ArrayRef<PDLConstraintFunction> constraintFunctions,
-                   ArrayRef<PDLRewriteFunction> rewriteFunctions)
-      : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
-        code(code), currentPatternBenefits(currentPatternBenefits),
-        patterns(patterns), constraintFunctions(constraintFunctions),
+  ByteCodeExecutor(
+      const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
+      MutableArrayRef<TypeRange> typeRangeMemory,
+      std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
+      MutableArrayRef<ValueRange> valueRangeMemory,
+      std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
+      ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
+      ArrayRef<PatternBenefit> currentPatternBenefits,
+      ArrayRef<PDLByteCodePattern> patterns,
+      ArrayRef<PDLConstraintFunction> constraintFunctions,
+      ArrayRef<PDLRewriteFunction> rewriteFunctions)
+      : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
+        allocatedTypeRangeMemory(allocatedTypeRangeMemory),
+        valueRangeMemory(valueRangeMemory),
+        allocatedValueRangeMemory(allocatedValueRangeMemory),
+        uniquedMemory(uniquedMemory), code(code),
+        currentPatternBenefits(currentPatternBenefits), patterns(patterns),
+        constraintFunctions(constraintFunctions),
         rewriteFunctions(rewriteFunctions) {}
 
   /// Start executing the code at the current bytecode index. `matches` is an
@@ -722,19 +929,25 @@ class ByteCodeExecutor {
   void executeApplyConstraint(PatternRewriter &rewriter);
   void executeApplyRewrite(PatternRewriter &rewriter);
   void executeAreEqual();
+  void executeAreRangesEqual();
   void executeBranch();
   void executeCheckOperandCount();
   void executeCheckOperationName();
   void executeCheckResultCount();
+  void executeCheckTypes();
   void executeCreateOperation(PatternRewriter &rewriter,
                               Location mainRewriteLoc);
+  void executeCreateTypes();
   void executeEraseOp(PatternRewriter &rewriter);
   void executeGetAttribute();
   void executeGetAttributeType();
   void executeGetDefiningOp();
   void executeGetOperand(unsigned index);
+  void executeGetOperands();
   void executeGetResult(unsigned index);
+  void executeGetResults();
   void executeGetValueType();
+  void executeGetValueRangeTypes();
   void executeIsNotNull();
   void executeRecordMatch(PatternRewriter &rewriter,
                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
@@ -744,6 +957,7 @@ class ByteCodeExecutor {
   void executeSwitchOperationName();
   void executeSwitchResultCount();
   void executeSwitchType();
+  void executeSwitchTypes();
 
   /// Read a value from the bytecode buffer, optionally skipping a certain
   /// number of prefix values. These methods always update the buffer to point
@@ -763,6 +977,19 @@ class ByteCodeExecutor {
       list.push_back(read<ValueT>());
   }
 
+  /// Read a list of values from the bytecode buffer. The values may be encoded
+  /// as either Value or ValueRange elements.
+  void readValueList(SmallVectorImpl<Value> &list) {
+    for (unsigned i = 0, e = read(); i != e; ++i) {
+      if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+        list.push_back(read<Value>());
+      } else {
+        ValueRange *values = read<ValueRange *>();
+        list.append(values->begin(), values->end());
+      }
+    }
+  }
+
   /// Jump to a specific successor based on a predicate value.
   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
   /// Jump to a specific successor based on a destination index.
@@ -771,8 +998,8 @@ class ByteCodeExecutor {
   }
 
   /// Handle a switch operation with the provided value and cases.
-  template <typename T, typename RangeT>
-  void handleSwitch(const T &value, RangeT &&cases) {
+  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
+  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
     LLVM_DEBUG({
       llvm::dbgs() << "  * Value: " << value << "\n"
                    << "  * Cases: ";
@@ -783,7 +1010,7 @@ class ByteCodeExecutor {
     // Check to see if the attribute value is within the case list. Jump to
     // the correct successor index based on the result.
     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
-      if (*it == value)
+      if (cmp(*it, value))
         return selectJump(size_t((it - cases.begin()) + 1));
     selectJump(size_t(0));
   }
@@ -795,7 +1022,9 @@ class ByteCodeExecutor {
     size_t index = *curCodeIt++;
 
     // If this type is an SSA value, it can only be stored in non-const memory.
-    if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
+    if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
+                        Value>::value ||
+        index < memory.size())
       return memory[index];
 
     // Otherwise, if this index is not inbounds it is uniqued.
@@ -813,17 +1042,21 @@ class ByteCodeExecutor {
   }
   template <typename T>
   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
-    switch (static_cast<PDLValueKind>(read())) {
-    case PDLValueKind::Attribute:
+    switch (read<PDLValue::Kind>()) {
+    case PDLValue::Kind::Attribute:
       return read<Attribute>();
-    case PDLValueKind::Operation:
+    case PDLValue::Kind::Operation:
       return read<Operation *>();
-    case PDLValueKind::Type:
+    case PDLValue::Kind::Type:
       return read<Type>();
-    case PDLValueKind::Value:
+    case PDLValue::Kind::Value:
       return read<Value>();
+    case PDLValue::Kind::TypeRange:
+      return read<TypeRange *>();
+    case PDLValue::Kind::ValueRange:
+      return read<ValueRange *>();
     }
-    llvm_unreachable("unhandled PDLValueKind");
+    llvm_unreachable("unhandled PDLValue::Kind");
   }
   template <typename T>
   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
@@ -838,12 +1071,20 @@ class ByteCodeExecutor {
   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
     return *curCodeIt++;
   }
+  template <typename T>
+  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
+    return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
+  }
 
   /// The underlying bytecode buffer.
   const ByteCodeField *curCodeIt;
 
   /// The current execution memory.
   MutableArrayRef<const void *> memory;
+  MutableArrayRef<TypeRange> typeRangeMemory;
+  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
+  MutableArrayRef<ValueRange> valueRangeMemory;
+  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
 
   /// References to ByteCode data necessary for execution.
   ArrayRef<const void *> uniquedMemory;
@@ -859,8 +1100,21 @@ class ByteCodeExecutor {
 /// overexposing access to information specific solely to the ByteCode.
 class ByteCodeRewriteResultList : public PDLResultList {
 public:
+  ByteCodeRewriteResultList(unsigned maxNumResults)
+      : PDLResultList(maxNumResults) {}
+
   /// Return the list of PDL results.
   MutableArrayRef<PDLValue> getResults() { return results; }
+
+  /// Return the type ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+    return allocatedTypeRanges;
+  }
+
+  /// Return the value ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+    return allocatedValueRanges;
+  }
 };
 } // end anonymous namespace
 
@@ -893,21 +1147,46 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
     llvm::interleaveComma(args, llvm::dbgs());
     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
   });
-  ByteCodeRewriteResultList results;
+
+  // Execute the rewrite function.
+  ByteCodeField numResults = read();
+  ByteCodeRewriteResultList results(numResults);
   rewriteFn(args, constParams, rewriter, results);
 
-  // Store the results in the bytecode memory.
-#ifndef NDEBUG
-  ByteCodeField expectedNumberOfResults = read();
-  assert(results.getResults().size() == expectedNumberOfResults &&
+  assert(results.getResults().size() == numResults &&
          "native PDL rewrite function returned unexpected number of results");
-#endif
 
   // Store the results in the bytecode memory.
   for (PDLValue &result : results.getResults()) {
     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
-    memory[read()] = result.getAsOpaquePointer();
+
+// In debug mode we also verify the expected kind of the result.
+#ifndef NDEBUG
+    assert(result.getKind() == read<PDLValue::Kind>() &&
+           "native PDL rewrite function returned an unexpected type of result");
+#endif
+
+    // If the result is a range, we need to copy it over to the bytecodes
+    // range memory.
+    if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
+      unsigned rangeIndex = read();
+      typeRangeMemory[rangeIndex] = *typeRange;
+      memory[read()] = &typeRangeMemory[rangeIndex];
+    } else if (Optional<ValueRange> valueRange =
+                   result.dyn_cast<ValueRange>()) {
+      unsigned rangeIndex = read();
+      valueRangeMemory[rangeIndex] = *valueRange;
+      memory[read()] = &valueRangeMemory[rangeIndex];
+    } else {
+      memory[read()] = result.getAsOpaquePointer();
+    }
   }
+
+  // Copy over any underlying storage allocated for result ranges.
+  for (auto &it : results.getAllocatedTypeRanges())
+    allocatedTypeRangeMemory.push_back(std::move(it));
+  for (auto &it : results.getAllocatedValueRanges())
+    allocatedValueRangeMemory.push_back(std::move(it));
 }
 
 void ByteCodeExecutor::executeAreEqual() {
@@ -919,6 +1198,32 @@ void ByteCodeExecutor::executeAreEqual() {
   selectJump(lhs == rhs);
 }
 
+void ByteCodeExecutor::executeAreRangesEqual() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
+  PDLValue::Kind valueKind = read<PDLValue::Kind>();
+  const void *lhs = read<const void *>();
+  const void *rhs = read<const void *>();
+
+  switch (valueKind) {
+  case PDLValue::Kind::TypeRange: {
+    const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
+    const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
+    LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
+    selectJump(*lhsRange == *rhsRange);
+    break;
+  }
+  case PDLValue::Kind::ValueRange: {
+    const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
+    const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
+    LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
+    selectJump(*lhsRange == *rhsRange);
+    break;
+  }
+  default:
+    llvm_unreachable("unexpected `AreRangesEqual` value kind");
+  }
+}
+
 void ByteCodeExecutor::executeBranch() {
   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
   curCodeIt = &code[read<ByteCodeAddr>()];
@@ -928,10 +1233,16 @@ void ByteCodeExecutor::executeCheckOperandCount() {
   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
   Operation *op = read<Operation *>();
   uint32_t expectedCount = read<uint32_t>();
+  bool compareAtLeast = read();
 
   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
-                          << "  * Expected: " << expectedCount << "\n");
-  selectJump(op->getNumOperands() == expectedCount);
+                          << "  * Expected: " << expectedCount << "\n"
+                          << "  * Comparator: "
+                          << (compareAtLeast ? ">=" : "==") << "\n");
+  if (compareAtLeast)
+    selectJump(op->getNumOperands() >= expectedCount);
+  else
+    selectJump(op->getNumOperands() == expectedCount);
 }
 
 void ByteCodeExecutor::executeCheckOperationName() {
@@ -948,10 +1259,44 @@ void ByteCodeExecutor::executeCheckResultCount() {
   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
   Operation *op = read<Operation *>();
   uint32_t expectedCount = read<uint32_t>();
+  bool compareAtLeast = read();
 
   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
-                          << "  * Expected: " << expectedCount << "\n");
-  selectJump(op->getNumResults() == expectedCount);
+                          << "  * Expected: " << expectedCount << "\n"
+                          << "  * Comparator: "
+                          << (compareAtLeast ? ">=" : "==") << "\n");
+  if (compareAtLeast)
+    selectJump(op->getNumResults() >= expectedCount);
+  else
+    selectJump(op->getNumResults() == expectedCount);
+}
+
+void ByteCodeExecutor::executeCheckTypes() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+  TypeRange *lhs = read<TypeRange *>();
+  Attribute rhs = read<Attribute>();
+  LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
+
+  selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
+}
+
+void ByteCodeExecutor::executeCreateTypes() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
+  unsigned memIndex = read();
+  unsigned rangeIndex = read();
+  ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
+
+  LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
+
+  // Allocate a buffer for this type range.
+  llvm::OwningArrayRef<Type> storage(typesAttr.size());
+  llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
+  allocatedTypeRangeMemory.emplace_back(std::move(storage));
+
+  // Assign this to the range slot and use the range as the value for the
+  // memory index.
+  typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
+  memory[memIndex] = &typeRangeMemory[rangeIndex];
 }
 
 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@@ -960,22 +1305,26 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
 
   unsigned memIndex = read();
   OperationState state(mainRewriteLoc, read<OperationName>());
-  readList<Value>(state.operands);
+  readValueList(state.operands);
   for (unsigned i = 0, e = read(); i != e; ++i) {
     Identifier name = read<Identifier>();
     if (Attribute attr = read<Attribute>())
       state.addAttribute(name, attr);
   }
 
-  bool hasInferredTypes = false;
   for (unsigned i = 0, e = read(); i != e; ++i) {
-    Type resultType = read<Type>();
-    hasInferredTypes |= !resultType;
-    state.types.push_back(resultType);
-  }
+    if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+      state.types.push_back(read<Type>());
+      continue;
+    }
+
+    // If we find a null range, this signals that the types are infered.
+    if (TypeRange *resultTypes = read<TypeRange *>()) {
+      state.types.append(resultTypes->begin(), resultTypes->end());
+      continue;
+    }
 
-  // Handle the case where the operation has inferred types.
-  if (hasInferredTypes) {
+    // Handle the case where the operation has inferred types.
     InferTypeOpInterface::Concept *concept =
         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
 
@@ -986,7 +1335,9 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
             state.attributes.getDictionary(state.getContext()), state.regions,
             state.types)))
       return;
+    break;
   }
+
   Operation *resultOp = rewriter.createOperation(state);
   memory[memIndex] = resultOp;
 
@@ -1036,11 +1387,21 @@ void ByteCodeExecutor::executeGetAttributeType() {
 void ByteCodeExecutor::executeGetDefiningOp() {
   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
   unsigned memIndex = read();
-  Value value = read<Value>();
-  Operation *op = value ? value.getDefiningOp() : nullptr;
+  Operation *op = nullptr;
+  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
+    Value value = read<Value>();
+    if (value)
+      op = value.getDefiningOp();
+    LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
+  } else {
+    ValueRange *values = read<ValueRange *>();
+    if (values && !values->empty()) {
+      op = values->front().getDefiningOp();
+    }
+    LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
+  }
 
-  LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
-                          << "  * Result: " << *op << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
   memory[memIndex] = op;
 }
 
@@ -1056,6 +1417,75 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) {
   memory[memIndex] = operand.getAsOpaquePointer();
 }
 
+/// This function is the internal implementation of `GetResults` and
+/// `GetOperands` that provides support for extracting a value range from the
+/// given operation.
+template <template <typename> class AttrSizedSegmentsT, typename RangeT>
+static void *
+executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
+                          ByteCodeField rangeIndex, StringRef attrSizedSegments,
+                          MutableArrayRef<ValueRange> &valueRangeMemory) {
+  // Check for the sentinel index that signals that all values should be
+  // returned.
+  if (index == std::numeric_limits<uint32_t>::max()) {
+    LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
+    // `values` is already the full value range.
+
+    // Otherwise, check to see if this operation uses AttrSizedSegments.
+  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "  * Extracting values from `" << attrSizedSegments << "`\n");
+
+    auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
+    if (!segmentAttr || segmentAttr.getNumElements() <= index)
+      return nullptr;
+
+    auto segments = segmentAttr.getValues<int32_t>();
+    unsigned startIndex =
+        std::accumulate(segments.begin(), segments.begin() + index, 0);
+    values = values.slice(startIndex, *std::next(segments.begin(), index));
+
+    LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
+                            << *std::next(segments.begin(), index) << "]\n");
+
+    // Otherwise, assume this is the last operand group of the operation.
+    // FIXME: We currently don't support operations with
+    // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
+    // have a way to detect it's presence.
+  } else if (values.size() >= index) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "  * Treating values as trailing variadic range\n");
+    values = values.drop_front(index);
+
+    // If we couldn't detect a way to compute the values, bail out.
+  } else {
+    return nullptr;
+  }
+
+  // If the range index is valid, we are returning a range.
+  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
+    valueRangeMemory[rangeIndex] = values;
+    return &valueRangeMemory[rangeIndex];
+  }
+
+  // If a range index wasn't provided, the range is required to be non-variadic.
+  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
+}
+
+void ByteCodeExecutor::executeGetOperands() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
+  unsigned index = read<uint32_t>();
+  Operation *op = read<Operation *>();
+  ByteCodeField rangeIndex = read();
+
+  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
+      op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
+      valueRangeMemory);
+  if (!result)
+    LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
+  memory[read()] = result;
+}
+
 void ByteCodeExecutor::executeGetResult(unsigned index) {
   Operation *op = read<Operation *>();
   unsigned memIndex = read();
@@ -1068,6 +1498,20 @@ void ByteCodeExecutor::executeGetResult(unsigned index) {
   memory[memIndex] = result.getAsOpaquePointer();
 }
 
+void ByteCodeExecutor::executeGetResults() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
+  unsigned index = read<uint32_t>();
+  Operation *op = read<Operation *>();
+  ByteCodeField rangeIndex = read();
+
+  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
+      op->getResults(), op, index, rangeIndex, "result_segment_sizes",
+      valueRangeMemory);
+  if (!result)
+    LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
+  memory[read()] = result;
+}
+
 void ByteCodeExecutor::executeGetValueType() {
   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
   unsigned memIndex = read();
@@ -1079,6 +1523,28 @@ void ByteCodeExecutor::executeGetValueType() {
   memory[memIndex] = type.getAsOpaquePointer();
 }
 
+void ByteCodeExecutor::executeGetValueRangeTypes() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
+  unsigned memIndex = read();
+  unsigned rangeIndex = read();
+  ValueRange *values = read<ValueRange *>();
+  if (!values) {
+    LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
+    memory[memIndex] = nullptr;
+    return;
+  }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "  * Values (" << values->size() << "): ";
+    llvm::interleaveComma(*values, llvm::dbgs());
+    llvm::dbgs() << "\n  * Result: ";
+    llvm::interleaveComma(values->getType(), llvm::dbgs());
+    llvm::dbgs() << "\n";
+  });
+  typeRangeMemory[rangeIndex] = values->getType();
+  memory[memIndex] = &typeRangeMemory[rangeIndex];
+}
+
 void ByteCodeExecutor::executeIsNotNull() {
   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
   const void *value = read<const void *>();
@@ -1117,7 +1583,30 @@ void ByteCodeExecutor::executeRecordMatch(
   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
                           << "  * Location: " << matchLoc << "\n");
   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
-  readList<const void *>(matches.back().values);
+  PDLByteCode::MatchResult &match = matches.back();
+
+  // Record all of the inputs to the match. If any of the inputs are ranges, we
+  // will also need to remap the range pointer to memory stored in the match
+  // state.
+  unsigned numInputs = read();
+  match.values.reserve(numInputs);
+  match.typeRangeValues.reserve(numInputs);
+  match.valueRangeValues.reserve(numInputs);
+  for (unsigned i = 0; i < numInputs; ++i) {
+    switch (read<PDLValue::Kind>()) {
+    case PDLValue::Kind::TypeRange:
+      match.typeRangeValues.push_back(*read<TypeRange *>());
+      match.values.push_back(&match.typeRangeValues.back());
+      break;
+    case PDLValue::Kind::ValueRange:
+      match.valueRangeValues.push_back(*read<ValueRange *>());
+      match.values.push_back(&match.valueRangeValues.back());
+      break;
+    default:
+      match.values.push_back(read<const void *>());
+      break;
+    }
+  }
   curCodeIt = dest;
 }
 
@@ -1125,7 +1614,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
   Operation *op = read<Operation *>();
   SmallVector<Value, 16> args;
-  readList<Value>(args);
+  readValueList(args);
 
   LLVM_DEBUG({
     llvm::dbgs() << "  * Operation: " << *op << "\n"
@@ -1198,6 +1687,19 @@ void ByteCodeExecutor::executeSwitchType() {
   handleSwitch(value, cases);
 }
 
+void ByteCodeExecutor::executeSwitchTypes() {
+  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
+  TypeRange *value = read<TypeRange *>();
+  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
+  if (!value) {
+    LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
+    return selectJump(size_t(0));
+  }
+  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
+    return value == caseValue.getAsValueRange<TypeAttr>();
+  });
+}
+
 void ByteCodeExecutor::execute(
     PatternRewriter &rewriter,
     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
@@ -1214,6 +1716,9 @@ void ByteCodeExecutor::execute(
     case AreEqual:
       executeAreEqual();
       break;
+    case AreRangesEqual:
+      executeAreRangesEqual();
+      break;
     case Branch:
       executeBranch();
       break;
@@ -1226,9 +1731,15 @@ void ByteCodeExecutor::execute(
     case CheckResultCount:
       executeCheckResultCount();
       break;
+    case CheckTypes:
+      executeCheckTypes();
+      break;
     case CreateOperation:
       executeCreateOperation(rewriter, *mainRewriteLoc);
       break;
+    case CreateTypes:
+      executeCreateTypes();
+      break;
     case EraseOp:
       executeEraseOp(rewriter);
       break;
@@ -1257,6 +1768,9 @@ void ByteCodeExecutor::execute(
       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
       executeGetOperand(read<uint32_t>());
       break;
+    case GetOperands:
+      executeGetOperands();
+      break;
     case GetResult0:
     case GetResult1:
     case GetResult2:
@@ -1270,9 +1784,15 @@ void ByteCodeExecutor::execute(
       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
       executeGetResult(read<uint32_t>());
       break;
+    case GetResults:
+      executeGetResults();
+      break;
     case GetValueType:
       executeGetValueType();
       break;
+    case GetValueRangeTypes:
+      executeGetValueRangeTypes();
+      break;
     case IsNotNull:
       executeIsNotNull();
       break;
@@ -1299,6 +1819,9 @@ void ByteCodeExecutor::execute(
     case SwitchType:
       executeSwitchType();
       break;
+    case SwitchTypes:
+      executeSwitchTypes();
+      break;
     }
     LLVM_DEBUG(llvm::dbgs() << "\n");
   }
@@ -1313,9 +1836,12 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
   state.memory[0] = op;
 
   // The matcher function always starts at code address 0.
-  ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
-                            matcherByteCode, state.currentPatternBenefits,
-                            patterns, constraintFunctions, rewriteFunctions);
+  ByteCodeExecutor executor(
+      matcherByteCode.data(), state.memory, state.typeRangeMemory,
+      state.allocatedTypeRangeMemory, state.valueRangeMemory,
+      state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
+      state.currentPatternBenefits, patterns, constraintFunctions,
+      rewriteFunctions);
   executor.execute(rewriter, &matches);
 
   // Order the found matches by benefit.
@@ -1332,9 +1858,11 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
   // memory buffer.
   llvm::copy(match.values, state.memory.begin());
 
-  ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
-                            state.memory, uniquedData, rewriterByteCode,
-                            state.currentPatternBenefits, patterns,
-                            constraintFunctions, rewriteFunctions);
+  ByteCodeExecutor executor(
+      &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
+      state.typeRangeMemory, state.allocatedTypeRangeMemory,
+      state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
+      rewriterByteCode, state.currentPatternBenefits, patterns,
+      constraintFunctions, rewriteFunctions);
   executor.execute(rewriter, /*matches=*/nullptr, match.location);
 }

diff  --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index f6a3bcbe54f9..c6f41be768de 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -25,8 +25,7 @@ namespace detail {
 class PDLByteCode;
 
 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode
-/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of
-/// indices into the bytecode. Correctness is checked with static asserts.
+/// entries. ByteCodeAddr refers to size of indices into the bytecode.
 using ByteCodeField = uint16_t;
 using ByteCodeAddr = uint32_t;
 
@@ -62,14 +61,16 @@ class PDLByteCodePattern : public Pattern {
 /// threads/drivers.
 class PDLByteCodeMutableState {
 public:
-  /// Initialize the state from a bytecode instance.
-  void initialize(PDLByteCode &bytecode);
-
   /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
   /// to the position of the pattern within the range returned by
   /// `PDLByteCode::getPatterns`.
   void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
 
+  /// Cleanup any allocated state after a match/rewrite has been completed. This
+  /// method should be called irregardless of whether the match+rewrite was a
+  /// success or not.
+  void cleanupAfterMatchAndRewrite();
+
 private:
   /// Allow access to data fields.
   friend class PDLByteCode;
@@ -78,6 +79,20 @@ class PDLByteCodeMutableState {
   /// of the bytecode.
   std::vector<const void *> memory;
 
+  /// A mutable block of memory used during the matching and rewriting phase of
+  /// the bytecode to store ranges of types.
+  std::vector<TypeRange> typeRangeMemory;
+  /// A set of type ranges that have been allocated by the byte code interpreter
+  /// to provide a guaranteed lifetime.
+  std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
+
+  /// A mutable block of memory used during the matching and rewriting phase of
+  /// the bytecode to store ranges of values.
+  std::vector<ValueRange> valueRangeMemory;
+  /// A set of value ranges that have been allocated by the byte code
+  /// interpreter to provide a guaranteed lifetime.
+  std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
+
   /// The up-to-date benefits of the patterns held by the bytecode. The order
   /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
   std::vector<PatternBenefit> currentPatternBenefits;
@@ -98,11 +113,19 @@ class PDLByteCode {
     MatchResult(Location loc, const PDLByteCodePattern &pattern,
                 PatternBenefit benefit)
         : location(loc), pattern(&pattern), benefit(benefit) {}
+    MatchResult(const MatchResult &) = delete;
+    MatchResult &operator=(const MatchResult &) = delete;
+    MatchResult(MatchResult &&other) = default;
+    MatchResult &operator=(MatchResult &&) = default;
 
     /// The location of operations to be replaced.
     Location location;
     /// Memory values defined in the matcher that are passed to the rewriter.
-    SmallVector<const void *, 4> values;
+    SmallVector<const void *> values;
+    /// Memory used for the range input values.
+    SmallVector<TypeRange, 0> typeRangeValues;
+    SmallVector<ValueRange, 0> valueRangeValues;
+
     /// The originating pattern that was matched. This is always non-null, but
     /// represented with a pointer to allow for assignment.
     const PDLByteCodePattern *pattern;
@@ -163,6 +186,10 @@ class PDLByteCode {
 
   /// The maximum memory index used by a value.
   ByteCodeField maxValueMemoryIndex = 0;
+
+  /// The maximum number of 
diff erent types of ranges.
+  ByteCodeField maxTypeRangeCount = 0;
+  ByteCodeField maxValueRangeCount = 0;
 };
 
 } // end namespace detail

diff  --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 6f5e1f299f26..5032f0203257 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -129,29 +129,40 @@ LogicalResult PatternApplicator::matchAndRewrite(
 
   // Process the patterns for that match the specific operation type, and any
   // operation type in an interleaved fashion.
-  auto opIt = opPatterns.begin(), opE = opPatterns.end();
-  auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
-  auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
-  while (true) {
+  unsigned opIt = 0, opE = opPatterns.size();
+  unsigned anyIt = 0, anyE = anyOpPatterns.size();
+  unsigned pdlIt = 0, pdlE = pdlMatches.size();
+  LogicalResult result = failure();
+  do {
     // Find the next pattern with the highest benefit.
     const Pattern *bestPattern = nullptr;
+    unsigned *bestPatternIt = &opIt;
     const PDLByteCode::MatchResult *pdlMatch = nullptr;
+
     /// Operation specific patterns.
-    if (opIt != opE)
-      bestPattern = *(opIt++);
+    if (opIt < opE)
+      bestPattern = opPatterns[opIt];
     /// Operation agnostic patterns.
-    if (anyIt != anyE &&
-        (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
-      bestPattern = *(anyIt++);
+    if (anyIt < anyE &&
+        (!bestPattern ||
+         bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
+      bestPatternIt = &anyIt;
+      bestPattern = anyOpPatterns[anyIt];
+    }
     /// PDL patterns.
-    if (pdlIt != pdlE &&
-        (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
-      pdlMatch = pdlIt;
-      bestPattern = (pdlIt++)->pattern;
+    if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
+                                             pdlMatches[pdlIt].benefit)) {
+      bestPatternIt = &pdlIt;
+      pdlMatch = &pdlMatches[pdlIt];
+      bestPattern = pdlMatch->pattern;
     }
     if (!bestPattern)
       break;
 
+    // Update the pattern iterator on failure so that this pattern isn't
+    // attempted again.
+    ++(*bestPatternIt);
+
     // Check that the pattern can be applied.
     if (canApply && !canApply(*bestPattern))
       continue;
@@ -160,19 +171,25 @@ LogicalResult PatternApplicator::matchAndRewrite(
     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
     // match has already been performed, we just need to rewrite.
     rewriter.setInsertionPoint(op);
-    LogicalResult result = success();
     if (pdlMatch) {
       bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+      result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
+
     } else {
-      result = static_cast<const RewritePattern *>(bestPattern)
-                   ->matchAndRewrite(op, rewriter);
+      const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
+      result = pattern->matchAndRewrite(op, rewriter);
+      if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
+        result = failure();
     }
-    if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
-      return success();
+    if (succeeded(result))
+      break;
 
     // Perform any necessary cleanups.
     if (onFailure)
       onFailure(*bestPattern);
-  }
-  return failure();
+  } while (true);
+
+  if (mutableByteCodeState)
+    mutableByteCodeState->cleanupAfterMatchAndRewrite();
+  return result;
 }

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index b0acd328147a..d630fa2aa14d 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -40,6 +40,38 @@ module @ir attributes { test.apply_constraint_1 } {
 
 // -----
 
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %results = pdl_interp.get_results of %root : !pdl.range<value>
+    %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+    pdl_interp.apply_constraint "multi_entity_var_constraint"(%results, %types : !pdl.range<value>, !pdl.range<type>) -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern"
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_2
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"
+module @ir attributes { test.apply_constraint_2 } {
+  "test.failure_op"() { test_attr } : () -> ()
+  "test.success_op"() : () -> (i32, i64)
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::ApplyRewriteOp
 //===----------------------------------------------------------------------===//
@@ -103,6 +135,68 @@ module @ir attributes { test.apply_rewrite_2 } {
 
 // -----
 
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %operands, %types = pdl_interp.apply_rewrite "var_creator"(%root : !pdl.operation) : !pdl.range<value>, !pdl.range<type>
+      %op = pdl_interp.create_operation "test.success"(%operands : !pdl.range<value>) -> (%types : !pdl.range<type>)
+      pdl_interp.replace %root with (%operands : !pdl.range<value>)
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_rewrite_3
+// CHECK: %[[OPERAND:.*]] = "test.producer"
+// CHECK: "test.success"(%[[OPERAND]]) : (i32) -> i32
+// CHECK: "test.consumer"(%[[OPERAND]])
+module @ir attributes { test.apply_rewrite_3 } {
+  %first_operand = "test.producer"() : () -> (i32)
+  %operand = "test.op"(%first_operand) : (i32) -> (i32)
+  "test.consumer"(%operand) : (i32) -> ()
+}
+
+// -----
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %type = pdl_interp.apply_rewrite "type_creator" : !pdl.type
+      %newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_rewrite_4
+// CHECK: "test.success"() : () -> f32
+module @ir attributes { test.apply_rewrite_4 } {
+  "test.op"() : () -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::AreEqualOp
 //===----------------------------------------------------------------------===//
@@ -137,6 +231,40 @@ module @ir attributes { test.are_equal_1 } {
 
 // -----
 
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %const_types = pdl_interp.create_types [i32, i64]
+    %results = pdl_interp.get_results of %root : !pdl.range<value>
+    %result_types = pdl_interp.get_value_type of %results : !pdl.range<type>
+    pdl_interp.are_equal %result_types, %const_types : !pdl.range<type> -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.are_equal_2
+// CHECK: "test.not_equal"
+// CHECK: "test.success"
+// CHECK-NOT: "test.op"
+module @ir attributes { test.are_equal_2 } {
+  "test.not_equal"() : () -> (i32)
+  "test.op"() : () -> (i32, i64)
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::BranchOp
 //===----------------------------------------------------------------------===//
@@ -211,7 +339,10 @@ module @ir attributes { test.check_attribute_1 } {
 
 module @patterns {
   func @matcher(%root : !pdl.operation) {
-    pdl_interp.check_operand_count of %root is 1 -> ^pat, ^end
+    pdl_interp.check_operand_count of %root is at_least 1 -> ^exact_check, ^end
+
+  ^exact_check:
+    pdl_interp.check_operand_count of %root is 2 -> ^pat, ^end
 
   ^pat:
     pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
@@ -234,7 +365,7 @@ module @patterns {
 // CHECK: "test.success"
 module @ir attributes { test.check_operand_count_1 } {
   %operand = "test.op"() : () -> i32
-  "test.op"(%operand) : (i32) -> ()
+  "test.op"(%operand, %operand) : (i32, i32) -> ()
 }
 
 // -----
@@ -277,7 +408,10 @@ module @ir attributes { test.check_operation_name_1 } {
 
 module @patterns {
   func @matcher(%root : !pdl.operation) {
-    pdl_interp.check_result_count of %root is 1 -> ^pat, ^end
+    pdl_interp.check_result_count of %root is at_least 1 -> ^exact_check, ^end
+
+  ^exact_check:
+    pdl_interp.check_result_count of %root is 2 -> ^pat, ^end
 
   ^pat:
     pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
@@ -296,9 +430,12 @@ module @patterns {
 }
 
 // CHECK-LABEL: test.check_result_count_1
+// CHECK: "test.op"() : () -> i32
 // CHECK: "test.success"() : () -> ()
+// CHECK-NOT: "test.op"() : () -> (i32, i32)
 module @ir attributes { test.check_result_count_1 } {
   "test.op"() : () -> i32
+  "test.op"() : () -> (i32, i32)
 }
 
 // -----
@@ -340,6 +477,43 @@ module @ir attributes { test.check_type_1 } {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypesOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %results = pdl_interp.get_results of %root : !pdl.range<value>
+    %result_types = pdl_interp.get_value_type of %results : !pdl.range<type>
+    pdl_interp.check_types %result_types are [i32] -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.check_types_1
+// CHECK: "test.op"() : () -> (i32, i64)
+// CHECK: "test.success"
+// CHECK-NOT: "test.op"() : () -> i32
+module @ir attributes { test.check_types_1 } {
+  "test.op"() : () -> (i32, i64)
+  "test.op"() : () -> i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::CreateAttributeOp
 //===----------------------------------------------------------------------===//
@@ -390,6 +564,12 @@ module @ir attributes { test.create_type_1 } {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypesOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::EraseOp
 //===----------------------------------------------------------------------===//
@@ -465,6 +645,110 @@ module @ir attributes { test.get_defining_op_1 } {
 
 // Fully tested within the tests for other operations.
 
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandsOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end
+
+  ^pat1:
+    %operands = pdl_interp.get_operands 0 of %root : !pdl.range<value>
+    %full_operands = pdl_interp.get_operands of %root : !pdl.range<value>
+    pdl_interp.are_equal %operands, %full_operands : !pdl.range<value> -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.get_operands_1
+// CHECK: "test.success"
+module @ir attributes { test.get_operands_1 } {
+  %inputs:2 = "test.producer"() : () -> (i32, i32)
+  "test.op"(%inputs#0, %inputs#1) : (i32, i32) -> ()
+}
+
+// -----
+
+// Test all of the various combinations related to `AttrSizedOperandSegments`.
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.attr_sized_operands" -> ^pat1, ^end
+
+  ^pat1:
+    %operands_0 = pdl_interp.get_operands 0 of %root : !pdl.range<value>
+    pdl_interp.is_not_null %operands_0 : !pdl.range<value> -> ^pat2, ^end
+
+  ^pat2:
+    %operands_0_single = pdl_interp.get_operands 0 of %root : !pdl.value
+    pdl_interp.is_not_null %operands_0_single : !pdl.value -> ^end, ^pat3
+
+  ^pat3:
+    %operands_1 = pdl_interp.get_operands 1 of %root : !pdl.range<value>
+    pdl_interp.is_not_null %operands_1 : !pdl.range<value> -> ^pat4, ^end
+
+  ^pat4:
+    %operands_1_single = pdl_interp.get_operands 1 of %root : !pdl.value
+    pdl_interp.is_not_null %operands_1_single : !pdl.value -> ^end, ^pat5
+
+  ^pat5:
+    %operands_2 = pdl_interp.get_operands 2 of %root : !pdl.range<value>
+    pdl_interp.is_not_null %operands_2 : !pdl.range<value> -> ^pat6, ^end
+
+  ^pat6:
+    %operands_2_single = pdl_interp.get_operands 2 of %root : !pdl.value
+    pdl_interp.is_not_null %operands_2_single : !pdl.value -> ^pat7, ^end
+
+  ^pat7:
+    %invalid_operands = pdl_interp.get_operands 50 of %root : !pdl.value
+    pdl_interp.is_not_null %invalid_operands : !pdl.value -> ^end, ^pat8
+
+  ^pat8:
+    pdl_interp.record_match @rewriters::@success(%root, %operands_0, %operands_1, %operands_2, %operands_2_single : !pdl.operation, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.value) : benefit(1), loc([%root]) -> ^end
+
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root: !pdl.operation, %operands_0: !pdl.range<value>, %operands_1: !pdl.range<value>, %operands_2: !pdl.range<value>, %operands_2_single: !pdl.value) {
+      %op0 = pdl_interp.create_operation "test.success"(%operands_0 : !pdl.range<value>)
+      %op1 = pdl_interp.create_operation "test.success"(%operands_1 : !pdl.range<value>)
+      %op2 = pdl_interp.create_operation "test.success"(%operands_2 : !pdl.range<value>)
+      %op3 = pdl_interp.create_operation "test.success"(%operands_2_single : !pdl.value)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.get_operands_2
+// CHECK-NEXT:  %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32)
+// CHECK-NEXT:  "test.success"() : () -> ()
+// CHECK-NEXT:  "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> ()
+// CHECK-NEXT:  "test.success"(%[[INPUTS]]#4) : (i32) -> ()
+// CHECK-NEXT:  "test.success"(%[[INPUTS]]#4) : (i32) -> ()
+module @ir attributes { test.get_operands_2 } {
+  %inputs:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32)
+  "test.attr_sized_operands"(%inputs#0, %inputs#1, %inputs#2, %inputs#3, %inputs#4) {operand_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : (i32, i32, i32, i32, i32) -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::GetResultOp
 //===----------------------------------------------------------------------===//
@@ -506,6 +790,119 @@ module @ir attributes { test.get_result_1 } {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultsOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end
+
+  ^pat1:
+    %results = pdl_interp.get_results 0 of %root : !pdl.range<value>
+    %full_results = pdl_interp.get_results of %root : !pdl.range<value>
+    pdl_interp.are_equal %results, %full_results : !pdl.range<value> -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.get_results_1
+// CHECK: "test.success"
+module @ir attributes { test.get_results_1 } {
+  %a:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32)
+}
+
+// -----
+
+// Test all of the various combinations related to `AttrSizedResultSegments`.
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.attr_sized_results" -> ^pat1, ^end
+
+  ^pat1:
+    %results_0 = pdl_interp.get_results 0 of %root : !pdl.range<value>
+    pdl_interp.is_not_null %results_0 : !pdl.range<value> -> ^pat2, ^end
+
+  ^pat2:
+    %results_0_single = pdl_interp.get_results 0 of %root : !pdl.value
+    pdl_interp.is_not_null %results_0_single : !pdl.value -> ^end, ^pat3
+
+  ^pat3:
+    %results_1 = pdl_interp.get_results 1 of %root : !pdl.range<value>
+    pdl_interp.is_not_null %results_1 : !pdl.range<value> -> ^pat4, ^end
+
+  ^pat4:
+    %results_1_single = pdl_interp.get_results 1 of %root : !pdl.value
+    pdl_interp.is_not_null %results_1_single : !pdl.value -> ^end, ^pat5
+
+  ^pat5:
+    %results_2 = pdl_interp.get_results 2 of %root : !pdl.range<value>
+    pdl_interp.is_not_null %results_2 : !pdl.range<value> -> ^pat6, ^end
+
+  ^pat6:
+    %results_2_single = pdl_interp.get_results 2 of %root : !pdl.value
+    pdl_interp.is_not_null %results_2_single : !pdl.value -> ^pat7, ^end
+
+  ^pat7:
+    %invalid_results = pdl_interp.get_results 50 of %root : !pdl.value
+    pdl_interp.is_not_null %invalid_results : !pdl.value -> ^end, ^pat8
+
+  ^pat8:
+    pdl_interp.record_match @rewriters::@success(%root, %results_0, %results_1, %results_2, %results_2_single : !pdl.operation, !pdl.range<value>, !pdl.range<value>, !pdl.range<value>, !pdl.value) : benefit(1), loc([%root]) -> ^end
+
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root: !pdl.operation, %results_0: !pdl.range<value>, %results_1: !pdl.range<value>, %results_2: !pdl.range<value>, %results_2_single: !pdl.value) {
+      %results_0_types = pdl_interp.get_value_type of %results_0 : !pdl.range<type>
+      %results_1_types = pdl_interp.get_value_type of %results_1 : !pdl.range<type>
+      %results_2_types = pdl_interp.get_value_type of %results_2 : !pdl.range<type>
+      %results_2_single_types = pdl_interp.get_value_type of %results_2_single : !pdl.type
+
+      %op0 = pdl_interp.create_operation "test.success" -> (%results_0_types : !pdl.range<type>)
+      %op1 = pdl_interp.create_operation "test.success" -> (%results_1_types : !pdl.range<type>)
+      %op2 = pdl_interp.create_operation "test.success" -> (%results_2_types : !pdl.range<type>)
+      %op3 = pdl_interp.create_operation "test.success" -> (%results_2_single_types : !pdl.type)
+
+      %new_results_0 = pdl_interp.get_results of %op0 : !pdl.range<value>
+      %new_results_1 = pdl_interp.get_results of %op1 : !pdl.range<value>
+      %new_results_2 = pdl_interp.get_results of %op2 : !pdl.range<value>
+
+      pdl_interp.replace %root with (%new_results_0, %new_results_1, %new_results_2 : !pdl.range<value>, !pdl.range<value>, !pdl.range<value>)
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.get_results_2
+// CHECK: "test.success"() : () -> ()
+// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32)
+// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32
+// CHECK: %[[RESULTS_2_SINGLE:.*]] = "test.success"() : () -> i32
+// CHECK: "test.consumer"(%[[RESULTS_1]]#0, %[[RESULTS_1]]#1, %[[RESULTS_1]]#2, %[[RESULTS_1]]#3, %[[RESULTS_2]]) : (i32, i32, i32, i32, i32) -> ()
+module @ir attributes { test.get_results_2 } {
+  %results:5 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : () -> (i32, i32, i32, i32, i32)
+  "test.consumer"(%results#0, %results#1, %results#2, %results#3, %results#4) : (i32, i32, i32, i32, i32) -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::GetValueTypeOp
 //===----------------------------------------------------------------------===//
@@ -564,6 +961,43 @@ module @ir attributes { test.record_match_1 } {
 
 // -----
 
+// Check that ranges are properly forwarded to the result.
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end
+
+  ^pat1:
+    %operands = pdl_interp.get_operands of %root : !pdl.range<value>
+    %results = pdl_interp.get_results of %root : !pdl.range<value>
+    %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+    pdl_interp.record_match @rewriters::@success(%operands, %types, %root : !pdl.range<value>, !pdl.range<type>, !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%operands: !pdl.range<value>, %types: !pdl.range<type>, %root: !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"(%operands : !pdl.range<value>) -> (%types : !pdl.range<type>)
+      %results = pdl_interp.get_results of %op : !pdl.range<value>
+      pdl_interp.replace %root with (%results : !pdl.range<value>)
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.record_match_2
+// CHECK: %[[OPERAND:.*]] = "test.producer"() : () -> i32
+// CHECK: %[[RESULTS:.*]]:2 = "test.success"(%[[OPERAND]]) : (i32) -> (i64, i32)
+// CHECK: "test.consumer"(%[[RESULTS]]#0, %[[RESULTS]]#1) : (i64, i32) -> ()
+module @ir attributes { test.record_match_2 } {
+  %input = "test.producer"() : () -> i32
+  %results:2 = "test.op"(%input) : (i32) -> (i64, i32)
+  "test.consumer"(%results#0, %results#1) : (i64, i32) -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::ReplaceOp
 //===----------------------------------------------------------------------===//
@@ -780,3 +1214,40 @@ module @patterns {
 module @ir attributes { test.switch_type_1 } {
   "test.op"() { test_attr = 10 : i32 } : () -> ()
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypesOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    %results = pdl_interp.get_results of %root : !pdl.range<value>
+    %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+    pdl_interp.switch_types %types to [[i64, i64], [i32]](^pat2, ^end) -> ^end
+
+  ^pat2:
+    pdl_interp.switch_types %types to [[i32], [i64, i32]](^end, ^end) -> ^pat3
+
+  ^pat3:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.create_operation "test.success"
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.switch_types_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_types_1 } {
+  %results:2 = "test.op"() : () -> (i64, i64)
+}

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index e60022ba94cc..bc45d7b083aa 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -24,6 +24,18 @@ static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
                                                  PatternRewriter &rewriter) {
   return customSingleEntityConstraint(values[1], constantParams, rewriter);
 }
+static LogicalResult
+customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
+                                    ArrayAttr constantParams,
+                                    PatternRewriter &rewriter) {
+  if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
+    return failure();
+  ValueRange operandValues = values[0].cast<ValueRange>();
+  TypeRange typeValues = values[1].cast<TypeRange>();
+  if (operandValues.size() != 2 || typeValues.size() != 2)
+    return failure();
+  return success();
+}
 
 // Custom creator invoked from PDL.
 static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
@@ -31,6 +43,19 @@ static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
   results.push_back(rewriter.createOperation(
       OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
 }
+static void customVariadicResultCreate(ArrayRef<PDLValue> args,
+                                       ArrayAttr constantParams,
+                                       PatternRewriter &rewriter,
+                                       PDLResultList &results) {
+  Operation *root = args[0].cast<Operation *>();
+  results.push_back(root->getOperands());
+  results.push_back(root->getOperands().getTypes());
+}
+static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+                             PatternRewriter &rewriter,
+                             PDLResultList &results) {
+  results.push_back(rewriter.getF32Type());
+}
 
 /// Custom rewriter invoked from PDL.
 static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
@@ -63,7 +88,12 @@ struct TestPDLByteCodePass
                                           customMultiEntityConstraint);
     pdlPattern.registerConstraintFunction("single_entity_constraint",
                                           customSingleEntityConstraint);
+    pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
+                                          customMultiEntityVariadicConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);
+    pdlPattern.registerRewriteFunction("var_creator",
+                                       customVariadicResultCreate);
+    pdlPattern.registerRewriteFunction("type_creator", customCreateType);
     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
 
     OwningRewritePatternList patternList(std::move(pdlPattern));


        


More information about the Mlir-commits mailing list