[Mlir-commits] [mlir] 8de482e - [MLIR] Modify Partial op conversion mode to optionally track all non-legalizable operations.

Lucy Fox llvmlistbot at llvm.org
Thu Apr 30 09:54:50 PDT 2020


Author: Lucy Fox
Date: 2020-04-30T09:52:37-07:00
New Revision: 8de482ea9aa6fda6aa1050ab500157790b864454

URL: https://github.com/llvm/llvm-project/commit/8de482ea9aa6fda6aa1050ab500157790b864454
DIFF: https://github.com/llvm/llvm-project/commit/8de482ea9aa6fda6aa1050ab500157790b864454.diff

LOG: [MLIR] Modify Partial op conversion mode to optionally track all non-legalizable operations.

There are three op conversion modes: Partial, Full, and Analysis. This change modifies the Partial mode to optionally take a set of non-legalizable ops. If this parameter is specified, all ops that are not legalizable (i.e. would cause full conversion to fail) are tracked throughout the partial legalization.

Differential Revision: https://reviews.llvm.org/D78788

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2298b3bb3c73..4f1fafb191ac 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -660,20 +660,25 @@ class ConversionTarget {
 /// ConversionPatternRewriter, to see what additional constraints are imposed on
 /// the use of the PatternRewriter.
 
-/// Apply a partial conversion on the given operations, and all nested
+/// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there are unreachable blocks in any of the regions nested
-/// within 'ops'. If 'converter' is provided, the signatures of blocks and
-/// regions are also converted.
+/// returns failure if there ops explicitly marked as illegal. If `converter` is
+/// provided, the signatures of blocks and regions are also converted.
+/// If an `unconvertedOps` set is provided, all operations that are found not
+/// to be legalizable to the given `target` are placed within that set. (Note
+/// that if there is an op explicitly marked as illegal, the conversion
+/// terminates and the `unconvertedOps` set will not necessarily be complete.)
 LLVM_NODISCARD LogicalResult
 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
                        const OwningRewritePatternList &patterns,
-                       TypeConverter *converter = nullptr);
+                       TypeConverter *converter = nullptr,
+                       DenseSet<Operation *> *unconvertedOps = nullptr);
 LLVM_NODISCARD LogicalResult
 applyPartialConversion(Operation *op, ConversionTarget &target,
                        const OwningRewritePatternList &patterns,
-                       TypeConverter *converter = nullptr);
+                       TypeConverter *converter = nullptr,
+                       DenseSet<Operation *> *unconvertedOps = nullptr);
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index e57fb0983b8b..14a7084c50bf 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1541,9 +1541,8 @@ struct OperationConverter {
   explicit OperationConverter(ConversionTarget &target,
                               const OwningRewritePatternList &patterns,
                               OpConversionMode mode,
-                              DenseSet<Operation *> *legalizableOps = nullptr)
-      : opLegalizer(target, patterns), mode(mode),
-        legalizableOps(legalizableOps) {}
+                              DenseSet<Operation *> *trackedOps = nullptr)
+      : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops,
@@ -1563,9 +1562,11 @@ struct OperationConverter {
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
 
-  /// A set of pre-existing operations that were found to be legalizable to the
-  /// target. This field is only used when mode == OpConversionMode::Analysis.
-  DenseSet<Operation *> *legalizableOps;
+  /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+  /// this is populated with ops found to be legalizable to the target.
+  /// When mode == OpConversionMode::Partial, this is populated with ops found
+  /// *not* to be legalizable to the target.
+  DenseSet<Operation *> *trackedOps;
 };
 } // end anonymous namespace
 
@@ -1594,17 +1595,22 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       return op->emitError()
              << "failed to legalize operation '" << op->getName() << "'";
     /// Partial conversions allow conversions to fail iff the operation was not
-    /// explicitly marked as illegal.
-    if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
-      return op->emitError()
-             << "failed to legalize operation '" << op->getName()
-             << "' that was explicitly marked illegal";
+    /// explicitly marked as illegal. If the user provided a nonlegalizableOps
+    /// set, non-legalizable ops are included.
+    if (mode == OpConversionMode::Partial) {
+      if (opLegalizer.isIllegal(op))
+        return op->emitError()
+               << "failed to legalize operation '" << op->getName()
+               << "' that was explicitly marked illegal";
+      if (trackedOps)
+        trackedOps->insert(op);
+    }
   } else {
     /// Analysis conversions don't fail if any operations fail to legalize,
     /// they are only interested in the operations that were successfully
     /// legalized.
     if (mode == OpConversionMode::Analysis)
-      legalizableOps->insert(op);
+      trackedOps->insert(op);
 
     // If legalization succeeded, convert the types any of the blocks within
     // this operation.
@@ -1932,21 +1938,30 @@ auto ConversionTarget::getOpInfo(OperationName op) const
 // Op Conversion Entry Points
 //===----------------------------------------------------------------------===//
 
-/// Apply a partial conversion on the given operations, and all nested
+/// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
-/// possible, ignoring operations that failed to legalize.
+/// possible, ignoring operations that failed to legalize. This method only
+/// returns failure if there ops explicitly marked as illegal. If `converter` is
+/// provided, the signatures of blocks and regions are also converted.
+/// If an `unconvertedOps` set is provided, all operations that are found not
+/// to be legalizable to the given `target` are placed within that set. (Note
+/// that if there is an op explicitly marked as illegal, the conversion
+/// terminates and the `unconvertedOps` set will not necessarily be complete.)
 LogicalResult mlir::applyPartialConversion(
     ArrayRef<Operation *> ops, ConversionTarget &target,
-    const OwningRewritePatternList &patterns, TypeConverter *converter) {
-  OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
+    const OwningRewritePatternList &patterns, TypeConverter *converter,
+    DenseSet<Operation *> *unconvertedOps) {
+  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
+                                 unconvertedOps);
   return opConverter.convertOperations(ops, converter);
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
                              const OwningRewritePatternList &patterns,
-                             TypeConverter *converter) {
+                             TypeConverter *converter,
+                             DenseSet<Operation *> *unconvertedOps) {
   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
-                                converter);
+                                converter, unconvertedOps);
 }
 
 /// Apply a complete conversion on the given operations, and all nested

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 5c5434446abe..b1f9ffe88b2b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -4,6 +4,7 @@
 func @verifyDirectPattern() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() {status = "Success"}
   %result = "test.illegal_op_a"() : () -> (i32)
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return %result : i32
 }
 
@@ -11,6 +12,7 @@ func @verifyDirectPattern() -> i32 {
 func @verifyLargerBenefit() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() {status = "Success"}
   %result = "test.illegal_op_c"() : () -> (i32)
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return %result : i32
 }
 
@@ -26,7 +28,9 @@ func @remap_input_1_to_1(%arg0: i64) {
 // CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
 func @remap_call_1_to_1(%arg0: i64) {
   // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
+  // expected-remark at +1 {{op 'std.call' is not legalizable}}
   call @remap_input_1_to_1(%arg0) : (i64) -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -40,6 +44,7 @@ func @remap_input_1_to_N(%arg0: f32) -> f32 {
 func @remap_input_1_to_N_remaining_use(%arg0: f32) {
   // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
   // CHECK-NEXT: "work"([[CAST]]) : (f32) -> ()
+  // expected-remark at +1 {{op 'work' is not legalizable}}
   "work"(%arg0) : (f32) -> ()
 }
 
@@ -47,6 +52,7 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
 func @remap_input_to_self(%arg0: index) {
   // CHECK-NOT: test.cast
   // CHECK: "work"
+  // expected-remark at +1 {{op 'work' is not legalizable}}
   "work"(%arg0) : (index) -> ()
 }
 
@@ -59,12 +65,14 @@ func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
 // CHECK-LABEL: func @no_remap_nested
 func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
+  // expected-remark at +1 {{op 'foo.region' is not legalizable}}
   "foo.region"() ({
     // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
     ^bb0(%i0: i64, %unused: i16, %i1: i64):
       // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
       "test.invalid"(%i0, %i1) : (i64, i64) -> ()
   }) : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -78,6 +86,7 @@ func @remap_moved_region_args() {
     ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
       "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -91,6 +100,7 @@ func @remap_cloned_region_args() {
     ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
       "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) {legalizer.should_clone} : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -102,6 +112,7 @@ func @remap_drop_region() {
     ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
       "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
   }) : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -109,6 +120,7 @@ func @remap_drop_region() {
 func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
   // CHECK-NEXT: "work"{{.*}} : (i16)
+  // expected-remark at +1 {{op 'work' is not legalizable}}
   "work"(%arg) : (i16) -> ()
 }
 
@@ -117,6 +129,7 @@ func @up_to_date_replacement(%arg: i8) -> i8 {
   // CHECK-NEXT: return
   %repl_1 = "test.rewrite"(%arg) : (i8) -> i8
   %repl_2 = "test.rewrite"(%repl_1) : (i8) -> i8
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return %repl_2 : i8
 }
 
@@ -127,11 +140,13 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) {
   %0 = "test.op_with_region_fold"(%arg0) ({
     "foo.op_with_region_terminator"() : () -> ()
   }) : (i32) -> (i32)
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return %0 : i32
 }
 
 // CHECK-LABEL: @create_block
 func @create_block() {
+  // expected-remark at +1 {{op 'test.container' is not legalizable}}
   "test.container"() ({
     // Check that we created a block with arguments.
     // CHECK-NOT: test.create_block
@@ -140,6 +155,7 @@ func @create_block() {
     "test.create_block"() : () -> ()
     "test.finish"() : () -> ()
   }) : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -147,6 +163,7 @@ func @create_block() {
 func @bounded_recursion() {
   // CHECK: test.recursive_rewrite 0
   test.recursive_rewrite 3
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -188,13 +205,16 @@ func @fail_to_convert_region() {
 
 // CHECK-LABEL: @create_illegal_block
 func @create_illegal_block() {
+  // expected-remark at +1 {{op 'test.container' is not legalizable}}
   "test.container"() ({
     // Check that we can undo block creation, i.e. that the block was removed.
     // CHECK: test.create_illegal_block
     // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+    // expected-remark at +1 {{op 'test.create_illegal_block' is not legalizable}}
     "test.create_illegal_block"() : () -> ()
     "test.finish"() : () -> ()
   }) : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
 
@@ -202,6 +222,7 @@ func @create_illegal_block() {
 
 // CHECK-LABEL: @undo_block_arg_replace
 func @undo_block_arg_replace() {
+  // expected-remark at +1 {{op 'test.undo_block_arg_replace' is not legalizable}}
   "test.undo_block_arg_replace"() ({
   ^bb0(%arg0: i32):
     // CHECK: ^bb0(%[[ARG:.*]]: i32):
@@ -209,5 +230,6 @@ func @undo_block_arg_replace() {
 
     "test.return"(%arg0) : (i32) -> ()
   }) : () -> ()
+  // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 655cd1dd1609..deb1cf5bb075 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -515,8 +515,12 @@ struct TestLegalizePatternDriver
 
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
-      (void)applyPartialConversion(getOperation(), target, patterns,
-                                   &converter);
+      DenseSet<Operation *> unlegalizedOps;
+      (void)applyPartialConversion(getOperation(), target, patterns, &converter,
+                                   &unlegalizedOps);
+      // Emit remarks for each legalizable operation.
+      for (auto *op : unlegalizedOps)
+        op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
       return;
     }
 


        


More information about the Mlir-commits mailing list