[Mlir-commits] [mlir] 310c3ee - [mlir:PDL][NFC] Update PDL API to use prefixed accessors

River Riddle llvmlistbot at llvm.org
Fri Sep 30 15:34:22 PDT 2022


Author: River Riddle
Date: 2022-09-30T15:27:10-07:00
New Revision: 310c3ee4724435464db36148a30c40aaf89bcc1d

URL: https://github.com/llvm/llvm-project/commit/310c3ee4724435464db36148a30c40aaf89bcc1d
DIFF: https://github.com/llvm/llvm-project/commit/310c3ee4724435464db36148a30c40aaf89bcc1d.diff

LOG: [mlir:PDL][NFC] Update PDL API to use prefixed accessors

This doesn't flip the switch for prefix generation yet, that'll be
done in a followup.

Added: 
    

Modified: 
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e521b5382b0a4..301fa68e59d03 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -591,7 +591,7 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
   builder.setInsertionPointToEnd(currentBlock);
   builder.create<pdl_interp::RecordMatchOp>(
       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
-      rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
+      rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
       failureBlockStack.back());
 }
 
@@ -616,7 +616,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
     // Prefer materializing constants directly when possible.
     Operation *oldOp = oldValue.getDefiningOp();
     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
-      if (Attribute value = attrOp.valueAttr()) {
+      if (Attribute value = attrOp.getValueAttr()) {
         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
                    attrOp.getLoc(), value);
       }
@@ -643,11 +643,12 @@ SymbolRefAttr PatternLowering::generateRewriter(
   // If this is a custom rewriter, simply dispatch to the registered rewrite
   // method.
   pdl::RewriteOp rewriter = pattern.getRewriter();
-  if (StringAttr rewriteName = rewriter.nameAttr()) {
+  if (StringAttr rewriteName = rewriter.getNameAttr()) {
     SmallVector<Value> args;
-    if (rewriter.root())
-      args.push_back(mapRewriteValue(rewriter.root()));
-    auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
+    if (rewriter.getRoot())
+      args.push_back(mapRewriteValue(rewriter.getRoot()));
+    auto mappedArgs =
+        llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
     args.append(mappedArgs.begin(), mappedArgs.end());
     builder.create<pdl_interp::ApplyRewriteOp>(
         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
@@ -679,10 +680,10 @@ void PatternLowering::generateRewriter(
     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
   SmallVector<Value, 2> arguments;
-  for (Value argument : rewriteOp.args())
+  for (Value argument : rewriteOp.getArgs())
     arguments.push_back(mapRewriteValue(argument));
   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
-      rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
+      rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
       arguments);
   for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
     rewriteValues[std::get<0>(it)] = std::get<1>(it);
@@ -692,7 +693,7 @@ void PatternLowering::generateRewriter(
     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
-      attrOp.getLoc(), attrOp.valueAttr());
+      attrOp.getLoc(), attrOp.getValueAttr());
   rewriteValues[attrOp] = newAttr;
 }
 
@@ -724,7 +725,7 @@ void PatternLowering::generateRewriter(
   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
       loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
       attributes, operationOp.getAttributeValueNames());
-  rewriteValues[operationOp.op()] = createdOp;
+  rewriteValues[operationOp.getOp()] = createdOp;
 
   // Generate accesses for any results that have their types constrained.
   // Handle the case where there is a single range representing all of the
@@ -771,7 +772,7 @@ void PatternLowering::generateRewriter(
   // If the replacement was another operation, get its results. `pdl` allows
   // for using an operation for simplicitly, but the interpreter isn't as
   // user facing.
-  if (Value replOp = replaceOp.replOperation()) {
+  if (Value replOp = replaceOp.getReplOperation()) {
     // Don't use replace if we know the replaced operation has no results.
     auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
     if (!opOp || !opOp.getTypeValues().empty()) {
@@ -779,7 +780,7 @@ void PatternLowering::generateRewriter(
           replOp.getLoc(), mapRewriteValue(replOp)));
     }
   } else {
-    for (Value operand : replaceOp.replValues())
+    for (Value operand : replaceOp.getReplValues())
       replOperands.push_back(mapRewriteValue(operand));
   }
 
@@ -800,15 +801,15 @@ void PatternLowering::generateRewriter(
     function_ref<Value(Value)> mapRewriteValue) {
   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
-      mapRewriteValue(resultOp.parent()), resultOp.index());
+      mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
 }
 
 void PatternLowering::generateRewriter(
     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
-      resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
-      resultOp.index());
+      resultOp.getLoc(), resultOp.getType(),
+      mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
 }
 
 void PatternLowering::generateRewriter(
@@ -878,7 +879,7 @@ void PatternLowering::generateOperationResultTypeRewriter(
 
   // Look for an operation that was replaced by `op`. The result types will be
   // inferred from the results that were replaced.
-  for (OpOperand &use : op.op().getUses()) {
+  for (OpOperand &use : op.getOp().getUses()) {
     // Check that the use corresponds to a ReplaceOp and that it is the
     // replacement value, not the operation being replaced.
     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
@@ -907,9 +908,9 @@ void PatternLowering::generateOperationResultTypeRewriter(
   if (resultTypeValues.empty())
     return;
 
-  // The verifier asserts that the result types of each pdl.operation can be
+  // The verifier asserts that the result types of each pdl.getOperation can be
   // inferred. If we reach here, there is a bug either in the logic above or
-  // in the verifier for pdl.operation.
+  // in the verifier for pdl.getOperation.
   op->emitOpError() << "unable to infer result type for operation";
   llvm_unreachable("unable to infer result type for operation");
 }

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 8b6e4c280c0ac..422182e970242 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -55,7 +55,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
   // If the attribute has a type or value, add a constraint.
   if (Value type = attr.getValueType())
     getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
-  else if (Attribute value = attr.valueAttr())
+  else if (Attribute value = attr.getValueAttr())
     predList.emplace_back(pos, builder.getAttributeConstraint(value));
 }
 
@@ -81,7 +81,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
                             builder.getType(pos));
       })
       .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
-        Optional<unsigned> index = op.index();
+        Optional<unsigned> index = op.getIndex();
 
         // Prevent traversal into a null value if the result has a proper index.
         if (index)
@@ -101,7 +101,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
         predList.emplace_back(resultPos, builder.getEqualTo(pos));
 
         // Collect the predicates of the parent operation.
-        getTreePredicates(predList, op.parent(), builder, inputs,
+        getTreePredicates(predList, op.getParent(), builder, inputs,
                           (Position *)parentPos);
       });
 }
@@ -253,7 +253,7 @@ static void getAttributePredicates(pdl::AttributeOp op,
   Position *&attrPos = inputs[op];
   if (attrPos)
     return;
-  Attribute value = op.valueAttr();
+  Attribute value = op.getValueAttr();
   assert(value && "expected non-tree `pdl.attribute` to contain a value");
   attrPos = builder.getAttributeLiteral(value);
 }
@@ -262,7 +262,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
                                     std::vector<PositionalPredicate> &predList,
                                     PredicateBuilder &builder,
                                     DenseMap<Value, Position *> &inputs) {
-  OperandRange arguments = op.args();
+  OperandRange arguments = op.getArgs();
 
   std::vector<Position *> allPositions;
   allPositions.reserve(arguments.size());
@@ -273,7 +273,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
                                     comparePosDepth);
   PredicateBuilder::Predicate pred =
-      builder.getConstraint(op.name(), allPositions);
+      builder.getConstraint(op.getName(), allPositions);
   predList.emplace_back(pos, pred);
 }
 
@@ -286,8 +286,8 @@ static void getResultPredicates(pdl::ResultOp op,
     return;
 
   // Ensure that the result isn't null.
-  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
-  resultPos = builder.getResult(parentPos, op.index());
+  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
+  resultPos = builder.getResult(parentPos, op.getIndex());
   predList.emplace_back(resultPos, builder.getIsNotNull());
 }
 
@@ -300,9 +300,9 @@ static void getResultPredicates(pdl::ResultsOp op,
     return;
 
   // Ensure that the result isn't null if the result has an index.
-  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
+  auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
   bool isVariadic = op.getType().isa<pdl::RangeType>();
-  Optional<unsigned> index = op.index();
+  Optional<unsigned> index = op.getIndex();
   resultPos = builder.getResultGroup(parentPos, index, isVariadic);
   if (index)
     predList.emplace_back(resultPos, builder.getIsNotNull());
@@ -375,12 +375,12 @@ static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
     for (Value operand : operationOp.getOperandValues())
       TypeSwitch<Operation *>(operand.getDefiningOp())
           .Case<pdl::ResultOp, pdl::ResultsOp>(
-              [&used](auto resultOp) { used.insert(resultOp.parent()); });
+              [&used](auto resultOp) { used.insert(resultOp.getParent()); });
   }
 
   // Remove the specified root from the use set, so that we can
   // always select it as a root, even if it is used by other operations.
-  if (Value root = pattern.getRewriter().root())
+  if (Value root = pattern.getRewriter().getRoot())
     used.erase(root);
 
   // Finally, collect all the unused operations.
@@ -470,8 +470,8 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
                               entry.depth + 1);
           })
           .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
-            toVisit.emplace(resultOp.parent(), entry.value, resultOp.index(),
-                            entry.depth);
+            toVisit.emplace(resultOp.getParent(), entry.value,
+                            resultOp.getIndex(), entry.depth);
           });
     }
   }
@@ -616,7 +616,7 @@ static Value buildPredicateList(pdl::PatternOp pattern,
 
   // Solve the optimal branching problem for each candidate root, or use the
   // provided one.
-  Value bestRoot = pattern.getRewriter().root();
+  Value bestRoot = pattern.getRewriter().getRoot();
   OptimalBranching::EdgeList bestEdges;
   if (!bestRoot) {
     unsigned bestCost = 0;

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 73c57d195495a..b96f34bcedb88 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -78,7 +78,7 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
           visit(operand.getDefiningOp(), visited);
       })
       .Case<ResultOp, ResultsOp>([&visited](auto result) {
-        visit(result.parent().getDefiningOp(), visited);
+        visit(result.getParent().getDefiningOp(), visited);
       });
 
   // Traverse the users.
@@ -112,7 +112,7 @@ LogicalResult ApplyNativeRewriteOp::verify() {
 
 LogicalResult AttributeOp::verify() {
   Value attrType = getValueType();
-  Optional<Attribute> attrValue = value();
+  Optional<Attribute> attrValue = getValue();
 
   if (!attrValue) {
     if (isa<RewriteOp>((*this)->getParentOp()))
@@ -196,7 +196,7 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
 
   // Check to see if the uses of the operation itself can be used to infer
   // types.
-  if (llvm::any_of(op.op().getUses(), canInferTypeFromUse))
+  if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
     return success();
 
   // Handle the case where the operation has no explicit result types.
@@ -402,7 +402,7 @@ StringRef PatternOp::getDefaultDialect() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReplaceOp::verify() {
-  if (replOperation() && !replValues().empty())
+  if (getReplOperation() && !getReplValues().empty())
     return emitOpError() << "expected no replacement values to be provided"
                             " when the replacement operation is present";
   return success();
@@ -430,7 +430,7 @@ static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
 }
 
 LogicalResult ResultsOp::verify() {
-  if (!index() && getType().isa<pdl::ValueType>()) {
+  if (!getIndex() && getType().isa<pdl::ValueType>()) {
     return emitOpError() << "expected `pdl.range<value>` result type when "
                             "no index is specified, but got: "
                          << getType();
@@ -446,7 +446,7 @@ LogicalResult RewriteOp::verifyRegions() {
   Region &rewriteRegion = getBodyRegion();
 
   // Handle the case where the rewrite is external.
-  if (name()) {
+  if (getName()) {
     if (!rewriteRegion.empty()) {
       return emitOpError()
              << "expected rewrite region to be empty when rewrite is external";
@@ -461,7 +461,7 @@ LogicalResult RewriteOp::verifyRegions() {
   }
 
   // Check that no additional arguments were provided.
-  if (!externalArgs().empty()) {
+  if (!getExternalArgs().empty()) {
     return emitOpError() << "expected no external arguments when the "
                             "rewrite is specified inline";
   }

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index bf73be9cbff6d..4632533b5d43f 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -79,7 +79,7 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
   int patternIndex = 0;
   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
     // If the pattern has a name, use that. Otherwise, generate a unique name.
-    if (Optional<StringRef> patternName = pattern.sym_name()) {
+    if (Optional<StringRef> patternName = pattern.getSymName()) {
       patternNames.insert(patternName->str());
     } else {
       std::string name;
@@ -124,9 +124,9 @@ struct {0} : ::mlir::PDLPatternModule {{
   };
   pattern.walk([&](Operation *op) {
     if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
-      checkRegisterNativeFn(constraintOp.name(), "Constraint");
+      checkRegisterNativeFn(constraintOp.getName(), "Constraint");
     else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
-      checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
+      checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
   });
   os << "  }\n};\n\n";
 }
@@ -140,7 +140,7 @@ void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
   module.walk([&](Operation *op) {
     TypeSwitch<Operation *>(op)
         .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
-            [&](auto op) { usedFns.insert(op.name()); });
+            [&](auto op) { usedFns.insert(op.getName()); });
   });
 
   for (const ast::Decl *decl : astModule.getChildren()) {


        


More information about the Mlir-commits mailing list