[Mlir-commits] [mlir] 72fddfb - [mlir] Flip PDL to use Both accessors
River Riddle
llvmlistbot at llvm.org
Wed Sep 21 17:52:11 PDT 2022
Author: River Riddle
Date: 2022-09-21T17:36:13-07:00
New Revision: 72fddfb5993e99d6e1a6d204167b6eac4b7eb934
URL: https://github.com/llvm/llvm-project/commit/72fddfb5993e99d6e1a6d204167b6eac4b7eb934
DIFF: https://github.com/llvm/llvm-project/commit/72fddfb5993e99d6e1a6d204167b6eac4b7eb934.diff
LOG: [mlir] Flip PDL to use Both accessors
This allows for incrementally updating the old API usages without
needing to update everything at once. PDL will be left on Both
for a little bit and then flipped to prefixed when all APIs have been
updated.
Differential Revision: https://reviews.llvm.org/D134387
Added:
Modified:
mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
mlir/python/mlir/dialects/_pdl_ops_ext.py
mlir/test/Dialect/PDL/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
index 545d87827a10e..3aa1d653c613b 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td
@@ -70,9 +70,8 @@ def PDL_Dialect : Dialect {
void registerTypes();
}];
- // FIXME: Prefixed accessors overlap with builtin Operation members. Flip
- // once resolved.
- let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
+ // FIXME: Flip to prefixed.
+ let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}
#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index f96a2cca526b7..c92cf4712cc05 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -122,10 +122,10 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
```
}];
- let arguments = (ins Optional<PDL_Type>:$type,
+ let arguments = (ins Optional<PDL_Type>:$valueType,
OptionalAttr<AnyAttr>:$value);
let results = (outs PDL_Attribute:$attr);
- let assemblyFormat = "(`:` $type^)? (`=` $value^)? attr-dict-with-keyword";
+ let assemblyFormat = "(`:` $valueType^)? (`=` $value^)? attr-dict-with-keyword";
let builders = [
OpBuilder<(ins CArg<"Value", "Value()">:$type), [{
@@ -156,8 +156,8 @@ def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> {
pdl.erase %root
```
}];
- let arguments = (ins PDL_Operation:$operation);
- let assemblyFormat = "$operation attr-dict";
+ let arguments = (ins PDL_Operation:$opValue);
+ let assemblyFormat = "$opValue attr-dict";
}
//===----------------------------------------------------------------------===//
@@ -187,9 +187,9 @@ def PDL_OperandOp
```
}];
- let arguments = (ins Optional<PDL_Type>:$type);
- let results = (outs PDL_Value:$val);
- let assemblyFormat = "(`:` $type^)? attr-dict";
+ let arguments = (ins Optional<PDL_Type>:$valueType);
+ let results = (outs PDL_Value:$value);
+ let assemblyFormat = "(`:` $valueType^)? attr-dict";
let builders = [
OpBuilder<(ins), [{
@@ -226,9 +226,9 @@ def PDL_OperandsOp
```
}];
- let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$type);
- let results = (outs PDL_RangeOf<PDL_Value>:$val);
- let assemblyFormat = "(`:` $type^)? attr-dict";
+ let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$valueType);
+ let results = (outs PDL_RangeOf<PDL_Value>:$value);
+ let assemblyFormat = "(`:` $valueType^)? attr-dict";
let builders = [
OpBuilder<(ins), [{
@@ -341,16 +341,16 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
```
}];
- let arguments = (ins OptionalAttr<StrAttr>:$name,
- Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operands,
- Variadic<PDL_Attribute>:$attributes,
- StrArrayAttr:$attributeNames,
- Variadic<PDL_InstOrRangeOf<PDL_Type>>:$types);
+ let arguments = (ins OptionalAttr<StrAttr>:$opName,
+ Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
+ Variadic<PDL_Attribute>:$attributeValues,
+ StrArrayAttr:$attributeValueNames,
+ Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
let results = (outs PDL_Operation:$op);
let assemblyFormat = [{
- ($name^)? (`(` $operands^ `:` type($operands) `)`)?
- custom<OperationOpAttributes>($attributes, $attributeNames)
- (`->` `(` $types^ `:` type($types) `)`)? attr-dict
+ ($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
+ custom<OperationOpAttributes>($attributeValues, $attributeValueNames)
+ (`->` `(` $typeValues^ `:` type($typeValues) `)`)? attr-dict
}];
let builders = [
@@ -413,9 +413,9 @@ def PDL_PatternOp : PDL_Op<"pattern", [
let arguments = (ins ConfinedAttr<I16Attr, [IntNonNegative]>:$benefit,
OptionalAttr<SymbolNameAttr>:$sym_name);
- let regions = (region SizedRegion<1>:$body);
+ let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat = [{
- ($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $body
+ ($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $bodyRegion
}];
let builders = [
@@ -467,11 +467,11 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
pdl.replace %root with %otherOp
```
}];
- let arguments = (ins PDL_Operation:$operation,
+ let arguments = (ins PDL_Operation:$opValue,
Optional<PDL_Operation>:$replOperation,
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$replValues);
let assemblyFormat = [{
- $operation `with` (`(` $replValues^ `:` type($replValues) `)`)?
+ $opValue `with` (`(` $replValues^ `:` type($replValues) `)`)?
($replOperation^)? attr-dict
}];
let hasVerifier = 1;
@@ -603,10 +603,10 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
let arguments = (ins Optional<PDL_Operation>:$root,
OptionalAttr<StrAttr>:$name,
Variadic<PDL_AnyType>:$externalArgs);
- let regions = (region AnyRegion:$body);
+ let regions = (region AnyRegion:$bodyRegion);
let assemblyFormat = [{
($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
- ($body^)?
+ ($bodyRegion^)?
attr-dict-with-keyword
}];
let hasRegionVerifier = 1;
@@ -635,9 +635,9 @@ def PDL_TypeOp : PDL_Op<"type"> {
```
}];
- let arguments = (ins OptionalAttr<TypeAttr>:$type);
+ let arguments = (ins OptionalAttr<TypeAttr>:$constantType);
let results = (outs PDL_Type:$result);
- let assemblyFormat = "attr-dict (`:` $type^)?";
+ let assemblyFormat = "attr-dict (`:` $constantType^)?";
let hasVerifier = 1;
}
@@ -664,9 +664,9 @@ def PDL_TypesOp : PDL_Op<"types"> {
```
}];
- let arguments = (ins OptionalAttr<TypeArrayAttr>:$types);
+ let arguments = (ins OptionalAttr<TypeArrayAttr>:$constantTypes);
let results = (outs PDL_RangeOf<PDL_Type>:$result);
- let assemblyFormat = "attr-dict (`:` $types^)?";
+ let assemblyFormat = "attr-dict (`:` $constantTypes^)?";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e4cce8d21f779..e521b5382b0a4 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -575,8 +575,9 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
// Collect the set of operations generated by the rewriter.
SmallVector<StringRef, 4> generatedOps;
- for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
- generatedOps.push_back(*op.name());
+ for (auto op :
+ pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
+ generatedOps.push_back(*op.getOpName());
ArrayAttr generatedOpsAttr;
if (!generatedOps.empty())
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
@@ -584,7 +585,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
// Grab the root kind if present.
StringAttr rootKindAttr;
if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
- if (Optional<StringRef> rootKind = rootOp.name())
+ if (Optional<StringRef> rootKind = rootOp.getOpName())
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
@@ -620,12 +621,12 @@ SymbolRefAttr PatternLowering::generateRewriter(
attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
- if (TypeAttr type = typeOp.typeAttr()) {
+ if (TypeAttr type = typeOp.getConstantTypeAttr()) {
return newValue = builder.create<pdl_interp::CreateTypeOp>(
typeOp.getLoc(), type);
}
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
- if (ArrayAttr type = typeOp.typesAttr()) {
+ if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
return newValue = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), type);
}
@@ -699,18 +700,18 @@ void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
- mapRewriteValue(eraseOp.operation()));
+ mapRewriteValue(eraseOp.getOpValue()));
}
void PatternLowering::generateRewriter(
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> operands;
- for (Value operand : operationOp.operands())
+ for (Value operand : operationOp.getOperandValues())
operands.push_back(mapRewriteValue(operand));
SmallVector<Value, 4> attributes;
- for (Value attr : operationOp.attributes())
+ for (Value attr : operationOp.getAttributeValues())
attributes.push_back(mapRewriteValue(attr));
bool hasInferredResultTypes = false;
@@ -721,14 +722,14 @@ void PatternLowering::generateRewriter(
// Create the new operation.
Location loc = operationOp.getLoc();
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
- loc, *operationOp.name(), types, hasInferredResultTypes, operands,
- attributes, operationOp.attributeNames());
+ loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
+ attributes, operationOp.getAttributeValueNames());
rewriteValues[operationOp.op()] = createdOp;
// Generate accesses for any results that have their types constrained.
// Handle the case where there is a single range representing all of the
// result types.
- OperandRange resultTys = operationOp.types();
+ OperandRange resultTys = operationOp.getTypeValues();
if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
@@ -772,8 +773,8 @@ void PatternLowering::generateRewriter(
// user facing.
if (Value replOp = replaceOp.replOperation()) {
// Don't use replace if we know the replaced operation has no results.
- auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
- if (!opOp || !opOp.types().empty()) {
+ auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
+ if (!opOp || !opOp.getTypeValues().empty()) {
replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
replOp.getLoc(), mapRewriteValue(replOp)));
}
@@ -784,13 +785,14 @@ void PatternLowering::generateRewriter(
// If there are no replacement values, just create an erase instead.
if (replOperands.empty()) {
- builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
- mapRewriteValue(replaceOp.operation()));
+ builder.create<pdl_interp::EraseOp>(
+ replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
return;
}
- builder.create<pdl_interp::ReplaceOp>(
- replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
+ builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
+ mapRewriteValue(replaceOp.getOpValue()),
+ replOperands);
}
void PatternLowering::generateRewriter(
@@ -814,7 +816,7 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
- if (TypeAttr typeAttr = typeOp.typeAttr()) {
+ if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
rewriteValues[typeOp] =
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
}
@@ -825,7 +827,7 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
- if (ArrayAttr typeAttr = typeOp.typesAttr()) {
+ if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), typeAttr);
}
@@ -840,7 +842,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// Try to handle resolution for each of the result types individually. This is
// preferred over type inferrence because it will allow for us to use existing
// types directly, as opposed to trying to rebuild the type list.
- OperandRange resultTypeValues = op.types();
+ OperandRange resultTypeValues = op.getTypeValues();
auto tryResolveResultTypes = [&] {
types.reserve(resultTypeValues.size());
for (const auto &it : llvm::enumerate(resultTypeValues)) {
@@ -886,7 +888,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// rewrites only have single block regions, so if the op isn't in the
// rewriter block (i.e. the current block of the operation) we already know
// it dominates (i.e. it's in the matcher).
- Value replOpVal = replOpUser.operation();
+ Value replOpVal = replOpUser.getOpValue();
Operation *replacedOp = replOpVal.getDefiningOp();
if (replacedOp->getBlock() == rewriterBlock &&
!replacedOp->isBeforeInBlock(op))
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index b77ea40b8725c..8b6e4c280c0ac 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -53,7 +53,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getIsNotNull());
// If the attribute has a type or value, add a constraint.
- if (Value type = attr.type())
+ if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.valueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
@@ -76,7 +76,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
predList.emplace_back(pos, builder.getIsNotNull());
- if (Value type = op.type())
+ if (Value type = op.getValueType())
getTreePredicates(predList, type, builder, inputs,
builder.getType(pos));
})
@@ -120,12 +120,12 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getIsNotNull());
// Check that this is the correct root operation.
- if (Optional<StringRef> opName = op.name())
+ if (Optional<StringRef> opName = op.getOpName())
predList.emplace_back(pos, builder.getOperationName(*opName));
// Check that the operation has the proper number of operands. If there are
// any variable length operands, we check a minimum instead of an exact count.
- OperandRange operands = op.operands();
+ OperandRange operands = op.getOperandValues();
unsigned minOperands = getNumNonRangeValues(operands);
if (minOperands != operands.size()) {
if (minOperands)
@@ -136,7 +136,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
// Check that the operation has the proper number of results. If there are
// any variable length results, we check a minimum instead of an exact count.
- OperandRange types = op.types();
+ OperandRange types = op.getTypeValues();
unsigned minResults = getNumNonRangeValues(types);
if (minResults == types.size())
predList.emplace_back(pos, builder.getResultCount(types.size()));
@@ -144,11 +144,11 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
// Recurse into any attributes, operands, or results.
- for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
+ for (auto [attrName, attr] :
+ llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
getTreePredicates(
- predList, std::get<1>(it), builder, inputs,
- builder.getAttribute(opPos,
- std::get<0>(it).cast<StringAttr>().getValue()));
+ predList, attr, builder, inputs,
+ builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue()));
}
// Process the operands and results of the operation. For all values up to
@@ -208,10 +208,10 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
TypePosition *pos) {
// Check for a constraint on a constant type.
if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
- if (Attribute type = typeOp.typeAttr())
+ if (Attribute type = typeOp.getConstantTypeAttr())
predList.emplace_back(pos, builder.getTypeConstraint(type));
} else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
- if (Attribute typeAttr = typeOp.typesAttr())
+ if (Attribute typeAttr = typeOp.getConstantTypesAttr())
predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
}
}
@@ -327,7 +327,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
- for (Operation &op : pattern.body().getOps()) {
+ for (Operation &op : pattern.getBodyRegion().getOps()) {
TypeSwitch<Operation *>(&op)
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
@@ -340,11 +340,13 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
})
.Case([&](pdl::TypeOp typeOp) {
getTypePredicates(
- typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
+ typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
+ inputs);
})
.Case([&](pdl::TypesOp typeOp) {
getTypePredicates(
- typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
+ typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
+ inputs);
});
}
}
@@ -369,8 +371,8 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
// First, collect all the operations that are used as operands
// to other operations. These are not roots by default.
DenseSet<Value> used;
- for (auto operationOp : pattern.body().getOps<pdl::OperationOp>()) {
- for (Value operand : operationOp.operands())
+ for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
+ for (Value operand : operationOp.getOperandValues())
TypeSwitch<Operation *>(operand.getDefiningOp())
.Case<pdl::ResultOp, pdl::ResultsOp>(
[&used](auto resultOp) { used.insert(resultOp.parent()); });
@@ -383,7 +385,7 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
// Finally, collect all the unused operations.
SmallVector<Value> roots;
- for (Value operationOp : pattern.body().getOps<pdl::OperationOp>())
+ for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
if (!used.contains(operationOp))
roots.push_back(operationOp);
@@ -451,7 +453,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
// are expensive to join on.
TypeSwitch<Operation *>(entry.value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
- OperandRange operands = operationOp.operands();
+ OperandRange operands = operationOp.getOperandValues();
// Special case when we pass all the operands in one range.
// For those, the index is empty.
if (operands.size() == 1 &&
@@ -462,7 +464,8 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
}
// Default case: visit all the operands.
- for (const auto &p : llvm::enumerate(operationOp.operands()))
+ for (const auto &p :
+ llvm::enumerate(operationOp.getOperandValues()))
toVisit.emplace(p.value(), entry.value, p.index(),
entry.depth + 1);
})
@@ -507,7 +510,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
/// Returns true if the operand at the given index needs to be queried using an
/// operand group, i.e., if it is variadic itself or follows a variadic operand.
static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
- OperandRange operands = op.operands();
+ OperandRange operands = op.getOperandValues();
assert(index < operands.size() && "operand index out of range");
for (unsigned i = 0; i <= index; ++i)
if (operands[i].getType().isa<pdl::RangeType>())
@@ -537,7 +540,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
operandPos = builder.getAllOperands(opPos);
} else if (useOperandGroup(operationOp, *opIndex.index)) {
// We are querying an operand group.
- Type type = operationOp.operands()[*opIndex.index].getType();
+ Type type = operationOp.getOperandValues()[*opIndex.index].getType();
bool variadic = type.isa<pdl::RangeType>();
operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
} else {
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 9d1684c0b9752..73c57d195495a 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -74,7 +74,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
// Traverse the operands / parent.
TypeSwitch<Operation *>(op)
.Case<OperationOp>([&visited](auto operation) {
- for (Value operand : operation.operands())
+ for (Value operand : operation.getOperandValues())
visit(operand.getDefiningOp(), visited);
})
.Case<ResultOp, ResultsOp>([&visited](auto result) {
@@ -111,7 +111,7 @@ LogicalResult ApplyNativeRewriteOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AttributeOp::verify() {
- Value attrType = type();
+ Value attrType = getValueType();
Optional<Attribute> attrValue = value();
if (!attrValue) {
@@ -189,7 +189,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (!replOpUser || use.getOperandNumber() == 0)
return false;
// Make sure the replaced operation was defined before this one.
- Operation *replacedOp = replOpUser.operation().getDefiningOp();
+ Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
return replacedOp->getBlock() != rewriterBlock ||
replacedOp->isBeforeInBlock(op);
};
@@ -203,7 +203,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
if (resultTypes.empty()) {
// If we don't know the concrete operation, don't attempt any verification.
// We can't make assumptions if we don't know the concrete operation.
- Optional<StringRef> rawOpName = op.name();
+ Optional<StringRef> rawOpName = op.getOpName();
if (!rawOpName)
return success();
Optional<RegisteredOperationName> opName =
@@ -246,10 +246,12 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
isa<OperandOp, OperandsOp, OperationOp>(user);
};
if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
- if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
+ if (typeOp.getConstantType() ||
+ llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
} else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
- if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
+ if (typeOp.getConstantTypes() ||
+ llvm::any_of(typeOp->getUsers(), constrainsInput))
continue;
}
@@ -264,11 +266,11 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
LogicalResult OperationOp::verify() {
bool isWithinRewrite = isa<RewriteOp>((*this)->getParentOp());
- if (isWithinRewrite && !name())
+ if (isWithinRewrite && !getOpName())
return emitOpError("must have an operation name when nested within "
"a `pdl.rewrite`");
- ArrayAttr attributeNames = attributeNamesAttr();
- auto attributeValues = attributes();
+ ArrayAttr attributeNames = getAttributeValueNamesAttr();
+ auto attributeValues = getAttributeValues();
if (attributeNames.size() != attributeValues.size()) {
return emitOpError()
<< "expected the same number of attribute values and attribute "
@@ -280,7 +282,7 @@ LogicalResult OperationOp::verify() {
// If the operation is within a rewrite body and doesn't have type inference,
// ensure that the result types can be resolved.
if (isWithinRewrite && !mightHaveTypeInference()) {
- if (failed(verifyResultTypesAreInferrable(*this, types())))
+ if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
return failure();
}
@@ -288,7 +290,7 @@ LogicalResult OperationOp::verify() {
}
bool OperationOp::hasTypeInference() {
- if (Optional<StringRef> rawOpName = name()) {
+ if (Optional<StringRef> rawOpName = getOpName()) {
OperationName opName(*rawOpName, getContext());
return opName.hasInterface<InferTypeOpInterface>();
}
@@ -296,7 +298,7 @@ bool OperationOp::hasTypeInference() {
}
bool OperationOp::mightHaveTypeInference() {
- if (Optional<StringRef> rawOpName = name()) {
+ if (Optional<StringRef> rawOpName = getOpName()) {
OperationName opName(*rawOpName, getContext());
return opName.mightHaveInterface<InferTypeOpInterface>();
}
@@ -387,7 +389,7 @@ void PatternOp::build(OpBuilder &builder, OperationState &state,
/// Returns the rewrite operation of this pattern.
RewriteOp PatternOp::getRewriter() {
- return cast<RewriteOp>(body().front().getTerminator());
+ return cast<RewriteOp>(getBodyRegion().front().getTerminator());
}
/// The default dialect is `pdl`.
@@ -441,7 +443,7 @@ LogicalResult ResultsOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult RewriteOp::verifyRegions() {
- Region &rewriteRegion = body();
+ Region &rewriteRegion = getBodyRegion();
// Handle the case where the rewrite is external.
if (name()) {
@@ -477,7 +479,7 @@ StringRef RewriteOp::getDefaultDialect() {
//===----------------------------------------------------------------------===//
LogicalResult TypeOp::verify() {
- if (!typeAttr())
+ if (!getConstantTypeAttr())
return verifyHasBindingUse(*this);
return success();
}
@@ -487,7 +489,7 @@ LogicalResult TypeOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult TypesOp::verify() {
- if (!typesAttr())
+ if (!getConstantTypesAttr())
return verifyHasBindingUse(*this);
return success();
}
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index f708d4ffae41f..86a24e4e9b5fa 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -203,7 +203,7 @@ static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
pdl::RewriteOp rewrite =
builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
/*externalArgs=*/ValueRange());
- builder.createBlock(&rewrite.body());
+ builder.createBlock(&rewrite.getBodyRegion());
}
}
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
index bb63fe64dd035..428301b18f208 100644
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py
@@ -86,14 +86,14 @@ class AttributeOp:
"""Specialization for PDL attribute op class."""
def __init__(self,
- type: Optional[Union[OpView, Operation, Value]] = None,
+ valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None):
- type = type if type is None else _get_value(type)
+ valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
- super().__init__(result, type=type, value=value, loc=loc, ip=ip)
+ super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
class EraseOp:
@@ -118,7 +118,7 @@ def __init__(self,
ip=None):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
- super().__init__(result, type=type, loc=loc, ip=ip)
+ super().__init__(result, valueType=type, loc=loc, ip=ip)
class OperandsOp:
@@ -131,7 +131,7 @@ def __init__(self,
ip=None):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
- super().__init__(result, type=types, loc=loc, ip=ip)
+ super().__init__(result, valueType=types, loc=loc, ip=ip)
class OperationOp:
@@ -147,15 +147,15 @@ def __init__(self,
ip=None):
name = name if name is None else _get_str_attr(name)
args = _get_values(args)
- attributeNames = []
- attributeValues = []
+ attrNames = []
+ attrValues = []
for attrName, attrValue in attributes.items():
- attributeNames.append(StringAttr.get(attrName))
- attributeValues.append(_get_value(attrValue))
- attributeNames = ArrayAttr.get(attributeNames)
+ attrNames.append(StringAttr.get(attrName))
+ attrValues.append(_get_value(attrValue))
+ attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
- super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip)
+ super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip)
class PatternOp:
@@ -255,24 +255,26 @@ class TypeOp:
"""Specialization for PDL type op class."""
def __init__(self,
- type: Optional[Union[TypeAttr, Type]] = None,
+ constantType: Optional[Union[TypeAttr, Type]] = None,
*,
loc=None,
ip=None):
- type = type if type is None else _get_type_attr(type)
+ constantType = constantType if constantType is None else _get_type_attr(
+ constantType)
result = pdl.TypeType.get()
- super().__init__(result, type=type, loc=loc, ip=ip)
+ super().__init__(result, constantType=constantType, loc=loc, ip=ip)
class TypesOp:
"""Specialization for PDL types op class."""
def __init__(self,
- types: Sequence[Union[TypeAttr, Type]] = [],
+ constantTypes: Sequence[Union[TypeAttr, Type]] = [],
*,
loc=None,
ip=None):
- types = _get_array_attr([_get_type_attr(ty) for ty in types])
- types = None if not types else types
+ constantTypes = _get_array_attr(
+ [_get_type_attr(ty) for ty in constantTypes])
+ constantTypes = None if not constantTypes else constantTypes
result = pdl.RangeType.get(pdl.TypeType.get())
- super().__init__(result, types=types, loc=loc, ip=ip)
+ super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index e93d34cb8f031..61c0aaeb69546 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -121,7 +121,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
// expected-error at below {{expected the same number of attribute values and attribute names, got 1 names and 0 values}}
%op = "pdl.operation"() {
- attributeNames = ["attr"],
+ attributeValueNames = ["attr"],
operand_segment_sizes = array<i32: 0, 0, 0>
} : () -> (!pdl.operation)
rewrite %op with "rewriter"
More information about the Mlir-commits
mailing list