[Mlir-commits] [mlir] 233e947 - [mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that create constants

River Riddle llvmlistbot at llvm.org
Fri Dec 10 11:50:37 PST 2021


Author: River Riddle
Date: 2021-12-10T19:38:43Z
New Revision: 233e9476d8bec85b7711041d7c7536ac6f87e9ef

URL: https://github.com/llvm/llvm-project/commit/233e9476d8bec85b7711041d7c7536ac6f87e9ef
DIFF: https://github.com/llvm/llvm-project/commit/233e9476d8bec85b7711041d7c7536ac6f87e9ef.diff

LOG: [mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that create constants

This allows for passing in these attributes/types to constraints/rewrites as arguments.

Differential Revision: https://reviews.llvm.org/D114817

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 6918b9681ddc7..931d12a6687a3 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -512,6 +512,13 @@ def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> {
   let arguments = (ins TypeArrayAttr:$value);
   let results = (outs PDL_RangeOf<PDL_Type>:$result);
   let assemblyFormat = "$value attr-dict";
+
+  let builders = [
+    OpBuilder<(ins "ArrayAttr":$type), [{
+      build($_builder, $_state,
+            pdl::RangeType::get($_builder.getType<pdl::TypeType>()), type);
+    }]>
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 87dcd2723686f..7db7dc03dc80d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -237,10 +237,12 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
     return val;
 
   // Get the value for the parent position.
-  Value parentVal = getValueAt(currentBlock, pos->getParent());
+  Value parentVal;
+  if (Position *parent = pos->getParent())
+    parentVal = getValueAt(currentBlock, pos->getParent());
 
   // TODO: Use a location from the position.
-  Location loc = parentVal.getLoc();
+  Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
   builder.setInsertionPointToEnd(currentBlock);
   Value value;
   switch (pos->getKind()) {
@@ -331,6 +333,22 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
         parentVal, resPos->getResultGroupNumber());
     break;
   }
+  case Predicates::AttributeLiteralPos: {
+    auto *attrPos = cast<AttributeLiteralPosition>(pos);
+    value =
+        builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
+    break;
+  }
+  case Predicates::TypeLiteralPos: {
+    auto *typePos = cast<TypeLiteralPosition>(pos);
+    Attribute rawTypeAttr = typePos->getValue();
+    if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
+      value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
+    else
+      value = builder.create<pdl_interp::CreateTypesOp>(
+          loc, rawTypeAttr.cast<ArrayAttr>());
+    break;
+  }
   default:
     llvm_unreachable("Generating unknown Position getter");
     break;
@@ -353,7 +371,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
   if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
     args = {getValueAt(currentBlock, equalToQuestion->getValue())};
   } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
-    for (Position *position : std::get<1>(cstQuestion->getValue()))
+    for (Position *position : cstQuestion->getArgs())
       args.push_back(getValueAt(currentBlock, position));
   }
 
@@ -413,10 +431,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
     break;
   }
   case Predicates::ConstraintQuestion: {
-    auto value = cast<ConstraintQuestion>(question)->getValue();
+    auto *cstQuestion = cast<ConstraintQuestion>(question);
     builder.create<pdl_interp::ApplyConstraintOp>(
-        loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(),
-        success, failure);
+        loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
+        failure);
     break;
   }
   default:

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
index 8d560ebbcde80..8d6d8776bde32 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
@@ -21,7 +21,7 @@ Position::~Position() {}
 unsigned Position::getOperationDepth() const {
   if (const auto *operationPos = dyn_cast<OperationPosition>(this))
     return operationPos->getDepth();
-  return parent->getOperationDepth();
+  return parent ? parent->getOperationDepth() : 0;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 4a7dcdc696f48..266580bd41f59 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -50,6 +50,8 @@ enum Kind : unsigned {
   ResultPos,
   ResultGroupPos,
   TypePos,
+  AttributeLiteralPos,
+  TypeLiteralPos,
 
   // Questions, ordered by dependency and decreasing priority.
   IsNotNullQuestion,
@@ -173,6 +175,16 @@ struct AttributePosition
   StringAttr getName() const { return key.second; }
 };
 
+//===----------------------------------------------------------------------===//
+// AttributeLiteralPosition
+
+/// A position describing a literal attribute.
+struct AttributeLiteralPosition
+    : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
+                           Predicates::AttributeLiteralPos> {
+  using PredicateBase::PredicateBase;
+};
+
 //===----------------------------------------------------------------------===//
 // OperandPosition
 
@@ -317,6 +329,17 @@ struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
   }
 };
 
+//===----------------------------------------------------------------------===//
+// TypeLiteralPosition
+
+/// A position describing a literal type or type range. The value is stored as
+/// either a TypeAttr, or an ArrayAttr of TypeAttr.
+struct TypeLiteralPosition
+    : public PredicateBase<TypeLiteralPosition, Position, Attribute,
+                           Predicates::TypeLiteralPos> {
+  using PredicateBase::PredicateBase;
+};
+
 //===----------------------------------------------------------------------===//
 // Qualifiers
 //===----------------------------------------------------------------------===//
@@ -404,6 +427,17 @@ struct ConstraintQuestion
           Predicates::ConstraintQuestion> {
   using Base::Base;
 
+  /// Return the name of the constraint.
+  StringRef getName() const { return std::get<0>(key); }
+
+  /// Return the arguments of the constraint.
+  ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
+
+  /// Return the constant parameters of the constraint.
+  ArrayAttr getParams() const {
+    return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
+  }
+
   /// Construct an instance with the given storage allocator.
   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
                                        KeyTy key) {
@@ -461,12 +495,14 @@ class PredicateUniquer : public StorageUniquer {
   PredicateUniquer() {
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
+    registerParametricStorageType<AttributeLiteralPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
     registerParametricStorageType<OperationPosition>();
     registerParametricStorageType<ResultPosition>();
     registerParametricStorageType<ResultGroupPosition>();
     registerParametricStorageType<TypePosition>();
+    registerParametricStorageType<TypeLiteralPosition>();
 
     // Register the types of Questions with the uniquer.
     registerParametricStorageType<AttributeAnswer>();
@@ -527,6 +563,11 @@ class PredicateBuilder {
     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
   }
 
+  /// Returns an attribute position for the given attribute.
+  Position *getAttributeLiteral(Attribute attr) {
+    return AttributeLiteralPosition::get(uniquer, attr);
+  }
+
   /// Returns an operand position for an operand of the given operation.
   Position *getOperand(OperationPosition *p, unsigned operand) {
     return OperandPosition::get(uniquer, p, operand);
@@ -558,6 +599,12 @@ class PredicateBuilder {
   /// Returns a type position for the given entity.
   Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
 
+  /// Returns a type position for the given type value. The value is stored
+  /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
+  Position *getTypeLiteral(Attribute attr) {
+    return TypeLiteralPosition::get(uniquer, attr);
+  }
+
   //===--------------------------------------------------------------------===//
   // Qualifiers
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 7e2219ca1416a..b22e50549dfd8 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -243,8 +243,18 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
       .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
 }
 
-/// Collect all of the predicates related to constraints within the given
-/// pattern operation.
+static void getAttributePredicates(pdl::AttributeOp op,
+                                   std::vector<PositionalPredicate> &predList,
+                                   PredicateBuilder &builder,
+                                   DenseMap<Value, Position *> &inputs) {
+  Position *&attrPos = inputs[op];
+  if (attrPos)
+    return;
+  Attribute value = op.valueAttr();
+  assert(value && "expected non-tree `pdl.attribute` to contain a value");
+  attrPos = builder.getAttributeLiteral(value);
+}
+
 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
                                     std::vector<PositionalPredicate> &predList,
                                     PredicateBuilder &builder,
@@ -296,6 +306,19 @@ static void getResultPredicates(pdl::ResultsOp op,
     predList.emplace_back(resultPos, builder.getIsNotNull());
 }
 
+static void getTypePredicates(Value typeValue,
+                              function_ref<Attribute()> typeAttrFn,
+                              PredicateBuilder &builder,
+                              DenseMap<Value, Position *> &inputs) {
+  Position *&typePos = inputs[typeValue];
+  if (typePos)
+    return;
+  Attribute typeAttr = typeAttrFn();
+  assert(typeAttr &&
+         "expected non-tree `pdl.type`/`pdl.types` to contain a value");
+  typePos = builder.getTypeLiteral(typeAttr);
+}
+
 /// Collect all of the predicates that cannot be determined via walking the
 /// tree.
 static void getNonTreePredicates(pdl::PatternOp pattern,
@@ -304,11 +327,22 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
                                  DenseMap<Value, Position *> &inputs) {
   for (Operation &op : pattern.body().getOps()) {
     TypeSwitch<Operation *>(&op)
+        .Case([&](pdl::AttributeOp attrOp) {
+          getAttributePredicates(attrOp, predList, builder, inputs);
+        })
         .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
           getConstraintPredicates(constraintOp, predList, builder, inputs);
         })
         .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
           getResultPredicates(resultOp, predList, builder, inputs);
+        })
+        .Case([&](pdl::TypeOp typeOp) {
+          getTypePredicates(
+              typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
+        })
+        .Case([&](pdl::TypesOp typeOp) {
+          getTypePredicates(
+              typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
         });
   }
 }

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 81a8f60610bce..2363668618e5e 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -114,12 +114,15 @@ static LogicalResult verify(AttributeOp op) {
   Value attrType = op.type();
   Optional<Attribute> attrValue = op.value();
 
-  if (!attrValue && isa<RewriteOp>(op->getParentOp()))
-    return op.emitOpError("expected constant value when specified within a "
-                          "`pdl.rewrite`");
-  if (attrValue && attrType)
+  if (!attrValue) {
+    if (isa<RewriteOp>(op->getParentOp()))
+      return op.emitOpError("expected constant value when specified within a "
+                            "`pdl.rewrite`");
+    return verifyHasBindingUse(op);
+  }
+  if (attrType)
     return op.emitOpError("expected only one of [`type`, `value`] to be set");
-  return verifyHasBindingUse(op);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -431,13 +434,21 @@ static LogicalResult verify(RewriteOp op) {
 // pdl::TypeOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); }
+static LogicalResult verify(TypeOp op) {
+  if (!op.typeAttr())
+    return verifyHasBindingUse(op);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // pdl::TypesOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); }
+static LogicalResult verify(TypesOp op) {
+  if (!op.typesAttr())
+    return verifyHasBindingUse(op);
+  return success();
+}
 
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 0efcd60945a7b..984a31790a8bc 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -573,3 +573,42 @@ module @variadic_results_at {
     pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
   }
 }
+
+// -----
+
+// CHECK-LABEL: module @attribute_literal
+module @attribute_literal {
+  // CHECK: func @matcher(%{{.*}}: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.create_attribute 10 : i64
+  // CHECK: pdl_interp.apply_constraint "constraint"(%[[ATTR]] : !pdl.attribute)
+
+  // Check the correct lowering of an attribute that hasn't been bound.
+  pdl.pattern : benefit(1) {
+    %attr = pdl.attribute 10
+    pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute)
+
+    %root = pdl.operation
+    pdl.rewrite %root with "rewriter"
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @type_literal
+module @type_literal {
+  // CHECK: func @matcher(%{{.*}}: !pdl.operation)
+  // CHECK: %[[TYPE:.*]] = pdl_interp.create_type i32
+  // CHECK: %[[TYPES:.*]] = pdl_interp.create_types [i32, i64]
+  // CHECK: pdl_interp.apply_constraint "constraint"(%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)
+
+  // Check the correct lowering of a type that hasn't been bound.
+  pdl.pattern : benefit(1) {
+    %type = pdl.type : i32
+    %types = pdl.types : [i32, i64]
+    pdl.apply_native_constraint "constraint"(%type, %types: !pdl.type, !pdl.range<type>)
+
+    %root = pdl.operation
+    pdl.rewrite %root with "rewriter"
+  }
+}
+


        


More information about the Mlir-commits mailing list