[Mlir-commits] [mlir] 9e57210 - [mlir:PDLL] Add support for building a range from a tuple within a rewrite
River Riddle
llvmlistbot at llvm.org
Tue Nov 8 01:58:28 PST 2022
Author: River Riddle
Date: 2022-11-08T01:57:57-08:00
New Revision: 9e57210ad9f76b93bb8e9fc441be5cf3597b2928
URL: https://github.com/llvm/llvm-project/commit/9e57210ad9f76b93bb8e9fc441be5cf3597b2928
DIFF: https://github.com/llvm/llvm-project/commit/9e57210ad9f76b93bb8e9fc441be5cf3597b2928.diff
LOG: [mlir:PDLL] Add support for building a range from a tuple within a rewrite
This allows for constructing type and value ranges from various sub elements,
which makes it easier to construct operations that take a range as an operand
or result type. Range construction is currently limited to within rewrites, to match
the current constraint on the PDL side.
Differential Revision: https://reviews.llvm.org/D133720
Added:
Modified:
mlir/include/mlir/Tools/PDLL/AST/Nodes.h
mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
mlir/lib/Tools/PDLL/AST/Nodes.cpp
mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
mlir/test/mlir-pdll/Parser/expr-failure.pdll
mlir/test/mlir-pdll/Parser/expr.pdll
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index 716b2be9f51ff..5f282b0c8884f 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -566,6 +566,40 @@ class OperationExpr final
}
};
+//===----------------------------------------------------------------------===//
+// RangeExpr
+//===----------------------------------------------------------------------===//
+
+/// This expression builds a range from a set of element values (which may be
+/// ranges themselves).
+class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
+ private llvm::TrailingObjects<RangeExpr, Expr *> {
+public:
+ static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
+ RangeType type);
+
+ /// Return the element expressions of this range.
+ MutableArrayRef<Expr *> getElements() {
+ return {getTrailingObjects<Expr *>(), numElements};
+ }
+ ArrayRef<Expr *> getElements() const {
+ return const_cast<RangeExpr *>(this)->getElements();
+ }
+
+ /// Return the range result type of this expression.
+ RangeType getType() const { return Base::getType().cast<RangeType>(); }
+
+private:
+ RangeExpr(SMRange loc, RangeType type, unsigned numElements)
+ : Base(loc, type), numElements(numElements) {}
+
+ /// The number of element values for this range.
+ unsigned numElements;
+
+ /// TrailingObject utilities.
+ friend class llvm::TrailingObjects<RangeExpr, Expr *>;
+};
+
//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//
@@ -1284,7 +1318,7 @@ inline bool CoreConstraintDecl::classof(const Node *node) {
inline bool Expr::classof(const Node *node) {
return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr,
- OperationExpr, TupleExpr, TypeExpr>(node);
+ OperationExpr, RangeExpr, TupleExpr, TypeExpr>(node);
}
inline bool OpRewriteStmt::classof(const Node *node) {
diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
index 3b08f5413c9ae..cc27bb6cdfceb 100644
--- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
+++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
@@ -84,6 +84,7 @@ class NodePrinter {
void printImpl(const DeclRefExpr *expr);
void printImpl(const MemberAccessExpr *expr);
void printImpl(const OperationExpr *expr);
+ void printImpl(const RangeExpr *expr);
void printImpl(const TupleExpr *expr);
void printImpl(const TypeExpr *expr);
@@ -169,8 +170,8 @@ void NodePrinter::print(const Node *node) {
// Expressions.
const AttributeExpr, const CallExpr, const DeclRefExpr,
- const MemberAccessExpr, const OperationExpr, const TupleExpr,
- const TypeExpr,
+ const MemberAccessExpr, const OperationExpr, const RangeExpr,
+ const TupleExpr, const TypeExpr,
// Decls.
const AttrConstraintDecl, const OpConstraintDecl,
@@ -254,6 +255,14 @@ void NodePrinter::printImpl(const OperationExpr *expr) {
printChildren("Attributes", expr->getAttributes());
}
+void NodePrinter::printImpl(const RangeExpr *expr) {
+ os << "RangeExpr " << expr << " Type<";
+ print(expr->getType());
+ os << ">\n";
+
+ printChildren(expr->getElements());
+}
+
void NodePrinter::printImpl(const TupleExpr *expr) {
os << "TupleExpr " << expr << " Type<";
print(expr->getType());
diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 57b52fde62617..0129f9fd1d0bc 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -57,8 +57,8 @@ class NodeVisitor {
// Expressions.
const AttributeExpr, const CallExpr, const DeclRefExpr,
- const MemberAccessExpr, const OperationExpr, const TupleExpr,
- const TypeExpr,
+ const MemberAccessExpr, const OperationExpr, const RangeExpr,
+ const TupleExpr, const TypeExpr,
// Core Constraint Decls.
const AttrConstraintDecl, const OpConstraintDecl,
@@ -109,6 +109,10 @@ class NodeVisitor {
for (const Node *child : expr->getAttributes())
visit(child);
}
+ void visitImpl(const RangeExpr *expr) {
+ for (const Node *child : expr->getElements())
+ visit(child);
+ }
void visitImpl(const TupleExpr *expr) {
for (const Node *child : expr->getElements())
visit(child);
@@ -325,6 +329,21 @@ Optional<StringRef> OperationExpr::getName() const {
return getNameDecl()->getName();
}
+//===----------------------------------------------------------------------===//
+// RangeExpr
+//===----------------------------------------------------------------------===//
+
+RangeExpr *RangeExpr::create(Context &ctx, SMRange loc,
+ ArrayRef<Expr *> elements, RangeType type) {
+ unsigned allocSize = RangeExpr::totalSizeToAlloc<Expr *>(elements.size());
+ void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
+
+ RangeExpr *expr = new (rawData) RangeExpr(loc, type, elements.size());
+ std::uninitialized_copy(elements.begin(), elements.end(),
+ expr->getElements().begin());
+ return expr;
+}
+
//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 86a24e4e9b5fa..33ede71b987b9 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -97,6 +97,7 @@ class CodeGen {
SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
Value genExprImpl(const ast::MemberAccessExpr *expr);
Value genExprImpl(const ast::OperationExpr *expr);
+ Value genExprImpl(const ast::RangeExpr *expr);
SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
Value genExprImpl(const ast::TypeExpr *expr);
@@ -377,7 +378,8 @@ void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
Value CodeGen::genSingleExpr(const ast::Expr *expr) {
return TypeSwitch<const ast::Expr *, Value>(expr)
.Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
- const ast::OperationExpr, const ast::TypeExpr>(
+ const ast::OperationExpr, const ast::RangeExpr,
+ const ast::TypeExpr>(
[&](auto derivedNode) { return this->genExprImpl(derivedNode); })
.Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
[&](auto derivedNode) {
@@ -517,6 +519,15 @@ Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
attrValues, results);
}
+Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {
+ SmallVector<Value> elements;
+ for (const ast::Expr *element : expr->getElements())
+ llvm::append_range(elements, genExpr(element));
+
+ return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
+ genType(expr->getType()), elements);
+}
+
SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
SmallVector<Value> elements;
for (const ast::Expr *element : expr->getElements())
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 78bd9a8e1f0ef..ffa7f0cf52ff5 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -47,10 +47,9 @@ class Parser {
bool enableDocumentation, CodeCompleteContext *codeCompleteContext)
: ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext),
curToken(lexer.lexToken()), enableDocumentation(enableDocumentation),
- valueTy(ast::ValueType::get(ctx)),
- valueRangeTy(ast::ValueRangeType::get(ctx)),
- typeTy(ast::TypeType::get(ctx)),
+ typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)),
typeRangeTy(ast::TypeRangeType::get(ctx)),
+ valueRangeTy(ast::ValueRangeType::get(ctx)),
attrTy(ast::AttributeType::get(ctx)),
codeCompleteContext(codeCompleteContext) {}
@@ -116,6 +115,14 @@ class Parser {
LogicalResult convertExpressionTo(
ast::Expr *&expr, ast::Type type,
function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
+ LogicalResult
+ convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType,
+ ast::Type type,
+ function_ref<ast::InFlightDiagnostic()> emitErrorFn);
+ LogicalResult convertTupleExpressionTo(
+ ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
+ function_ref<ast::InFlightDiagnostic()> emitErrorFn,
+ function_ref<void(ast::Diagnostic &diag)> noteAttachFn);
/// Given an operation expression, convert it to a Value or ValueRange
/// typed expression.
@@ -555,8 +562,8 @@ class Parser {
ParserContext parserContext = ParserContext::Global;
/// Cached types to simplify verification and expression creation.
- ast::Type valueTy, valueRangeTy;
- ast::Type typeTy, typeRangeTy;
+ ast::Type typeTy, valueTy;
+ ast::RangeType typeRangeTy, valueRangeTy;
ast::Type attrTy;
/// A counter used when naming anonymous constraints and rewrites.
@@ -619,55 +626,8 @@ LogicalResult Parser::convertExpressionTo(
return diag;
};
- if (auto exprOpType = exprType.dyn_cast<ast::OperationType>()) {
- // Two operation types are compatible if they have the same name, or if the
- // expected type is more general.
- if (auto opType = type.dyn_cast<ast::OperationType>()) {
- if (opType.getName())
- return emitConvertError();
- return success();
- }
-
- // An operation can always convert to a ValueRange.
- if (type == valueRangeTy) {
- expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
- valueRangeTy);
- return success();
- }
-
- // Allow conversion to a single value by constraining the result range.
- if (type == valueTy) {
- // If the operation is registered, we can verify if it can ever have a
- // single result.
- if (const ods::Operation *odsOp = exprOpType.getODSOperation()) {
- if (odsOp->getResults().empty()) {
- return emitConvertError()->attachNote(
- llvm::formatv("see the definition of `{0}`, which was defined "
- "with zero results",
- odsOp->getName()),
- odsOp->getLoc());
- }
-
- unsigned numSingleResults = llvm::count_if(
- odsOp->getResults(), [](const ods::OperandOrResult &result) {
- return result.getVariableLengthKind() ==
- ods::VariableLengthKind::Single;
- });
- if (numSingleResults > 1) {
- return emitConvertError()->attachNote(
- llvm::formatv("see the definition of `{0}`, which was defined "
- "with at least {1} results",
- odsOp->getName(), numSingleResults),
- odsOp->getLoc());
- }
- }
-
- expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
- valueTy);
- return success();
- }
- return emitConvertError();
- }
+ if (auto exprOpType = exprType.dyn_cast<ast::OperationType>())
+ return convertOpExpressionTo(expr, exprOpType, type, emitConvertError);
// FIXME: Decide how to allow/support converting a single result to multiple,
// and multiple to a single result. For now, we just allow Single->Range,
@@ -681,22 +641,85 @@ LogicalResult Parser::convertExpressionTo(
return success();
// Handle tuple types.
- if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>()) {
- auto tupleType = type.dyn_cast<ast::TupleType>();
- if (!tupleType || tupleType.size() != exprTupleType.size())
- return emitConvertError();
+ if (auto exprTupleType = exprType.dyn_cast<ast::TupleType>())
+ return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError,
+ noteAttachFn);
+
+ return emitConvertError();
+}
+
+LogicalResult Parser::convertOpExpressionTo(
+ ast::Expr *&expr, ast::OperationType exprType, ast::Type type,
+ function_ref<ast::InFlightDiagnostic()> emitErrorFn) {
+ // Two operation types are compatible if they have the same name, or if the
+ // expected type is more general.
+ if (auto opType = type.dyn_cast<ast::OperationType>()) {
+ if (opType.getName())
+ return emitErrorFn();
+ return success();
+ }
+
+ // An operation can always convert to a ValueRange.
+ if (type == valueRangeTy) {
+ expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
+ valueRangeTy);
+ return success();
+ }
+
+ // Allow conversion to a single value by constraining the result range.
+ if (type == valueTy) {
+ // If the operation is registered, we can verify if it can ever have a
+ // single result.
+ if (const ods::Operation *odsOp = exprType.getODSOperation()) {
+ if (odsOp->getResults().empty()) {
+ return emitErrorFn()->attachNote(
+ llvm::formatv("see the definition of `{0}`, which was defined "
+ "with zero results",
+ odsOp->getName()),
+ odsOp->getLoc());
+ }
+
+ unsigned numSingleResults = llvm::count_if(
+ odsOp->getResults(), [](const ods::OperandOrResult &result) {
+ return result.getVariableLengthKind() ==
+ ods::VariableLengthKind::Single;
+ });
+ if (numSingleResults > 1) {
+ return emitErrorFn()->attachNote(
+ llvm::formatv("see the definition of `{0}`, which was defined "
+ "with at least {1} results",
+ odsOp->getName(), numSingleResults),
+ odsOp->getLoc());
+ }
+ }
+
+ expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
+ valueTy);
+ return success();
+ }
+ return emitErrorFn();
+}
+
+LogicalResult Parser::convertTupleExpressionTo(
+ ast::Expr *&expr, ast::TupleType exprType, ast::Type type,
+ function_ref<ast::InFlightDiagnostic()> emitErrorFn,
+ function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
+ // Handle conversions between tuples.
+ if (auto tupleType = type.dyn_cast<ast::TupleType>()) {
+ if (tupleType.size() != exprType.size())
+ return emitErrorFn();
// Build a new tuple expression using each of the elements of the current
// tuple.
SmallVector<ast::Expr *> newExprs;
- for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) {
+ for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
newExprs.push_back(ast::MemberAccessExpr::create(
ctx, expr->getLoc(), expr, llvm::to_string(i),
- exprTupleType.getElementTypes()[i]));
+ exprType.getElementTypes()[i]));
auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`",
- i, exprTupleType));
+ i, exprType));
if (noteAttachFn)
noteAttachFn(diag);
};
@@ -709,7 +732,37 @@ LogicalResult Parser::convertExpressionTo(
return success();
}
- return emitConvertError();
+ // Handle conversion to a range.
+ auto convertToRange = [&](ArrayRef<ast::Type> allowedElementTypes,
+ ast::RangeType resultTy) -> LogicalResult {
+ // TODO: We currently only allow range conversion within a rewrite context.
+ if (parserContext != ParserContext::Rewrite) {
+ return emitErrorFn()->attachNote("Tuple to Range conversion is currently "
+ "only allowed within a rewrite context");
+ }
+
+ // All of the tuple elements must be allowed types.
+ for (ast::Type elementType : exprType.getElementTypes())
+ if (!llvm::is_contained(allowedElementTypes, elementType))
+ return emitErrorFn();
+
+ // Build a new tuple expression using each of the elements of the current
+ // tuple.
+ SmallVector<ast::Expr *> newExprs;
+ for (unsigned i = 0, e = exprType.size(); i < e; ++i) {
+ newExprs.push_back(ast::MemberAccessExpr::create(
+ ctx, expr->getLoc(), expr, llvm::to_string(i),
+ exprType.getElementTypes()[i]));
+ }
+ expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy);
+ return success();
+ };
+ if (type == valueRangeTy)
+ return convertToRange({valueTy, valueRangeTy}, valueRangeTy);
+ if (type == typeRangeTy)
+ return convertToRange({typeTy, typeRangeTy}, typeRangeTy);
+
+ return emitErrorFn();
}
//===----------------------------------------------------------------------===//
@@ -2955,6 +3008,10 @@ LogicalResult Parser::validateOperationOperandsOrResults(
}
}
+ // Otherwise, try to convert the expression to a range.
+ if (succeeded(convertExpressionTo(valueExpr, rangeTy)))
+ continue;
+
return emitError(
valueExpr->getLoc(),
llvm::formatv(
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
index 671a3b4818c63..950b90d75d6a4 100644
--- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
@@ -114,6 +114,26 @@ Pattern TupleMemberAccessName {
// -----
+//===----------------------------------------------------------------------===//
+// RangeExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @RangeExpr
+// CHECK: %[[ARG:.*]] = operand
+// CHECK: %[[ARGS:.*]] = operands
+// CHECK: %[[TYPE:.*]] = type
+// CHECK: %[[TYPES:.*]] = types
+// CHECK: range : !pdl.range<value>
+// CHECK: range %[[ARG]], %[[ARGS]] : !pdl.value, !pdl.range<value>
+// CHECK: range : !pdl.range<type>
+// CHECK: range %[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>
+Pattern RangeExpr {
+ replace op<>(arg: Value, args: ValueRange) -> (type: Type, types: TypeRange)
+ with op<test.op>((), (arg, args)) -> ((), (type, types));
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// TypeExpr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index de8c6163895e3..31258cb99ebc4 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -124,6 +124,25 @@ Pattern {
// -----
+//===----------------------------------------------------------------------===//
+// Range Expr
+//===----------------------------------------------------------------------===//
+
+Pattern {
+ // CHECK: unable to convert expression of type `Tuple<>` to the expected type of `ValueRange`
+ // CHECK: Tuple to Range conversion is currently only allowed within a rewrite context
+ erase op<>(());
+}
+
+// -----
+
+Pattern {
+ // CHECK: unable to convert expression of type `Tuple<Value, Type>` to the expected type of `ValueRange`
+ replace op<>(arg: Value) -> (type: Type) with op<test.op>((arg, type));
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Tuple Expr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index 5320efd2e690a..6e68883f4edee 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -235,6 +235,29 @@ Pattern {
// -----
+//===----------------------------------------------------------------------===//
+// RangeExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: Module
+// CHECK: `Operands`
+// CHECK: -RangeExpr {{.*}} Type<ValueRange>
+// CHECK: -RangeExpr {{.*}} Type<ValueRange>
+// CHECK: -MemberAccessExpr {{.*}} Member<0> Type<Value>
+// CHECK: -MemberAccessExpr {{.*}} Member<1> Type<ValueRange>
+// CHECK: `Result Types`
+// CHECK: -RangeExpr {{.*}} Type<TypeRange>
+// CHECK: -RangeExpr {{.*}} Type<TypeRange>
+// CHECK: -MemberAccessExpr {{.*}} Member<0> Type<Type>
+// CHECK: -MemberAccessExpr {{.*}} Member<1> Type<TypeRange>
+Pattern {
+ rewrite op<>(arg: Value, args: ValueRange) -> (type: Type, types: TypeRange) with {
+ op<test.op>((), (arg, args)) -> ((), (type, types));
+ };
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// TypeExpr
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list