[Mlir-commits] [mlir] 55f878b - [mlir][pdl] Add a new !pdl.range<> type

River Riddle llvmlistbot at llvm.org
Wed Mar 3 15:55:31 PST 2021


Author: River Riddle
Date: 2021-03-03T15:48:00-08:00
New Revision: 55f878bad96421489dbe7ec8cec239acc02a899b

URL: https://github.com/llvm/llvm-project/commit/55f878bad96421489dbe7ec8cec239acc02a899b
DIFF: https://github.com/llvm/llvm-project/commit/55f878bad96421489dbe7ec8cec239acc02a899b.diff

LOG: [mlir][pdl] Add a new !pdl.range<> type

This type represents a range of positional values. It will be used in followup revisions to add support for variadic constructs to PDL, such as operand and result ranges.

Differential Revision: https://reviews.llvm.org/D95717

Added: 
    mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
    mlir/test/Dialect/PDL/invalid-types.mlir

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
    mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/lib/Dialect/PDL/IR/CMakeLists.txt
    mlir/lib/Dialect/PDL/IR/PDL.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index e4a9ebe8900d..27fdd75791e4 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -52,7 +52,7 @@ def PDL_ApplyConstraintOp
   }];
 
   let arguments = (ins StrAttr:$name,
-                       Variadic<PDL_PositionalValue>:$args,
+                       Variadic<PDL_AnyType>:$args,
                        OptionalAttr<ArrayAttr>:$constParams);
   let assemblyFormat = [{
     $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict
@@ -136,9 +136,9 @@ def PDL_CreateNativeOp
   }];
 
   let arguments = (ins StrAttr:$name,
-                       Variadic<PDL_PositionalValue>:$args,
+                       Variadic<PDL_AnyType>:$args,
                        OptionalAttr<ArrayAttr>:$constParams);
-  let results = (outs PDL_PositionalValue:$result);
+  let results = (outs PDL_AnyType:$result);
   let assemblyFormat = [{
     $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
     attr-dict
@@ -403,7 +403,7 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
 
   let arguments = (ins PDL_Operation:$root,
                        OptionalAttr<StrAttr>:$name,
-                       Variadic<PDL_PositionalValue>:$externalArgs,
+                       Variadic<PDL_AnyType>:$externalArgs,
                        OptionalAttr<ArrayAttr>:$externalConstParams);
   let regions = (region AnyRegion:$body);
   let assemblyFormat = [{

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
index a9028f3d5972..8cbe31fd2a6f 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h
@@ -19,6 +19,18 @@
 // PDL Dialect Types
 //===----------------------------------------------------------------------===//
 
+namespace mlir {
+namespace pdl {
+/// This class represents the base class of all PDL types.
+class PDLType : public Type {
+public:
+  using Type::Type;
+
+  static bool classof(Type type);
+};
+} // namespace pdl
+} // namespace mlir
+
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/PDL/IR/PDLOpsTypes.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
index 1cf0b1e25af5..c854616fbc8f 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
@@ -19,7 +19,8 @@ include "mlir/Dialect/PDL/IR/PDLDialect.td"
 // PDL Types
 //===----------------------------------------------------------------------===//
 
-class PDL_Type<string name, string typeMnemonic> : TypeDef<PDL_Dialect, name> {
+class PDL_Type<string name, string typeMnemonic>
+    : TypeDef<PDL_Dialect, name, "::mlir::pdl::PDLType"> {
   let mnemonic = typeMnemonic;
 }
 
@@ -47,6 +48,27 @@ def PDL_Operation : PDL_Type<"Operation", "operation"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::RangeType
+//===----------------------------------------------------------------------===//
+
+def PDL_Range : PDL_Type<"Range", "range"> {
+  let summary = "PDL handle to a range of a given sub-type";
+  let description = [{
+    This type represents a range of instances of the given PDL element type,
+    i.e. `Attribute`, `Operation`, `Type`, or `Value`.
+  }];
+  let parameters = (ins "Type":$elementType);
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
+      return $_get(elementType.getContext(), elementType);
+    }]>,
+  ];
+  let genVerifyDecl = 1;
+  let skipDefaultBuilders = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::TypeType
 //===----------------------------------------------------------------------===//
@@ -75,10 +97,8 @@ def PDL_Value : PDL_Type<"Value", "value"> {
 // Additional Type Constraints
 //===----------------------------------------------------------------------===//
 
-// A positional value is a location on a pattern DAG, which may be an attribute,
-// operation, or operand/result.
-def PDL_PositionalValue :
-    AnyTypeOf<[PDL_Attribute, PDL_Operation, PDL_Type, PDL_Value],
-              "Positional Value">;
+def PDL_AnyType : Type<
+  CPred<"$_self.isa<::mlir::pdl::PDLType>()">, "pdl type",
+        "::mlir::pdl::PDLType">;
 
 #endif // MLIR_DIALECT_PDL_IR_PDLTYPES

diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 5720a624ef95..46513c1906ff 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -113,7 +113,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   }];
 
   let arguments = (ins StrAttr:$name,
-                       Variadic<PDL_PositionalValue>:$args,
+                       Variadic<PDL_AnyType>:$args,
                        OptionalAttr<ArrayAttr>:$constParams);
   let assemblyFormat = [{
     $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->`
@@ -151,7 +151,7 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
   }];
   let arguments = (ins StrAttr:$name,
                        PDL_Operation:$root,
-                       Variadic<PDL_PositionalValue>:$args,
+                       Variadic<PDL_AnyType>:$args,
                        OptionalAttr<ArrayAttr>:$constParams);
   let assemblyFormat = [{
     $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root
@@ -178,8 +178,7 @@ def PDLInterp_AreEqualOp
     ```
   }];
 
-  let arguments = (ins PDL_PositionalValue:$lhs,
-                       PDL_PositionalValue:$rhs);
+  let arguments = (ins PDL_AnyType:$lhs, PDL_AnyType:$rhs);
   let assemblyFormat = "operands `:` type($lhs) attr-dict `->` successors";
 }
 
@@ -374,9 +373,9 @@ def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
   }];
 
   let arguments = (ins StrAttr:$name,
-                       Variadic<PDL_PositionalValue>:$args,
+                       Variadic<PDL_AnyType>:$args,
                        OptionalAttr<ArrayAttr>:$constParams);
-  let results = (outs PDL_PositionalValue:$result);
+  let results = (outs PDL_AnyType:$result);
   let assemblyFormat = [{
     $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
     attr-dict
@@ -691,7 +690,7 @@ def PDLInterp_IsNotNullOp
     ```
   }];
 
-  let arguments = (ins PDL_PositionalValue:$value);
+  let arguments = (ins PDL_AnyType:$value);
   let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
 }
 
@@ -716,7 +715,7 @@ def PDLInterp_RecordMatchOp
     ```
   }];
 
-  let arguments = (ins Variadic<PDL_PositionalValue>:$inputs,
+  let arguments = (ins Variadic<PDL_AnyType>:$inputs,
                        Variadic<PDL_Operation>:$matchedOps,
                        SymbolRefAttr:$rewriter,
                        OptionalAttr<StrAttr>:$rootKind,

diff  --git a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt
index 98ec697df925..6ffa06ba766c 100644
--- a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRPDL
   PDL.cpp
+  PDLTypes.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/PDL

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 256e6aedc6e2..6ee8a1bf4491 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -10,10 +10,8 @@
 #include "mlir/Dialect/PDL/IR/PDLOps.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/DialectImplementation.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/StringSwitch.h"
-#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::pdl;
@@ -430,27 +428,3 @@ static LogicalResult verify(TypeOp op) {
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
-
-//===----------------------------------------------------------------------===//
-// TableGen'd type method definitions
-//===----------------------------------------------------------------------===//
-
-#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
-
-Type PDLDialect::parseType(DialectAsmParser &parser) const {
-  StringRef keyword;
-  if (parser.parseKeyword(&keyword))
-    return Type();
-  if (Type type = generatedTypeParser(getContext(), parser, keyword))
-    return type;
-
-  parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `")
-      << keyword << "'";
-  return Type();
-}
-
-void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const {
-  if (failed(generatedTypePrinter(type, printer)))
-    llvm_unreachable("unknown 'pdl' type");
-}

diff  --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
new file mode 100644
index 000000000000..37cdf772338e
--- /dev/null
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -0,0 +1,100 @@
+//===- PDLTypes.cpp - Pattern Descriptor Language Types -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::pdl;
+
+//===----------------------------------------------------------------------===//
+// TableGen'd type method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// PDLDialect
+//===----------------------------------------------------------------------===//
+
+static Type parsePDLType(DialectAsmParser &parser) {
+  StringRef keyword;
+  if (parser.parseKeyword(&keyword))
+    return Type();
+  if (Type type = generatedTypeParser(parser.getBuilder().getContext(), parser,
+                                      keyword))
+    return type;
+
+  // FIXME: This ends up with a double error being emitted if `RangeType` also
+  // emits an error. We should rework the `generatedTypeParser` to better
+  // support when the keyword is valid but the individual type parser itself
+  // emits an error.
+  parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `")
+      << keyword << "'";
+  return Type();
+}
+
+Type PDLDialect::parseType(DialectAsmParser &parser) const {
+  return parsePDLType(parser);
+}
+
+void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const {
+  if (failed(generatedTypePrinter(type, printer)))
+    llvm_unreachable("unknown 'pdl' type");
+}
+
+//===----------------------------------------------------------------------===//
+// PDL Types
+//===----------------------------------------------------------------------===//
+
+bool PDLType::classof(Type type) {
+  return llvm::isa<PDLDialect>(type.getDialect());
+}
+
+//===----------------------------------------------------------------------===//
+// RangeType
+//===----------------------------------------------------------------------===//
+
+Type RangeType::parse(MLIRContext *context, DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  llvm::SMLoc elementLoc = parser.getCurrentLocation();
+  Type elementType = parsePDLType(parser);
+  if (!elementType || parser.parseGreater())
+    return Type();
+
+  if (elementType.isa<RangeType>()) {
+    parser.emitError(elementLoc)
+        << "element of pdl.range cannot be another range, but got"
+        << elementType;
+    return Type();
+  }
+  return RangeType::get(elementType);
+}
+
+void RangeType::print(DialectAsmPrinter &printer) const {
+  printer << "range<";
+  (void)generatedTypePrinter(getElementType(), printer);
+  printer << ">";
+}
+
+LogicalResult RangeType::verify(function_ref<InFlightDiagnostic()> emitError,
+                                Type elementType) {
+  if (!elementType.isa<PDLType>() || elementType.isa<RangeType>()) {
+    return emitError()
+           << "expected element of pdl.range to be one of [!pdl.attribute, "
+              "!pdl.operation, !pdl.type, !pdl.value], but got "
+           << elementType;
+  }
+  return success();
+}

diff  --git a/mlir/test/Dialect/PDL/invalid-types.mlir b/mlir/test/Dialect/PDL/invalid-types.mlir
new file mode 100644
index 000000000000..8d677db27adf
--- /dev/null
+++ b/mlir/test/Dialect/PDL/invalid-types.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+//===----------------------------------------------------------------------===//
+// pdl::RangeType
+//===----------------------------------------------------------------------===//
+
+// expected-error at +2 {{element of pdl.range cannot be another range, but got'!pdl.range<value>'}}
+// expected-error at +1 {{invalid 'pdl' type}}
+#invalid_element = !pdl.range<range<value>>


        


More information about the Mlir-commits mailing list