[Mlir-commits] [mlir] c80e6ed - Revert "[mlir][PDL] Add support for native constraints with results (#82760)"
Matthias Gehre
llvmlistbot at llvm.org
Thu Feb 29 22:44:58 PST 2024
Author: Matthias Gehre
Date: 2024-03-01T07:44:30+01:00
New Revision: c80e6edba4a9593f0587e27fa0ac825ebe174afd
URL: https://github.com/llvm/llvm-project/commit/c80e6edba4a9593f0587e27fa0ac825ebe174afd
DIFF: https://github.com/llvm/llvm-project/commit/c80e6edba4a9593f0587e27fa0ac825ebe174afd.diff
LOG: Revert "[mlir][PDL] Add support for native constraints with results (#82760)"
Due to buildbot failure https://lab.llvm.org/buildbot/#/builders/88/builds/72130
This reverts commit dca32a3b594b3c91f9766a9312b5d82534910fa1.
Added:
Modified:
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/IR/PDLPatternMatch.h.inc
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
mlir/test/Dialect/PDL/ops.mlir
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
mlir/test/mlir-pdll/Parser/constraint-failure.pdll
mlir/test/mlir-pdll/Parser/constraint.pdll
mlir/test/python/dialects/pdl_ops.py
Removed:
mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 1e108c3d8ac77a..4e9ebccba77d88 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -35,25 +35,20 @@ def PDL_ApplyNativeConstraintOp
let description = [{
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
has been registered externally with the consumer of PDL, to a given set of
- entities and optionally return a number of values.
+ entities.
Example:
```mlir
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
- // Apply constraint `with_result` to `root`. This constraint returns an attribute.
- %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
```
}];
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
- let results = (outs Variadic<PDL_AnyType>:$results);
- let assemblyFormat = [{
- $name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
- }];
+ let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 901acc0e6733bb..48f625bd1fa3fd 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -88,9 +88,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
- values.
- The constraint function may return any number of results.
- On success, this operation branches to the true destination,
+ values. On success, this operation branches to the true destination,
otherwise the false destination is taken. This behavior can be reversed
by setting the attribute `isNegated` to true.
@@ -106,10 +104,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
- let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
- $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
- `->` successors
+ $name `(` $args `:` type($args) `)` attr-dict `->` successors
}];
}
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index 66286ed7a4c898..a215da8cb6431d 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -318,9 +318,8 @@ protected:
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
-using PDLConstraintFunction = std::function<LogicalResult(
- PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
-
+using PDLConstraintFunction =
+ std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
@@ -727,7 +726,7 @@ std::enable_if_t<
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
- PatternRewriter &rewriter, PDLResultList &,
+ PatternRewriter &rewriter,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -843,13 +842,10 @@ public:
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
- /// * `LogicalResult (PatternRewriter &,
- /// PDLResultList &,
- /// ArrayRef<PDLValue>)`
+ /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
- /// the low-level PDLValue form, and the results are manually appended to
- /// the given result list.
+ /// the low-level PDLValue form.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
@@ -964,8 +960,8 @@ public:
}
};
class PDLResultList {};
-using PDLConstraintFunction = std::function<LogicalResult(
- PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction =
+ std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index b00cd0dee3ae80..e911631a4bc52a 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -50,8 +50,7 @@ struct PatternLowering {
/// Generate interpreter operations for the tree rooted at the given matcher
/// node, in the specified region.
- Block *generateMatcher(MatcherNode &node, Region ®ion,
- Block *block = nullptr);
+ Block *generateMatcher(MatcherNode &node, Region ®ion);
/// Get or create an access to the provided positional value in the current
/// block. This operation may mutate the provided block pointer if nested
@@ -149,10 +148,6 @@ struct PatternLowering {
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
-
- /// A mapping from a constraint question to the ApplyConstraintOp
- /// that implements it.
- DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
};
} // namespace
@@ -187,11 +182,9 @@ void PatternLowering::lower(ModuleOp module) {
firstMatcherBlock->erase();
}
-Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
- Block *block) {
+Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
// Push a new scope for the values used by this matcher.
- if (!block)
- block = ®ion.emplaceBlock();
+ Block *block = ®ion.emplaceBlock();
ValueMapScope scope(values);
// If this is the return node, simply insert the corresponding interpreter
@@ -371,15 +364,6 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
loc, cast<ArrayAttr>(rawTypeAttr));
break;
}
- case Predicates::ConstraintResultPos: {
- // Due to the order of traversal, the ApplyConstraintOp has already been
- // created and we can find it in constraintOpMap.
- auto *constrResPos = cast<ConstraintPosition>(pos);
- auto i = constraintOpMap.find(constrResPos->getQuestion());
- assert(i != constraintOpMap.end());
- value = i->second->getResult(constrResPos->getIndex());
- break;
- }
default:
llvm_unreachable("Generating unknown Position getter");
break;
@@ -406,11 +390,12 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
args.push_back(getValueAt(currentBlock, position));
}
- // Generate a new block as success successor and get the failure successor.
- Block *success = ®ion->emplaceBlock();
+ // Generate the matcher in the current (potentially nested) region
+ // and get the failure successor.
+ Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
Block *failure = failureBlockStack.back();
- // Create the predicate.
+ // Finally, create the predicate.
builder.setInsertionPointToEnd(currentBlock);
Predicates::Kind kind = question->getKind();
switch (kind) {
@@ -462,20 +447,14 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
- auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
- loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
- cstQuestion->getIsNegated(), success, failure);
-
- constraintOpMap.insert({cstQuestion, applyConstraintOp});
+ builder.create<pdl_interp::ApplyConstraintOp>(
+ loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
+ failure);
break;
}
default:
llvm_unreachable("Generating unknown Predicate operation");
}
-
- // Generate the matcher in the current (potentially nested) region.
- // This might use the results of the current predicate.
- generateMatcher(*boolNode->getSuccessNode(), *region, success);
}
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 5ad2c477573a5b..2c9b63f86d6efa 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -47,7 +47,6 @@ enum Kind : unsigned {
OperandPos,
OperandGroupPos,
AttributePos,
- ConstraintResultPos,
ResultPos,
ResultGroupPos,
TypePos,
@@ -280,28 +279,6 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
bool isOperandDefiningOp() const;
};
-//===----------------------------------------------------------------------===//
-// ConstraintPosition
-
-struct ConstraintQuestion;
-
-/// A position describing the result of a native constraint. It saves the
-/// corresponding ConstraintQuestion and result index to enable referring
-/// back to them
-struct ConstraintPosition
- : public PredicateBase<ConstraintPosition, Position,
- std::pair<ConstraintQuestion *, unsigned>,
- Predicates::ConstraintResultPos> {
- using PredicateBase::PredicateBase;
-
- /// Returns the ConstraintQuestion to enable keeping track of the native
- /// constraint this position stems from.
- ConstraintQuestion *getQuestion() const { return key.first; }
-
- // Returns the result index of this position
- unsigned getIndex() const { return key.second; }
-};
-
//===----------------------------------------------------------------------===//
// ResultPosition
@@ -470,13 +447,11 @@ struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};
-/// Apply a parameterized constraint to multiple position values and possibly
-/// produce results.
+/// Apply a parameterized constraint to multiple position values.
struct ConstraintQuestion
- : public PredicateBase<
- ConstraintQuestion, Qualifier,
- std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
- Predicates::ConstraintQuestion> {
+ : public PredicateBase<ConstraintQuestion, Qualifier,
+ std::tuple<StringRef, ArrayRef<Position *>, bool>,
+ Predicates::ConstraintQuestion> {
using Base::Base;
/// Return the name of the constraint.
@@ -485,19 +460,15 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
- /// Return the result types of the constraint.
- ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
-
/// Return the negation status of the constraint.
- bool getIsNegated() const { return std::get<3>(key); }
+ bool getIsNegated() const { return std::get<2>(key); }
/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
- alloc.copyInto(std::get<2>(key)),
- std::get<3>(key)});
+ std::get<2>(key)});
}
/// Returns a hash suitable for the given keytype.
@@ -555,7 +526,6 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
- registerParametricStorageType<ConstraintPosition>();
registerParametricStorageType<ForEachPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
@@ -618,12 +588,6 @@ class PredicateBuilder {
return OperationPosition::get(uniquer, p);
}
- // Returns a position for a new value created by a constraint.
- ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
- unsigned index) {
- return ConstraintPosition::get(uniquer, std::make_pair(q, index));
- }
-
/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -709,11 +673,11 @@ class PredicateBuilder {
}
/// Create a predicate that applies a generic constraint.
- Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
- ArrayRef<Type> resultTypes, bool isNegated) {
- return {ConstraintQuestion::get(
- uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
- TrueAnswer::get(uniquer)};
+ Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
+ bool isNegated) {
+ return {
+ ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
+ TrueAnswer::get(uniquer)};
}
/// Create a predicate comparing a value with null.
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index f3d0e08ef49b97..a9c3b0a71ef0d7 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,7 +15,6 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <queue>
@@ -50,15 +49,14 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
+ pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
- if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
- // If the attribute has a type or value, add a constraint.
- if (Value type = attr.getValueType())
- getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
- else if (Attribute value = attr.getValueAttr())
- predList.emplace_back(pos, builder.getAttributeConstraint(value));
- }
+ // If the attribute has a type or value, add a constraint.
+ if (Value type = attr.getValueType())
+ getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
+ else if (Attribute value = attr.getValueAttr())
+ predList.emplace_back(pos, builder.getAttributeConstraint(value));
}
/// Collect all of the predicates for the given operand position.
@@ -274,25 +272,8 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
- ResultRange results = op.getResults();
- PredicateBuilder::Predicate pred = builder.getConstraint(
- op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
- op.getIsNegated());
-
- // For each result register a position so it can be used later
- for (auto [i, result] : llvm::enumerate(results)) {
- ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
- ConstraintPosition *pos = builder.getConstraintPosition(q, i);
- auto [it, inserted] = inputs.insert({result, pos});
- // If this is an input value that has been visited in the tree, add a
- // constraint to ensure that both instances refer to the same value.
- if (!inserted) {
- auto minMaxPositions =
- std::minmax<Position *>(pos, it->second, comparePosDepth);
- predList.emplace_back(minMaxPositions.second,
- builder.getEqualTo(minMaxPositions.first));
- }
- }
+ PredicateBuilder::Predicate pred =
+ builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
predList.emplace_back(pos, pred);
}
@@ -894,49 +875,6 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
*root = std::make_unique<ExitNode>();
}
-/// Sorts the range begin/end with the partial order given by cmp.
-template <typename Iterator, typename Compare>
-static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
- while (begin != end) {
- // Cannot compute sortBeforeOthers in the predicate of stable_partition
- // because stable_partition will not keep the [begin, end) range intact
- // while it runs.
- llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
- for (auto i = begin; i != end; ++i) {
- if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
- sortBeforeOthers.insert(*i);
- }
-
- auto const next = std::stable_partition(begin, end, [&](auto const &a) {
- return sortBeforeOthers.contains(a);
- });
- assert(next != begin && "not a partial ordering");
- begin = next;
- }
-}
-
-/// Returns true if 'b' depends on a result of 'a'.
-static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
- auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
- if (!cqa)
- return false;
-
- auto positionDependsOnA = [&](Position *p) {
- auto *cp = dyn_cast<ConstraintPosition>(p);
- return cp && cp->getQuestion() == cqa;
- };
-
- if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
- // Does any argument of b use a?
- return llvm::any_of(cqb->getArgs(), positionDependsOnA);
- }
- if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
- return positionDependsOnA(b->position) ||
- positionDependsOnA(equalTo->getValue());
- }
- return positionDependsOnA(b->position);
-}
-
/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
std::unique_ptr<MatcherNode>
@@ -1017,10 +955,6 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
return *lhs < *rhs;
});
- // Mostly keep the now established order, but also ensure that
- // ConstraintQuestions come after the results they use.
- stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
-
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 428b19f4c74af8..d5f34679f06c60 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -94,12 +94,6 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
LogicalResult ApplyNativeConstraintOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected at least one argument");
- if (llvm::any_of(getResults(), [](OpResult result) {
- return isa<OperationType>(result.getType());
- })) {
- return emitOpError(
- "returning an operation from a constraint is not supported");
- }
return success();
}
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 559ce7f466a83d..6e6992dcdeea78 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,25 +769,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
- // Constraints that should return a value have to be registered as rewrites.
- // If a constraint and a rewrite of similar name are registered the
- // constraint takes precedence
+ assert(constraintToMemIndex.count(op.getName()) &&
+ "expected index for constraint function");
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
writer.appendPDLValueList(op.getArgs());
writer.append(ByteCodeField(op.getIsNegated()));
- ResultRange results = op.getResults();
- writer.append(ByteCodeField(results.size()));
- for (Value result : results) {
- // We record the expected kind of the result, so that we can provide extra
- // verification of the native rewrite function and handle the failure case
- // of constraints accordingly.
- writer.appendPDLValueKind(result);
-
- // Range results also need to append the range storage index.
- if (isa<pdl::RangeType>(result.getType()))
- writer.append(getRangeStorageIndex(result));
- writer.append(result);
- }
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -800,9 +786,11 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
ResultRange results = op.getResults();
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
- // We record the expected kind of the result, so that we
+ // In debug mode we also record the expected kind of the result, so that we
// can provide extra verification of the native rewrite function.
+#ifndef NDEBUG
writer.appendPDLValueKind(result);
+#endif
// Range results also need to append the range storage index.
if (isa<pdl::RangeType>(result.getType()))
@@ -1088,28 +1076,6 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
// ByteCode Execution
namespace {
-/// This class is an instantiation of the PDLResultList that provides access to
-/// the returned results. This API is not on `PDLResultList` to avoid
-/// overexposing access to information specific solely to the ByteCode.
-class ByteCodeRewriteResultList : public PDLResultList {
-public:
- ByteCodeRewriteResultList(unsigned maxNumResults)
- : PDLResultList(maxNumResults) {}
-
- /// Return the list of PDL results.
- MutableArrayRef<PDLValue> getResults() { return results; }
-
- /// Return the type ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
- return allocatedTypeRanges;
- }
-
- /// Return the value ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
- return allocatedValueRanges;
- }
-};
-
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
@@ -1186,9 +1152,6 @@ class ByteCodeExecutor {
void executeSwitchResultCount();
void executeSwitchType();
void executeSwitchTypes();
- void processNativeFunResults(ByteCodeRewriteResultList &results,
- unsigned numResults,
- LogicalResult &rewriteResult);
/// Pushes a code iterator to the stack.
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@@ -1262,8 +1225,6 @@ class ByteCodeExecutor {
return T::getFromOpaquePointer(pointer);
}
- void skip(size_t skipN) { curCodeIt += skipN; }
-
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
@@ -1420,11 +1381,33 @@ class ByteCodeExecutor {
ArrayRef<PDLConstraintFunction> constraintFunctions;
ArrayRef<PDLRewriteFunction> rewriteFunctions;
};
+
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+ ByteCodeRewriteResultList(unsigned maxNumResults)
+ : PDLResultList(maxNumResults) {}
+
+ /// Return the list of PDL results.
+ MutableArrayRef<PDLValue> getResults() { return results; }
+
+ /// Return the type ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+ return allocatedTypeRanges;
+ }
+
+ /// Return the value ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+ return allocatedValueRanges;
+ }
+};
} // namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
- ByteCodeField fun_idx = read();
+ const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
@@ -1439,29 +1422,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
llvm::interleaveComma(args, llvm::dbgs());
});
-
- ByteCodeField numResults = read();
- const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
- ByteCodeRewriteResultList results(numResults);
- LogicalResult rewriteResult = constraintFn(rewriter, results, args);
- ArrayRef<PDLValue> constraintResults = results.getResults();
- LLVM_DEBUG({
- if (succeeded(rewriteResult)) {
- llvm::dbgs() << " * Constraint succeeded\n";
- llvm::dbgs() << " * Results: ";
- llvm::interleaveComma(constraintResults, llvm::dbgs());
- llvm::dbgs() << "\n";
- } else {
- llvm::dbgs() << " * Constraint failed\n";
- }
- });
- assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
- "native PDL rewrite function succeeded but returned "
- "unexpected number of results");
- processNativeFunResults(results, numResults, rewriteResult);
-
- // Depending on the constraint jump to the proper destination.
- selectJump(isNegated != succeeded(rewriteResult));
+ // Invoke the constraint and jump to the proper destination.
+ selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
}
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1483,39 +1445,16 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
- processNativeFunResults(results, numResults, rewriteResult);
+ // Store the results in the bytecode memory.
+ for (PDLValue &result : results.getResults()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
- if (failed(rewriteResult)) {
- LLVM_DEBUG(llvm::dbgs() << " - Failed");
- return failure();
- }
- return success();
-}
+// In debug mode we also verify the expected kind of the result.
+#ifndef NDEBUG
+ assert(result.getKind() == read<PDLValue::Kind>() &&
+ "native PDL rewrite function returned an unexpected type of result");
+#endif
-void ByteCodeExecutor::processNativeFunResults(
- ByteCodeRewriteResultList &results, unsigned numResults,
- LogicalResult &rewriteResult) {
- // Store the results in the bytecode memory or handle missing results on
- // failure.
- for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
- PDLValue::Kind resultKind = read<PDLValue::Kind>();
-
- // Skip the according number of values on the buffer on failure and exit
- // early as there are no results to process.
- if (failed(rewriteResult)) {
- if (resultKind == PDLValue::Kind::TypeRange ||
- resultKind == PDLValue::Kind::ValueRange) {
- skip(2);
- } else {
- skip(1);
- }
- return;
- }
- PDLValue result = results.getResults()[resultIdx];
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
- assert(result.getKind() == resultKind &&
- "native PDL rewrite function returned an unexpected type of "
- "result");
// If the result is a range, we need to copy it over to the bytecodes
// range memory.
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@@ -1537,6 +1476,13 @@ void ByteCodeExecutor::processNativeFunResults(
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
+
+ // Process the result of the rewrite.
+ if (failed(rewriteResult)) {
+ LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ return failure();
+ }
+ return success();
}
void ByteCodeExecutor::executeAreEqual() {
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 206781ed146692..97ff8bd0d8584d 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -1362,6 +1362,12 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
if (failed(parseToken(Token::semicolon,
"expected `;` after native declaration")))
return failure();
+ // TODO: PDL should be able to support constraint results in certain
+ // situations, we should revise this.
+ if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
+ return emitError(
+ "native Constraints currently do not support returning results");
+ }
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
}
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 92afb765b5ab4e..02bb8316c02db0 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -79,57 +79,6 @@ module @constraints {
// -----
-// CHECK-LABEL: module @constraint_with_result
-module @constraint_with_result {
- // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
- // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
- // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
- pdl.pattern : benefit(1) {
- %root = operation
- %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
- rewrite %root with "rewriter"(%attr : !pdl.attribute)
- }
-}
-
-// -----
-
-// CHECK-LABEL: module @constraint_with_unused_result
-module @constraint_with_unused_result {
- // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
- // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
- // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
- pdl.pattern : benefit(1) {
- %root = operation
- %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
- rewrite %root with "rewriter"
- }
-}
-
-// -----
-
-// CHECK-LABEL: module @constraint_with_result_multiple
-module @constraint_with_result_multiple {
- // check that native constraints work as expected even when multiple identical constraints are fused
-
- // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
- // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
- // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
- // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
- // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
- pdl.pattern : benefit(1) {
- %root = operation
- %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
- rewrite %root with "rewriter"(%attr : !pdl.attribute)
- }
- pdl.pattern : benefit(1) {
- %root = operation
- %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
- rewrite %root with "rewriter"(%attr : !pdl.attribute)
- }
-}
-
-// -----
-
// CHECK-LABEL: module @negated_constraint
module @negated_constraint {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
diff --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
deleted file mode 100644
index 4511baff2015a4..00000000000000
--- a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
+++ /dev/null
@@ -1,77 +0,0 @@
-// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
-
-// Ensuse that the dependency between add & less
-// causes them to be in the correct order.
-// CHECK-LABEL: matcher
-// CHECK: apply_constraint "return_attr_constraint"
-// CHECK: apply_constraint "use_attr_constraint"
-
-module {
- pdl.pattern : benefit(1) {
- %0 = attribute
- %1 = types
- %2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
- %3 = attribute = 0 : i32
- %4 = attribute = 1 : i32
- %5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
- apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
- rewrite %2 with "rewriter"
- }
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
-// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
- %inputOp = operation
- %result = result 0 of %inputOp
- %attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
- %root = operation(%result : !pdl.value) {"attr" = %attr}
- rewrite %root with "rewriter"(%attr : !pdl.attribute)
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
-// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
-// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
- %attr = attribute = 10
- %value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
- %root = operation(%value : !pdl.value)
- rewrite %root with "rewriter"
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
-// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
-// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
- %attr = attribute = 10
- %type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
- %root = operation -> (%type : !pdl.type)
- rewrite %root with "rewriter"
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
-// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
-// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
- %attr = attribute = 10
- %types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
- %root = operation -> (%types : !pdl.range<type>)
- rewrite %root with "rewriter"
-}
diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 20e40deea5f863..6e6da5cce446ae 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -134,24 +134,6 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
// -----
-pdl.pattern @apply_constraint_with_no_results : benefit(1) {
- %root = operation
- apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
- rewrite %root with "rewriter"
-}
-
-// -----
-
-pdl.pattern @apply_constraint_with_results : benefit(1) {
- %root = operation
- %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
- rewrite %root {
- apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
- }
-}
-
-// -----
-
pdl.pattern @attribute_with_dict : benefit(1) {
%root = operation
rewrite %root {
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index f8e4f2e83b296a..ae61c1a079548a 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -109,74 +109,6 @@ module @ir attributes { test.apply_constraint_3 } {
// -----
-// Test returning a type from a native constraint.
-module @patterns {
- pdl_interp.func @matcher(%root : !pdl.operation) {
- pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
-
- ^pat:
- %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
-
- ^pat2:
- pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
-
- ^end:
- pdl_interp.finalize
- }
-
- module @rewriters {
- pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
- %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
- pdl_interp.erase %root
- pdl_interp.finalize
- }
- }
-}
-
-// CHECK-LABEL: test.apply_constraint_4
-// CHECK-NOT: "test.replaced_by_pattern"
-// CHECK: "test.replaced_by_pattern"() : () -> f32
-module @ir attributes { test.apply_constraint_4 } {
- "test.failure_op"() : () -> ()
- "test.success_op"() : () -> ()
-}
-
-// -----
-
-// Test success and failure cases of native constraints with pdl.range results.
-module @patterns {
- pdl_interp.func @matcher(%root : !pdl.operation) {
- pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
-
- ^pat:
- %num_results = pdl_interp.create_attribute 2 : i32
- %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
-
- ^pat1:
- pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
-
- ^end:
- pdl_interp.finalize
- }
-
- module @rewriters {
- pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
- %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
- pdl_interp.erase %root
- pdl_interp.finalize
- }
- }
-}
-
-// CHECK-LABEL: test.apply_constraint_5
-// CHECK-NOT: "test.replaced_by_pattern"
-// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
-module @ir attributes { test.apply_constraint_5 } {
- "test.failure_op"() : () -> ()
- "test.success_op"() : () -> ()
-}
-
-// -----
//===----------------------------------------------------------------------===//
// pdl_interp::ApplyRewriteOp
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index b9424e06bf0318..50caf8f9cfc709 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -887,7 +887,7 @@ class TestTransformDialectExtension
#include "TestTransformDialectExtensionTypes.cpp.inc"
>();
- auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
+ auto verboseConstraint = [](PatternRewriter &rewriter,
ArrayRef<PDLValue> pdlValues) {
for (const PDLValue &pdlValue : pdlValues) {
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 56af3c15b905f1..daa1c371f27c92 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -30,50 +30,6 @@ static LogicalResult customMultiEntityVariadicConstraint(
return success();
}
-// Custom constraint that returns a value if the op is named test.success_op
-static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
- PDLResultList &results,
- ArrayRef<PDLValue> args) {
- auto *op = args[0].cast<Operation *>();
- if (op->getName().getStringRef() == "test.success_op") {
- StringAttr customAttr = rewriter.getStringAttr("test.success");
- results.push_back(customAttr);
- return success();
- }
- return failure();
-}
-
-// Custom constraint that returns a type if the op is named test.success_op
-static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
- PDLResultList &results,
- ArrayRef<PDLValue> args) {
- auto *op = args[0].cast<Operation *>();
- if (op->getName().getStringRef() == "test.success_op") {
- results.push_back(rewriter.getF32Type());
- return success();
- }
- return failure();
-}
-
-// Custom constraint that returns a type range of variable length if the op is
-// named test.success_op
-static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
- PDLResultList &results,
- ArrayRef<PDLValue> args) {
- auto *op = args[0].cast<Operation *>();
- int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
-
- if (op->getName().getStringRef() == "test.success_op") {
- SmallVector<Type> types;
- for (int i = 0; i < numTypes; i++) {
- types.push_back(rewriter.getF32Type());
- }
- results.push_back(TypeRange(types));
- return success();
- }
- return failure();
-}
-
// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -146,12 +102,6 @@ struct TestPDLByteCodePass
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint);
- pdlPattern.registerConstraintFunction("op_constr_return_attr",
- customValueResultConstraint);
- pdlPattern.registerConstraintFunction("op_constr_return_type",
- customTypeResultConstraint);
- pdlPattern.registerConstraintFunction("op_constr_return_type_range",
- customTypeRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
index 48747d3fa2e681..18877b4bcc50ec 100644
--- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -158,3 +158,8 @@ Pattern {
// CHECK: expected `;` after native declaration
Constraint Foo() [{}]
+
+// -----
+
+// CHECK: native Constraints currently do not support returning results
+Constraint Foo() -> Op;
diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
index e2a52ff130cc84..1c0a015ab4a7b4 100644
--- a/mlir/test/mlir-pdll/Parser/constraint.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -12,14 +12,6 @@ Constraint Foo() [{ /* Native Code */ }];
// -----
-// Test that native constraints support returning results.
-
-// CHECK: Module
-// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
-Constraint Foo() -> Attr;
-
-// -----
-
// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
// CHECK: `Inputs`
diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
index 95cb25c14873db..0d364f9222a657 100644
--- a/mlir/test/python/dialects/pdl_ops.py
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -298,6 +298,6 @@ def test_apply_native_constraint():
pattern = PatternOp(1)
with InsertionPoint(pattern.body):
resultType = TypeOp()
- ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
+ ApplyNativeConstraintOp("typeConstraint", args=[resultType])
root = OperationOp(types=[resultType])
RewriteOp(root, name="rewrite")
More information about the Mlir-commits
mailing list