[Mlir-commits] [mlir] 233e947 - [mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that create constants
River Riddle
llvmlistbot at llvm.org
Fri Dec 10 11:50:37 PST 2021
Author: River Riddle
Date: 2021-12-10T19:38:43Z
New Revision: 233e9476d8bec85b7711041d7c7536ac6f87e9ef
URL: https://github.com/llvm/llvm-project/commit/233e9476d8bec85b7711041d7c7536ac6f87e9ef
DIFF: https://github.com/llvm/llvm-project/commit/233e9476d8bec85b7711041d7c7536ac6f87e9ef.diff
LOG: [mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that create constants
This allows for passing in these attributes/types to constraints/rewrites as arguments.
Differential Revision: https://reviews.llvm.org/D114817
Added:
Modified:
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 6918b9681ddc7..931d12a6687a3 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -512,6 +512,13 @@ def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> {
let arguments = (ins TypeArrayAttr:$value);
let results = (outs PDL_RangeOf<PDL_Type>:$result);
let assemblyFormat = "$value attr-dict";
+
+ let builders = [
+ OpBuilder<(ins "ArrayAttr":$type), [{
+ build($_builder, $_state,
+ pdl::RangeType::get($_builder.getType<pdl::TypeType>()), type);
+ }]>
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 87dcd2723686f..7db7dc03dc80d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -237,10 +237,12 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
return val;
// Get the value for the parent position.
- Value parentVal = getValueAt(currentBlock, pos->getParent());
+ Value parentVal;
+ if (Position *parent = pos->getParent())
+ parentVal = getValueAt(currentBlock, pos->getParent());
// TODO: Use a location from the position.
- Location loc = parentVal.getLoc();
+ Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
builder.setInsertionPointToEnd(currentBlock);
Value value;
switch (pos->getKind()) {
@@ -331,6 +333,22 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
parentVal, resPos->getResultGroupNumber());
break;
}
+ case Predicates::AttributeLiteralPos: {
+ auto *attrPos = cast<AttributeLiteralPosition>(pos);
+ value =
+ builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
+ break;
+ }
+ case Predicates::TypeLiteralPos: {
+ auto *typePos = cast<TypeLiteralPosition>(pos);
+ Attribute rawTypeAttr = typePos->getValue();
+ if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
+ value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
+ else
+ value = builder.create<pdl_interp::CreateTypesOp>(
+ loc, rawTypeAttr.cast<ArrayAttr>());
+ break;
+ }
default:
llvm_unreachable("Generating unknown Position getter");
break;
@@ -353,7 +371,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
args = {getValueAt(currentBlock, equalToQuestion->getValue())};
} else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
- for (Position *position : std::get<1>(cstQuestion->getValue()))
+ for (Position *position : cstQuestion->getArgs())
args.push_back(getValueAt(currentBlock, position));
}
@@ -413,10 +431,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
break;
}
case Predicates::ConstraintQuestion: {
- auto value = cast<ConstraintQuestion>(question)->getValue();
+ auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(
- loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(),
- success, failure);
+ loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
+ failure);
break;
}
default:
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
index 8d560ebbcde80..8d6d8776bde32 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
@@ -21,7 +21,7 @@ Position::~Position() {}
unsigned Position::getOperationDepth() const {
if (const auto *operationPos = dyn_cast<OperationPosition>(this))
return operationPos->getDepth();
- return parent->getOperationDepth();
+ return parent ? parent->getOperationDepth() : 0;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 4a7dcdc696f48..266580bd41f59 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -50,6 +50,8 @@ enum Kind : unsigned {
ResultPos,
ResultGroupPos,
TypePos,
+ AttributeLiteralPos,
+ TypeLiteralPos,
// Questions, ordered by dependency and decreasing priority.
IsNotNullQuestion,
@@ -173,6 +175,16 @@ struct AttributePosition
StringAttr getName() const { return key.second; }
};
+//===----------------------------------------------------------------------===//
+// AttributeLiteralPosition
+
+/// A position describing a literal attribute.
+struct AttributeLiteralPosition
+ : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
+ Predicates::AttributeLiteralPos> {
+ using PredicateBase::PredicateBase;
+};
+
//===----------------------------------------------------------------------===//
// OperandPosition
@@ -317,6 +329,17 @@ struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
}
};
+//===----------------------------------------------------------------------===//
+// TypeLiteralPosition
+
+/// A position describing a literal type or type range. The value is stored as
+/// either a TypeAttr, or an ArrayAttr of TypeAttr.
+struct TypeLiteralPosition
+ : public PredicateBase<TypeLiteralPosition, Position, Attribute,
+ Predicates::TypeLiteralPos> {
+ using PredicateBase::PredicateBase;
+};
+
//===----------------------------------------------------------------------===//
// Qualifiers
//===----------------------------------------------------------------------===//
@@ -404,6 +427,17 @@ struct ConstraintQuestion
Predicates::ConstraintQuestion> {
using Base::Base;
+ /// Return the name of the constraint.
+ StringRef getName() const { return std::get<0>(key); }
+
+ /// Return the arguments of the constraint.
+ ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
+
+ /// Return the constant parameters of the constraint.
+ ArrayAttr getParams() const {
+ return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
+ }
+
/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
@@ -461,12 +495,14 @@ class PredicateUniquer : public StorageUniquer {
PredicateUniquer() {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
+ registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
registerParametricStorageType<ResultGroupPosition>();
registerParametricStorageType<TypePosition>();
+ registerParametricStorageType<TypeLiteralPosition>();
// Register the types of Questions with the uniquer.
registerParametricStorageType<AttributeAnswer>();
@@ -527,6 +563,11 @@ class PredicateBuilder {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
}
+ /// Returns an attribute position for the given attribute.
+ Position *getAttributeLiteral(Attribute attr) {
+ return AttributeLiteralPosition::get(uniquer, attr);
+ }
+
/// Returns an operand position for an operand of the given operation.
Position *getOperand(OperationPosition *p, unsigned operand) {
return OperandPosition::get(uniquer, p, operand);
@@ -558,6 +599,12 @@ class PredicateBuilder {
/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
+ /// Returns a type position for the given type value. The value is stored
+ /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
+ Position *getTypeLiteral(Attribute attr) {
+ return TypeLiteralPosition::get(uniquer, attr);
+ }
+
//===--------------------------------------------------------------------===//
// Qualifiers
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 7e2219ca1416a..b22e50549dfd8 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -243,8 +243,18 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
}
-/// Collect all of the predicates related to constraints within the given
-/// pattern operation.
+static void getAttributePredicates(pdl::AttributeOp op,
+ std::vector<PositionalPredicate> &predList,
+ PredicateBuilder &builder,
+ DenseMap<Value, Position *> &inputs) {
+ Position *&attrPos = inputs[op];
+ if (attrPos)
+ return;
+ Attribute value = op.valueAttr();
+ assert(value && "expected non-tree `pdl.attribute` to contain a value");
+ attrPos = builder.getAttributeLiteral(value);
+}
+
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
@@ -296,6 +306,19 @@ static void getResultPredicates(pdl::ResultsOp op,
predList.emplace_back(resultPos, builder.getIsNotNull());
}
+static void getTypePredicates(Value typeValue,
+ function_ref<Attribute()> typeAttrFn,
+ PredicateBuilder &builder,
+ DenseMap<Value, Position *> &inputs) {
+ Position *&typePos = inputs[typeValue];
+ if (typePos)
+ return;
+ Attribute typeAttr = typeAttrFn();
+ assert(typeAttr &&
+ "expected non-tree `pdl.type`/`pdl.types` to contain a value");
+ typePos = builder.getTypeLiteral(typeAttr);
+}
+
/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
@@ -304,11 +327,22 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
TypeSwitch<Operation *>(&op)
+ .Case([&](pdl::AttributeOp attrOp) {
+ getAttributePredicates(attrOp, predList, builder, inputs);
+ })
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
getConstraintPredicates(constraintOp, predList, builder, inputs);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
getResultPredicates(resultOp, predList, builder, inputs);
+ })
+ .Case([&](pdl::TypeOp typeOp) {
+ getTypePredicates(
+ typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
+ })
+ .Case([&](pdl::TypesOp typeOp) {
+ getTypePredicates(
+ typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
});
}
}
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 81a8f60610bce..2363668618e5e 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -114,12 +114,15 @@ static LogicalResult verify(AttributeOp op) {
Value attrType = op.type();
Optional<Attribute> attrValue = op.value();
- if (!attrValue && isa<RewriteOp>(op->getParentOp()))
- return op.emitOpError("expected constant value when specified within a "
- "`pdl.rewrite`");
- if (attrValue && attrType)
+ if (!attrValue) {
+ if (isa<RewriteOp>(op->getParentOp()))
+ return op.emitOpError("expected constant value when specified within a "
+ "`pdl.rewrite`");
+ return verifyHasBindingUse(op);
+ }
+ if (attrType)
return op.emitOpError("expected only one of [`type`, `value`] to be set");
- return verifyHasBindingUse(op);
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -431,13 +434,21 @@ static LogicalResult verify(RewriteOp op) {
// pdl::TypeOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); }
+static LogicalResult verify(TypeOp op) {
+ if (!op.typeAttr())
+ return verifyHasBindingUse(op);
+ return success();
+}
//===----------------------------------------------------------------------===//
// pdl::TypesOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); }
+static LogicalResult verify(TypesOp op) {
+ if (!op.typesAttr())
+ return verifyHasBindingUse(op);
+ return success();
+}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 0efcd60945a7b..984a31790a8bc 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -573,3 +573,42 @@ module @variadic_results_at {
pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
}
}
+
+// -----
+
+// CHECK-LABEL: module @attribute_literal
+module @attribute_literal {
+ // CHECK: func @matcher(%{{.*}}: !pdl.operation)
+ // CHECK: %[[ATTR:.*]] = pdl_interp.create_attribute 10 : i64
+ // CHECK: pdl_interp.apply_constraint "constraint"(%[[ATTR]] : !pdl.attribute)
+
+ // Check the correct lowering of an attribute that hasn't been bound.
+ pdl.pattern : benefit(1) {
+ %attr = pdl.attribute 10
+ pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute)
+
+ %root = pdl.operation
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @type_literal
+module @type_literal {
+ // CHECK: func @matcher(%{{.*}}: !pdl.operation)
+ // CHECK: %[[TYPE:.*]] = pdl_interp.create_type i32
+ // CHECK: %[[TYPES:.*]] = pdl_interp.create_types [i32, i64]
+ // CHECK: pdl_interp.apply_constraint "constraint"(%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)
+
+ // Check the correct lowering of a type that hasn't been bound.
+ pdl.pattern : benefit(1) {
+ %type = pdl.type : i32
+ %types = pdl.types : [i32, i64]
+ pdl.apply_native_constraint "constraint"(%type, %types: !pdl.type, !pdl.range<type>)
+
+ %root = pdl.operation
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
More information about the Mlir-commits
mailing list