[Mlir-commits] [mlir] 310c3ee - [mlir:PDL][NFC] Update PDL API to use prefixed accessors
River Riddle
llvmlistbot at llvm.org
Fri Sep 30 15:34:22 PDT 2022
Author: River Riddle
Date: 2022-09-30T15:27:10-07:00
New Revision: 310c3ee4724435464db36148a30c40aaf89bcc1d
URL: https://github.com/llvm/llvm-project/commit/310c3ee4724435464db36148a30c40aaf89bcc1d
DIFF: https://github.com/llvm/llvm-project/commit/310c3ee4724435464db36148a30c40aaf89bcc1d.diff
LOG: [mlir:PDL][NFC] Update PDL API to use prefixed accessors
This doesn't flip the switch for prefix generation yet, that'll be
done in a followup.
Added:
Modified:
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e521b5382b0a4..301fa68e59d03 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -591,7 +591,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
builder.setInsertionPointToEnd(currentBlock);
builder.create<pdl_interp::RecordMatchOp>(
pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
- rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
+ rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
failureBlockStack.back());
}
@@ -616,7 +616,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
// Prefer materializing constants directly when possible.
Operation *oldOp = oldValue.getDefiningOp();
if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
- if (Attribute value = attrOp.valueAttr()) {
+ if (Attribute value = attrOp.getValueAttr()) {
return newValue = builder.create<pdl_interp::CreateAttributeOp>(
attrOp.getLoc(), value);
}
@@ -643,11 +643,12 @@ SymbolRefAttr PatternLowering::generateRewriter(
// If this is a custom rewriter, simply dispatch to the registered rewrite
// method.
pdl::RewriteOp rewriter = pattern.getRewriter();
- if (StringAttr rewriteName = rewriter.nameAttr()) {
+ if (StringAttr rewriteName = rewriter.getNameAttr()) {
SmallVector<Value> args;
- if (rewriter.root())
- args.push_back(mapRewriteValue(rewriter.root()));
- auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
+ if (rewriter.getRoot())
+ args.push_back(mapRewriteValue(rewriter.getRoot()));
+ auto mappedArgs =
+ llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
builder.create<pdl_interp::ApplyRewriteOp>(
rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
@@ -679,10 +680,10 @@ void PatternLowering::generateRewriter(
pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 2> arguments;
- for (Value argument : rewriteOp.args())
+ for (Value argument : rewriteOp.getArgs())
arguments.push_back(mapRewriteValue(argument));
auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
- rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
+ rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
arguments);
for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
rewriteValues[std::get<0>(it)] = std::get<1>(it);
@@ -692,7 +693,7 @@ void PatternLowering::generateRewriter(
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
- attrOp.getLoc(), attrOp.valueAttr());
+ attrOp.getLoc(), attrOp.getValueAttr());
rewriteValues[attrOp] = newAttr;
}
@@ -724,7 +725,7 @@ void PatternLowering::generateRewriter(
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
attributes, operationOp.getAttributeValueNames());
- rewriteValues[operationOp.op()] = createdOp;
+ rewriteValues[operationOp.getOp()] = createdOp;
// Generate accesses for any results that have their types constrained.
// Handle the case where there is a single range representing all of the
@@ -771,7 +772,7 @@ void PatternLowering::generateRewriter(
// If the replacement was another operation, get its results. `pdl` allows
// for using an operation for simplicitly, but the interpreter isn't as
// user facing.
- if (Value replOp = replaceOp.replOperation()) {
+ if (Value replOp = replaceOp.getReplOperation()) {
// Don't use replace if we know the replaced operation has no results.
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.getTypeValues().empty()) {
@@ -779,7 +780,7 @@ void PatternLowering::generateRewriter(
replOp.getLoc(), mapRewriteValue(replOp)));
}
} else {
- for (Value operand : replaceOp.replValues())
+ for (Value operand : replaceOp.getReplValues())
replOperands.push_back(mapRewriteValue(operand));
}
@@ -800,15 +801,15 @@ void PatternLowering::generateRewriter(
function_ref<Value(Value)> mapRewriteValue) {
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
resultOp.getLoc(), builder.getType<pdl::ValueType>(),
- mapRewriteValue(resultOp.parent()), resultOp.index());
+ mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
void PatternLowering::generateRewriter(
pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
- resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
- resultOp.index());
+ resultOp.getLoc(), resultOp.getType(),
+ mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
void PatternLowering::generateRewriter(
@@ -878,7 +879,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
// Look for an operation that was replaced by `op`. The result types will be
// inferred from the results that were replaced.
- for (OpOperand &use : op.op().getUses()) {
+ for (OpOperand &use : op.getOp().getUses()) {
// Check that the use corresponds to a ReplaceOp and that it is the
// replacement value, not the operation being replaced.
pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
@@ -907,9 +908,9 @@ void PatternLowering::generateOperationResultTypeRewriter(
if (resultTypeValues.empty())
return;
- // The verifier asserts that the result types of each pdl.operation can be
+ // The verifier asserts that the result types of each pdl.getOperation can be
// inferred. If we reach here, there is a bug either in the logic above or
- // in the verifier for pdl.operation.
+ // in the verifier for pdl.getOperation.
op->emitOpError() << "unable to infer result type for operation";
llvm_unreachable("unable to infer result type for operation");
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 8b6e4c280c0ac..422182e970242 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -55,7 +55,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
// If the attribute has a type or value, add a constraint.
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
- else if (Attribute value = attr.valueAttr())
+ else if (Attribute value = attr.getValueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
}
@@ -81,7 +81,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
builder.getType(pos));
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
- Optional<unsigned> index = op.index();
+ Optional<unsigned> index = op.getIndex();
// Prevent traversal into a null value if the result has a proper index.
if (index)
@@ -101,7 +101,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(resultPos, builder.getEqualTo(pos));
// Collect the predicates of the parent operation.
- getTreePredicates(predList, op.parent(), builder, inputs,
+ getTreePredicates(predList, op.getParent(), builder, inputs,
(Position *)parentPos);
});
}
@@ -253,7 +253,7 @@ static void getAttributePredicates(pdl::AttributeOp op,
Position *&attrPos = inputs[op];
if (attrPos)
return;
- Attribute value = op.valueAttr();
+ Attribute value = op.getValueAttr();
assert(value && "expected non-tree `pdl.attribute` to contain a value");
attrPos = builder.getAttributeLiteral(value);
}
@@ -262,7 +262,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
- OperandRange arguments = op.args();
+ OperandRange arguments = op.getArgs();
std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
@@ -273,7 +273,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
- builder.getConstraint(op.name(), allPositions);
+ builder.getConstraint(op.getName(), allPositions);
predList.emplace_back(pos, pred);
}
@@ -286,8 +286,8 @@ static void getResultPredicates(pdl::ResultOp op,
return;
// Ensure that the result isn't null.
- auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
- resultPos = builder.getResult(parentPos, op.index());
+ auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
+ resultPos = builder.getResult(parentPos, op.getIndex());
predList.emplace_back(resultPos, builder.getIsNotNull());
}
@@ -300,9 +300,9 @@ static void getResultPredicates(pdl::ResultsOp op,
return;
// Ensure that the result isn't null if the result has an index.
- auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
+ auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
bool isVariadic = op.getType().isa<pdl::RangeType>();
- Optional<unsigned> index = op.index();
+ Optional<unsigned> index = op.getIndex();
resultPos = builder.getResultGroup(parentPos, index, isVariadic);
if (index)
predList.emplace_back(resultPos, builder.getIsNotNull());
@@ -375,12 +375,12 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
for (Value operand : operationOp.getOperandValues())
TypeSwitch<Operation *>(operand.getDefiningOp())
.Case<pdl::ResultOp, pdl::ResultsOp>(
- [&used](auto resultOp) { used.insert(resultOp.parent()); });
+ [&used](auto resultOp) { used.insert(resultOp.getParent()); });
}
// Remove the specified root from the use set, so that we can
// always select it as a root, even if it is used by other operations.
- if (Value root = pattern.getRewriter().root())
+ if (Value root = pattern.getRewriter().getRoot())
used.erase(root);
// Finally, collect all the unused operations.
@@ -470,8 +470,8 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
entry.depth + 1);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
- toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
- entry.depth);
+ toVisit.emplace(resultOp.getParent(), entry.value,
+ resultOp.getIndex(), entry.depth);
});
}
}
@@ -616,7 +616,7 @@ static Value buildPredicateList(pdl::PatternOp pattern,
// Solve the optimal branching problem for each candidate root, or use the
// provided one.
- Value bestRoot = pattern.getRewriter().root();
+ Value bestRoot = pattern.getRewriter().getRoot();
OptimalBranching::EdgeList bestEdges;
if (!bestRoot) {
unsigned bestCost = 0;
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 73c57d195495a..b96f34bcedb88 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -78,7 +78,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
visit(operand.getDefiningOp(), visited);
})
.Case<ResultOp, ResultsOp>([&visited](auto result) {
- visit(result.parent().getDefiningOp(), visited);
+ visit(result.getParent().getDefiningOp(), visited);
});
// Traverse the users.
@@ -112,7 +112,7 @@ LogicalResult ApplyNativeRewriteOp::verify() {
LogicalResult AttributeOp::verify() {
Value attrType = getValueType();
- Optional<Attribute> attrValue = value();
+ Optional<Attribute> attrValue = getValue();
if (!attrValue) {
if (isa<RewriteOp>((*this)->getParentOp()))
@@ -196,7 +196,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
// Check to see if the uses of the operation itself can be used to infer
// types.
- if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
+ if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
return success();
// Handle the case where the operation has no explicit result types.
@@ -402,7 +402,7 @@ StringRef PatternOp::getDefaultDialect() {
//===----------------------------------------------------------------------===//
LogicalResult ReplaceOp::verify() {
- if (replOperation() && !replValues().empty())
+ if (getReplOperation() && !getReplValues().empty())
return emitOpError() << "expected no replacement values to be provided"
" when the replacement operation is present";
return success();
@@ -430,7 +430,7 @@ static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
}
LogicalResult ResultsOp::verify() {
- if (!index() && getType().isa<pdl::ValueType>()) {
+ if (!getIndex() && getType().isa<pdl::ValueType>()) {
return emitOpError() << "expected `pdl.range<value>` result type when "
"no index is specified, but got: "
<< getType();
@@ -446,7 +446,7 @@ LogicalResult RewriteOp::verifyRegions() {
Region &rewriteRegion = getBodyRegion();
// Handle the case where the rewrite is external.
- if (name()) {
+ if (getName()) {
if (!rewriteRegion.empty()) {
return emitOpError()
<< "expected rewrite region to be empty when rewrite is external";
@@ -461,7 +461,7 @@ LogicalResult RewriteOp::verifyRegions() {
}
// Check that no additional arguments were provided.
- if (!externalArgs().empty()) {
+ if (!getExternalArgs().empty()) {
return emitOpError() << "expected no external arguments when the "
"rewrite is specified inline";
}
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index bf73be9cbff6d..4632533b5d43f 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -79,7 +79,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
int patternIndex = 0;
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
// If the pattern has a name, use that. Otherwise, generate a unique name.
- if (Optional<StringRef> patternName = pattern.sym_name()) {
+ if (Optional<StringRef> patternName = pattern.getSymName()) {
patternNames.insert(patternName->str());
} else {
std::string name;
@@ -124,9 +124,9 @@ struct {0} : ::mlir::PDLPatternModule {{
};
pattern.walk([&](Operation *op) {
if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
- checkRegisterNativeFn(constraintOp.name(), "Constraint");
+ checkRegisterNativeFn(constraintOp.getName(), "Constraint");
else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
- checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
+ checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
});
os << " }\n};\n\n";
}
@@ -140,7 +140,7 @@ void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
- [&](auto op) { usedFns.insert(op.name()); });
+ [&](auto op) { usedFns.insert(op.getName()); });
});
for (const ast::Decl *decl : astModule.getChildren()) {
More information about the Mlir-commits
mailing list