[Mlir-commits] [mlir] de6c82d - [MLIR][PDL] Generalize result type verification

Uday Bondhugula llvmlistbot at llvm.org
Mon Jan 3 18:42:29 PST 2022


Author: Stanislav Funiak
Date: 2022-01-04T08:11:46+05:30
New Revision: de6c82d6fdb9a80b50a415bcc0fa9518fa964d40

URL: https://github.com/llvm/llvm-project/commit/de6c82d6fdb9a80b50a415bcc0fa9518fa964d40
DIFF: https://github.com/llvm/llvm-project/commit/de6c82d6fdb9a80b50a415bcc0fa9518fa964d40.diff

LOG: [MLIR][PDL] Generalize result type verification

Presently the result type verification checks if the type is used by a `pdl::OperationOp` inside the matcher. This is unnecessarily restrictive; the type could come from a `pdl::OperandOp or `pdl::OperandsOp` and still be inferrable.

Reviewed By: rriddle, Mogball

Differential Revision: https://reviews.llvm.org/D116083

Added: 
    

Modified: 
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/test/Dialect/PDL/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 2a399ec2169ee..95a3fb742fa11 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -207,16 +207,17 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
     if (isa<ApplyNativeRewriteOp>(resultTypeOp))
       continue;
 
-    // If the type operation was defined in the matcher and constrains the
-    // result of an input operation, it can be used.
-    auto constrainsInputOp = [rewriterBlock](Operation *user) {
-      return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
+    // If the type operation was defined in the matcher and constrains an
+    // operand or the result of an input operation, it can be used.
+    auto constrainsInput = [rewriterBlock](Operation *user) {
+      return user->getBlock() != rewriterBlock &&
+             isa<OperandOp, OperandsOp, OperationOp>(user);
     };
     if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
-      if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
+      if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInput))
         continue;
     } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
-      if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
+      if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInput))
         continue;
     }
 

diff  --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 758d5c6ac0314..9c7daf46a0907 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -88,7 +88,7 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) {
 // -----
 
 // Check that the result type of an operation within a rewrite can be inferred
-// from types used within the match block.
+// from the result types of an operation within the match block.
 pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
@@ -101,7 +101,7 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
 // -----
 
 // Check that the result type of an operation within a rewrite can be inferred
-// from types used within the match block.
+// from the result types of an operation within the match block.
 pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
   %types = pdl.types
   %root = pdl.operation -> (%types : !pdl.range<type>)
@@ -113,6 +113,34 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
 
 // -----
 
+// Check that the result type of an operation within a rewrite can be inferred
+// from the type of an operand within the match block.
+pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
+  %type1 = pdl.type
+  %type2 = pdl.type
+  %operand1 = pdl.operand : %type1
+  %operand2 = pdl.operand : %type2
+  %root = pdl.operation (%operand1, %operand2 : !pdl.value, !pdl.value)
+  pdl.rewrite %root {
+    %newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type)
+  }
+}
+
+// -----
+
+// Check that the result type of an operation within a rewrite can be inferred
+// from the types of operands within the match block.
+pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
+  %types = pdl.types
+  %operands = pdl.operands : %types
+  %root = pdl.operation (%operands : !pdl.range<value>)
+  pdl.rewrite %root {
+    %newOp = pdl.operation "foo.op" -> (%types : !pdl.range<type>)
+  }
+}
+
+// -----
+
 pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
   %root = pdl.operation
   pdl.rewrite %root {


        


More information about the Mlir-commits mailing list