[Mlir-commits] [mlir] ce57789 - [mlir:PDL] Add support for creating ranges in rewrites
River Riddle
llvmlistbot at llvm.org
Tue Nov 8 01:58:26 PST 2022
Author: River Riddle
Date: 2022-11-08T01:57:57-08:00
New Revision: ce57789d8e5dc109dc9bd330232b31a22a80ad3a
URL: https://github.com/llvm/llvm-project/commit/ce57789d8e5dc109dc9bd330232b31a22a80ad3a
DIFF: https://github.com/llvm/llvm-project/commit/ce57789d8e5dc109dc9bd330232b31a22a80ad3a.diff
LOG: [mlir:PDL] Add support for creating ranges in rewrites
This commit adds support for building a concatenated range from
a given set of elements, either single element or other ranges, within a
rewrite. We could conceptually extend this to support constraining
input ranges, but the logic there is quite a bit more complex so it is
left for later work when a need arises.
Differential Revision: https://reviews.llvm.org/D133719
Added:
Modified:
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
mlir/test/Dialect/PDL/invalid.mlir
mlir/test/Dialect/PDLInterp/invalid.mlir
mlir/test/Rewrite/pdl-bytecode.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index fbe991a61a4cf..c85687e199b74 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -436,6 +436,48 @@ def PDL_PatternOp : PDL_Op<"pattern", [
let hasRegionVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// pdl::RangeOp
+//===----------------------------------------------------------------------===//
+
+def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> {
+ let summary = "Construct a range of pdl entities";
+ let description = [{
+ `pdl.range` operations construct a range from a given set of PDL entities,
+ which all share the same underlying element type. For example, a
+ `!pdl.range<value>` may be constructed from a list of `!pdl.value`
+ or `!pdl.range<value>` entities.
+
+ Example:
+
+ ```mlir
+ // Construct a range of values.
+ %valueRange = pdl.range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
+
+ // Construct a range of types.
+ %typeRange = pdl.range %inputType, %inputRange : !pdl.type, !pdl.range<type>
+
+ // Construct an empty range of types.
+ %valueRange = pdl.range : !pdl.range<type>
+ ```
+
+ TODO: Range construction is currently limited to rewrites, but it could
+ be extended to constraints under certain circustances; i.e., if we can
+ determine how to extract the underlying elements. If we can't, e.g. if
+ there are multiple sub ranges used for construction, we won't be able
+ to determine their sizes during constraint time.
+ }];
+
+ let arguments = (ins Variadic<PDL_AnyType>:$arguments);
+ let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
+ let assemblyFormat = [{
+ ($arguments^ `:` type($arguments))?
+ custom<RangeType>(ref(type($arguments)), type($result))
+ attr-dict
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
index 8cbe31fd2a6f0..a342dcc6233f5 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
@@ -28,6 +28,11 @@ class PDLType : public Type {
static bool classof(Type type);
};
+
+/// If the given type is a range, return its element type, otherwise return
+/// the type itself.
+Type getRangeElementTypeOrSelf(Type type);
+
} // namespace pdl
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 659bfbcac8605..96d631bd474a4 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -992,6 +992,43 @@ def PDLInterp_IsNotNullOp
let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateRangeOp : PDLInterp_Op<"create_range", [Pure]> {
+ let summary = "Construct a range of PDL entities";
+ let description = [{
+ `pdl_interp.create_range` operations construct a range from a given set of PDL
+ entities, which all share the same underlying element type. For example, a
+ `!pdl.range<value>` may be constructed from a list of `!pdl.value`
+ or `!pdl.range<value>` entities.
+
+ Example:
+
+ ```mlir
+ // Construct a range of values.
+ %valueRange = pdl_interp.create_range %inputValue, %inputRange : !pdl.value, !pdl.range<value>
+
+ // Construct a range of types.
+ %typeRange = pdl_interp.create_range %inputType, %inputRange : !pdl.type, !pdl.range<type>
+
+ // Construct an empty range of types.
+ %valueRange = pdl_interp.create_range : !pdl.range<type>
+ ```
+ }];
+
+ let arguments = (ins Variadic<PDL_AnyType>:$arguments);
+ let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result);
+ let assemblyFormat = [{
+ ($arguments^ `:` type($arguments))?
+ custom<RangeType>(ref(type($arguments)), type($result))
+ attr-dict
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// pdl_interp::RecordMatchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 987e7a36ea890..fdc95ab7a820a 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -89,6 +89,9 @@ struct PatternLowering {
void generateRewriter(pdl::OperationOp operationOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
+ void generateRewriter(pdl::RangeOp rangeOp,
+ DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ReplaceOp replaceOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
@@ -668,8 +671,8 @@ SymbolRefAttr PatternLowering::generateRewriter(
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
.Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
- pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
- pdl::TypeOp, pdl::TypesOp>([&](auto op) {
+ pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
+ pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
});
}
@@ -775,6 +778,16 @@ void PatternLowering::generateRewriter(
}
}
+void PatternLowering::generateRewriter(
+ pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue) {
+ SmallVector<Value, 4> replOperands;
+ for (Value operand : rangeOp.getArguments())
+ replOperands.push_back(mapRewriteValue(operand));
+ rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
+ rangeOp.getLoc(), rangeOp.getType(), replOperands);
+}
+
void PatternLowering::generateRewriter(
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index b96f34bcedb88..e33ba7153968e 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -397,6 +397,39 @@ StringRef PatternOp::getDefaultDialect() {
return PDLDialect::getDialectNamespace();
}
+//===----------------------------------------------------------------------===//
+// pdl::RangeOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
+ Type &resultType) {
+ // If arguments were provided, infer the result type from the argument list.
+ if (!argumentTypes.empty()) {
+ resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
+ return success();
+ }
+ // Otherwise, parse the type as a trailing type.
+ return p.parseColonType(resultType);
+}
+
+static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
+ Type resultType) {
+ if (argumentTypes.empty())
+ p << ": " << resultType;
+}
+
+LogicalResult RangeOp::verify() {
+ Type elementType = getType().getElementType();
+ for (Type operandType : getOperandTypes()) {
+ Type operandElementType = getRangeElementTypeOrSelf(operandType);
+ if (operandElementType != elementType) {
+ return emitOpError("expected operand to have element type ")
+ << elementType << ", but got " << operandElementType;
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// pdl::ReplaceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
index f4dbbfb5b0506..49eee1afe0964 100644
--- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -59,6 +59,12 @@ bool PDLType::classof(Type type) {
return llvm::isa<PDLDialect>(type.getDialect());
}
+Type pdl::getRangeElementTypeOrSelf(Type type) {
+ if (auto rangeType = type.dyn_cast<RangeType>())
+ return rangeType.getElementType();
+ return type;
+}
+
//===----------------------------------------------------------------------===//
// RangeType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 01670e31d70a9..e8a61ef4c6a4d 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -237,6 +237,40 @@ static Type getGetValueTypeOpValueType(Type type) {
return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
}
+//===----------------------------------------------------------------------===//
+// pdl::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
+ Type &resultType) {
+ // If arguments were provided, infer the result type from the argument list.
+ if (!argumentTypes.empty()) {
+ resultType =
+ pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
+ return success();
+ }
+ // Otherwise, parse the type as a trailing type.
+ return p.parseColonType(resultType);
+}
+
+static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
+ TypeRange argumentTypes, Type resultType) {
+ if (argumentTypes.empty())
+ p << ": " << resultType;
+}
+
+LogicalResult CreateRangeOp::verify() {
+ Type elementType = getType().getElementType();
+ for (Type operandType : getOperandTypes()) {
+ Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
+ if (operandElementType != elementType) {
+ return emitOpError("expected operand to have element type ")
+ << elementType << ", but got " << operandElementType;
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// pdl_interp::SwitchAttributeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 9cc51da9fcf33..6b1dfb93a0c5f 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -99,10 +99,14 @@ enum OpCode : ByteCodeField {
CheckTypes,
/// Continue to the next iteration of a loop.
Continue,
+ /// Create a type range from a list of constant types.
+ CreateConstantTypeRange,
/// Create an operation.
CreateOperation,
- /// Create a range of types.
- CreateTypes,
+ /// Create a type range from a list of dynamic types.
+ CreateDynamicTypeRange,
+ /// Create a value range.
+ CreateDynamicValueRange,
/// Erase an operation.
EraseOp,
/// Extract the op from a range at the specified index.
@@ -265,6 +269,7 @@ class Generator {
void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
@@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
- pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
- pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
- pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
+ pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
+ pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
+ pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
@@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
else
writer.appendPDLValueList(op.getInputResultTypes());
}
+void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
+ // Append the correct opcode for the range type.
+ TypeSwitch<Type>(op.getType().getElementType())
+ .Case(
+ [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
+ .Case([&](pdl::ValueType) {
+ writer.append(OpCode::CreateDynamicValueRange);
+ });
+
+ writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
+ writer.appendPDLValueList(op->getOperands());
+}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.getResult()) = getMemIndex(op.getValue());
}
void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
- writer.append(OpCode::CreateTypes, op.getResult(),
+ writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
getRangeStorageIndex(op.getResult()), op.getValue());
}
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
@@ -1103,9 +1120,11 @@ class ByteCodeExecutor {
void executeCheckResultCount();
void executeCheckTypes();
void executeContinue();
+ void executeCreateConstantTypeRange();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
- void executeCreateTypes();
+ template <typename T>
+ void executeDynamicCreateRange(StringRef type);
void executeEraseOp(PatternRewriter &rewriter);
template <typename T, typename Range, PDLValue::Kind kind>
void executeExtract();
@@ -1172,8 +1191,18 @@ class ByteCodeExecutor {
}
/// Read a list of values from the bytecode buffer. The values may be encoded
- /// as either Value or ValueRange elements.
- void readValueList(SmallVectorImpl<Value> &list) {
+ /// either as a single element or a range of elements.
+ void readList(SmallVectorImpl<Type> &list) {
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+ list.push_back(read<Type>());
+ } else {
+ TypeRange *values = read<TypeRange *>();
+ list.append(values->begin(), values->end());
+ }
+ }
+ }
+ void readList(SmallVectorImpl<Value> &list) {
for (unsigned i = 0, e = read(); i != e; ++i) {
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
list.push_back(read<Value>());
@@ -1292,6 +1321,39 @@ class ByteCodeExecutor {
return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
}
+ /// Assign the given range to the given memory index. This allocates a new
+ /// range object if necessary.
+ template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
+ void assignRangeToMemory(RangeT &&range, unsigned memIndex,
+ unsigned rangeIndex) {
+ // Utility functor used to type-erase the assignment.
+ auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
+ // If the input range is empty, we don't need to allocate anything.
+ if (range.empty()) {
+ rangeMemory[rangeIndex] = {};
+ } else {
+ // Allocate a buffer for this type range.
+ llvm::OwningArrayRef<T> storage(llvm::size(range));
+ llvm::copy(range, storage.begin());
+
+ // Assign this to the range slot and use the range as the value for the
+ // memory index.
+ allocatedRangeMemory.emplace_back(std::move(storage));
+ rangeMemory[rangeIndex] = allocatedRangeMemory.back();
+ }
+ memory[memIndex] = &rangeMemory[rangeIndex];
+ };
+
+ // Dispatch based on the concrete range type.
+ if constexpr (std::is_same_v<T, Type>) {
+ return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
+ } else if constexpr (std::is_same_v<T, Value>) {
+ return assignRange(allocatedValueRangeMemory, valueRangeMemory);
+ } else {
+ llvm_unreachable("unhandled range type");
+ }
+ }
+
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
@@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() {
popCodeIt();
}
-void ByteCodeExecutor::executeCreateTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
+void ByteCodeExecutor::executeCreateConstantTypeRange() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
unsigned memIndex = read();
unsigned rangeIndex = read();
ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
-
- // Allocate a buffer for this type range.
- llvm::OwningArrayRef<Type> storage(typesAttr.size());
- llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
- allocatedTypeRangeMemory.emplace_back(std::move(storage));
-
- // Assign this to the range slot and use the range as the value for the
- // memory index.
- typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
- memory[memIndex] = &typeRangeMemory[rangeIndex];
+ assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
+ rangeIndex);
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
- readValueList(state.operands);
+ readList(state.operands);
for (unsigned i = 0, e = read(); i != e; ++i) {
StringAttr name = read<StringAttr>();
if (Attribute attr = read<Attribute>())
@@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
});
}
+template <typename T>
+void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ SmallVector<T> values;
+ readList(values);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n * " << type << "s: ";
+ llvm::interleaveComma(values, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ assignRangeToMemory(values, memIndex, rangeIndex);
+}
+
void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
Operation *op = read<Operation *>();
@@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
- readValueList(args);
+ readList(args);
LLVM_DEBUG({
llvm::dbgs() << " * Operation: " << *op << "\n"
@@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case Continue:
executeContinue();
break;
+ case CreateConstantTypeRange:
+ executeCreateConstantTypeRange();
+ break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
- case CreateTypes:
- executeCreateTypes();
+ case CreateDynamicTypeRange:
+ executeDynamicCreateRange<Type>("Type");
+ break;
+ case CreateDynamicValueRange:
+ executeDynamicCreateRange<Value>("Value");
break;
case EraseOp:
executeEraseOp(rewriter);
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index d6e8f4ab322e4..e5a84d69dcad9 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -243,3 +243,20 @@ module @unbound_rewrite_op {
}
// -----
+
+// CHECK-LABEL: module @range_op
+module @range_op {
+ // CHECK: module @rewriters
+ // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value)
+ // CHECK: %[[RANGE1:.*]] = pdl_interp.create_range : !pdl.range<value>
+ // CHECK: %[[RANGE2:.*]] = pdl_interp.create_range %[[OPERAND]], %[[RANGE1]] : !pdl.value, !pdl.range<value>
+ // CHECK: pdl_interp.finalize
+ pdl.pattern : benefit(1) {
+ %operand = pdl.operand
+ %root = operation "foo.op"(%operand : !pdl.value)
+ rewrite %root {
+ %emptyRange = pdl.range : !pdl.range<value>
+ %range = pdl.range %operand, %emptyRange : !pdl.value, !pdl.range<value>
+ }
+ }
+}
diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index 61c0aaeb69546..522e9fbbe4a2c 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -237,6 +237,23 @@ pdl.pattern : benefit(1) {
// -----
+//===----------------------------------------------------------------------===//
+// pdl::RangeOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+ %operand = pdl.operand
+ %resultType = pdl.type
+ %root = pdl.operation "baz.op"(%operand : !pdl.value) -> (%resultType : !pdl.type)
+
+ rewrite %root {
+ // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
+ %range = pdl.range %operand, %resultType : !pdl.value, !pdl.type
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl::ResultsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/PDLInterp/invalid.mlir b/mlir/test/Dialect/PDLInterp/invalid.mlir
index f194d3246ffaf..0457a158430a2 100644
--- a/mlir/test/Dialect/PDLInterp/invalid.mlir
+++ b/mlir/test/Dialect/PDLInterp/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
//===----------------------------------------------------------------------===//
-// pdl::CreateOperationOp
+// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
pdl_interp.func @rewriter() {
@@ -23,3 +23,15 @@ pdl_interp.func @rewriter() {
} : (!pdl.type) -> (!pdl.operation)
pdl_interp.finalize
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+pdl_interp.func @rewriter(%value: !pdl.value, %type: !pdl.type) {
+ // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}}
+ %range = pdl_interp.create_range %value, %type : !pdl.value, !pdl.type
+ pdl_interp.finalize
+}
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index 20e2490c2de79..565874fe0ebce 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -568,6 +568,48 @@ module @ir attributes { test.create_op_infer_results } {
// -----
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateRangeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end
+
+ ^pat1:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root: !pdl.operation) {
+ %rootOperand = pdl_interp.get_operand 0 of %root
+ %rootOperands = pdl_interp.get_operands of %root : !pdl.range<value>
+ %operandRange = pdl_interp.create_range %rootOperand, %rootOperands : !pdl.value, !pdl.range<value>
+
+ %operandType = pdl_interp.get_value_type of %rootOperand : !pdl.type
+ %operandTypes = pdl_interp.get_value_type of %rootOperands : !pdl.range<type>
+ %typeRange = pdl_interp.create_range %operandType, %operandTypes : !pdl.type, !pdl.range<type>
+
+ %op = pdl_interp.create_operation "test.success"(%operandRange : !pdl.range<value>) -> (%typeRange : !pdl.range<type>)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.create_range_1
+// CHECK: %[[INPUTS:.*]]:2 = "test.input"()
+// CHECK: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#0, %[[INPUTS]]#1) : (i32, i32, i32) -> (i32, i32, i32)
+module @ir attributes { test.create_range_1 } {
+ %values:2 = "test.input"() : () -> (i32, i32)
+ "test.op"(%values#0, %values#1) : (i32, i32) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::CreateTypeOp
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list