[Mlir-commits] [mlir] ec03bbe - [mlir] Fix bug in partial dialect conversion

Vladislav Vinogradov llvmlistbot at llvm.org
Mon Sep 20 00:38:05 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-09-20T10:39:10+03:00
New Revision: ec03bbe8a74ae593d0ea5d8bf55c337e395873d1

URL: https://github.com/llvm/llvm-project/commit/ec03bbe8a74ae593d0ea5d8bf55c337e395873d1
DIFF: https://github.com/llvm/llvm-project/commit/ec03bbe8a74ae593d0ea5d8bf55c337e395873d1.diff

LOG: [mlir] Fix bug in partial dialect conversion

The discussion on forum:
https://llvm.discourse.group/t/bug-in-partial-dialect-conversion/4115

The `applyPartialConversion` didn't handle the operations, that were
marked as illegal inside dynamic legality callback.
Instead of reporting error, if such operation was not converted to legal set,
the method just added it to `unconvertedSet` in the same way as unknown operations.

This patch fixes that and handle dynamically illegal operations as well.

The patch includes 2 fixes for existing passes:

* `tensor-bufferize` - explicitly mark `std.return` as legal.
* `convert-parallel-loops-to-gpu` - ugly fix with marking visited operations
  to avoid recursive legality checks.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D108505

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
    mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
    mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer-full.mlir
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
index ac1ba0e2f24b4..483867933b7d2 100644
--- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
+++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
@@ -16,6 +16,7 @@ class ConversionTarget;
 struct LogicalResult;
 class MLIRContext;
 class Value;
+class Operation;
 class RewritePatternSet;
 using OwningRewritePatternList = RewritePatternSet;
 
@@ -49,6 +50,9 @@ void populateParallelLoopToGPUPatterns(RewritePatternSet &patterns);
 /// are not rewritten by the provided patterns are legal.
 void configureParallelLoopToGPULegality(ConversionTarget &target);
 
+/// Clean up after applyPartialConversion/applyFullConversion call.
+void finalizeParallelLoopToGPUConversion(Operation *op);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_

diff  --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index d13cebe3c3a2f..9770299d3b971 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -37,6 +37,24 @@
 using namespace mlir;
 using namespace mlir::scf;
 
+// Name of internal attribute to mark visited operations during conversion.
+//
+// NOTE: The conversion originally used the following legality criteria:
+//   `!parallelOp->hasAttr(gpu::getMappingAttrName())`
+// But the provided pattern might reject some cases based on more detailed
+// analysis of the `mapping` attribute.
+// To avoid dialect conversion failure due to non-converted illegal operation
+// we use this extra Unit attribute as a marker, that the operation was checked
+// by the pattern and is should be considered as legal in the following legality
+// checks. The `finalizeParallelLoopToGPUConversion` function performs clean up
+// of this extra attributes ans is supposed to be called after the dialect
+// conversion.
+//
+// TODO: Implement a cleaner solution, factoring out the "matching" logic
+// from the pattern and its callees into a separate function that can be called
+// from both the pattern and the op legality check.
+static constexpr StringLiteral kVisitedAttrName = "SCFToGPU_visited";
+
 // Extract an indexed value from KernelDim3.
 static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) {
   switch (pos) {
@@ -567,6 +585,9 @@ static LogicalResult processParallelLoop(
 LogicalResult
 ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
                                              PatternRewriter &rewriter) const {
+  // Mark the operation as visited for recursive legality check.
+  parallelOp->setAttr(kVisitedAttrName, rewriter.getUnitAttr());
+
   // We can only transform starting at the outer-most loop. Launches inside of
   // parallel loops are not supported.
   if (auto parentLoop = parallelOp->getParentOfType<ParallelOp>())
@@ -649,6 +670,13 @@ void mlir::populateParallelLoopToGPUPatterns(RewritePatternSet &patterns) {
 void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
   target.addLegalDialect<memref::MemRefDialect>();
   target.addDynamicallyLegalOp<scf::ParallelOp>([](scf::ParallelOp parallelOp) {
-    return !parallelOp->getAttr(gpu::getMappingAttrName());
+    return !parallelOp->hasAttr(gpu::getMappingAttrName()) ||
+           parallelOp->hasAttr(kVisitedAttrName);
+  });
+}
+
+void mlir::finalizeParallelLoopToGPUConversion(Operation *op) {
+  op->walk([](scf::ParallelOp parallelOp) {
+    parallelOp->removeAttr(kVisitedAttrName);
   });
 }

diff  --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
index 43c6798091e73..e9a8df02e685c 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
@@ -55,6 +55,7 @@ struct ParallelLoopToGpuPass
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
+    finalizeParallelLoopToGPUConversion(getOperation());
   }
 };
 

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index f9faba08cf9f2..f5f7b0f5faf19 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -175,6 +175,7 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
     target.addLegalDialect<memref::MemRefDialect>();
     target.addDynamicallyLegalDialect<StandardOpsDialect>(
         [&](Operation *op) { return typeConverter.isLegal(op); });
+    target.addLegalOp<ReturnOp>();
     target.addLegalDialect<scf::SCFDialect>();
 
     if (failed(

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6aa42f64059fd..4f1c8cfa70c53 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1650,7 +1650,13 @@ OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo,
 
 bool OperationLegalizer::isIllegal(Operation *op) const {
   // Check if the target explicitly marked this operation as illegal.
-  return target.getOpAction(op->getName()) == LegalizationAction::Illegal;
+  if (auto info = target.getOpAction(op->getName())) {
+    if (*info == LegalizationAction::Dynamic)
+      return !target.isLegal(op);
+    return *info == LegalizationAction::Illegal;
+  }
+
+  return false;
 }
 
 LogicalResult

diff  --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 3cbc1736b6aa2..5480d3d3d7286 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -47,55 +47,88 @@ func @recursively_legal_invalid_op() {
 
 // -----
 
-// Test that region cloning can be properly undone.
-func @test_undo_region_clone() {
-  "test.region"() ({
-    ^bb1(%i0: i64):
-      "test.invalid"(%i0) : (i64) -> ()
-  }) {legalizer.should_clone} : () -> ()
-
-  // expected-error at +1 {{failed to legalize operation 'test.illegal_op_f'}}
-  %ignored = "test.illegal_op_f"() : () -> (i32)
-  "test.return"() : () -> ()
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+
+  // Test that region cloning can be properly undone.
+  func @test_undo_region_clone() {
+    "test.region"() ({
+      ^bb1(%i0: i64):
+        "test.invalid"(%i0) : (i64) -> ()
+    }) {legalizer.should_clone} : () -> ()
+
+    // expected-error at +1 {{failed to legalize operation 'test.illegal_op_f'}}
+    %ignored = "test.illegal_op_f"() : () -> (i32)
+    "test.return"() : () -> ()
+  }
+
 }
 
 // -----
 
-// Test that unknown operations can be dynamically legal.
-func @test_unknown_dynamically_legal() {
-  "foo.unknown_op"() {test.dynamically_legal} : () -> ()
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+
+  // Test that unknown operations can be dynamically legal.
+  func @test_unknown_dynamically_legal() {
+    "foo.unknown_op"() {test.dynamically_legal} : () -> ()
+
+    // expected-error at +1 {{failed to legalize operation 'foo.unknown_op'}}
+    "foo.unknown_op"() {} : () -> ()
+    "test.return"() : () -> ()
+  }
 
-  // expected-error at +1 {{failed to legalize operation 'foo.unknown_op'}}
-  "foo.unknown_op"() {} : () -> ()
-  "test.return"() : () -> ()
 }
 
 // -----
 
-// Test that region inlining can be properly undone.
-func @test_undo_region_inline() {
-  "test.region"() ({
-    ^bb1(%i0: i64):
-       // expected-error at +1 {{failed to legalize operation 'std.br'}}
-       br ^bb2(%i0 : i64)
-    ^bb2(%i1: i64):
-      "test.invalid"(%i1) : (i64) -> ()
-  }) {} : () -> ()
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+
+  // Test that region inlining can be properly undone.
+  func @test_undo_region_inline() {
+    "test.region"() ({
+      ^bb1(%i0: i64):
+        // expected-error at +1 {{failed to legalize operation 'std.br'}}
+        br ^bb2(%i0 : i64)
+      ^bb2(%i1: i64):
+        "test.invalid"(%i1) : (i64) -> ()
+    }) {} : () -> ()
+
+    "test.return"() : () -> ()
+  }
 
-  "test.return"() : () -> ()
 }
 
 // -----
 
-// Test that multiple block erases can be properly undone.
-func @test_undo_block_erase() {
-   // expected-error at +1 {{failed to legalize operation 'test.region'}}
-  "test.region"() ({
-    ^bb1(%i0: i64):
-       br ^bb2(%i0 : i64)
-    ^bb2(%i1: i64):
-      "test.invalid"(%i1) : (i64) -> ()
-  }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+
+  // Test that multiple block erases can be properly undone.
+  func @test_undo_block_erase() {
+    // expected-error at +1 {{failed to legalize operation 'test.region'}}
+    "test.region"() ({
+      ^bb1(%i0: i64):
+        br ^bb2(%i0 : i64)
+      ^bb2(%i1: i64):
+        "test.invalid"(%i1) : (i64) -> ()
+    }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
+
+    "test.return"() : () -> ()
+  }
+
+}
+
+// -----
+
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+
+  func @create_unregistered_op_in_pattern() -> i32 {
+    // expected-error at +1 {{failed to legalize operation 'test.illegal_op_g'}}
+    %0 = "test.illegal_op_g"() : () -> (i32)
+    "test.return"(%0) : (i32) -> ()
+  }
 
-  "test.return"() : () -> ()
 }

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 0603883ce1490..25c3eb34f849a 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -173,18 +173,28 @@ func @bounded_recursion() {
 
 // -----
 
-func @fail_to_convert_illegal_op() -> i32 {
-  // expected-error at +1 {{failed to legalize operation 'test.illegal_op_f'}}
-  %result = "test.illegal_op_f"() : () -> (i32)
-  return %result : i32
+// expected-remark at +1 {{applyPartialConversion failed}}
+builtin.module {
+
+  func @fail_to_convert_illegal_op() -> i32 {
+    // expected-error at +1 {{failed to legalize operation 'test.illegal_op_f'}}
+    %result = "test.illegal_op_f"() : () -> (i32)
+    return %result : i32
+  }
+
 }
 
 // -----
 
-func @fail_to_convert_illegal_op_in_region() {
-  // expected-error at +1 {{failed to legalize operation 'test.region_builder'}}
-  "test.region_builder"() : () -> ()
-  return
+// expected-remark at +1 {{applyPartialConversion failed}}
+builtin.module {
+
+  func @fail_to_convert_illegal_op_in_region() {
+    // expected-error at +1 {{failed to legalize operation 'test.region_builder'}}
+    "test.region_builder"() : () -> ()
+    return
+  }
+
 }
 
 // -----
@@ -192,17 +202,21 @@ func @fail_to_convert_illegal_op_in_region() {
 // Check that the entry block arguments of a region are untouched in the case
 // of failure.
 
-// CHECK-LABEL: func @fail_to_convert_region
-func @fail_to_convert_region() {
-  // CHECK-NEXT: "test.region"
-  // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64):
-  "test.region"() ({
-    ^bb1(%i0: i64):
-      // expected-error at +1 {{failed to legalize operation 'test.region_builder'}}
-      "test.region_builder"() : () -> ()
-      "test.valid"() : () -> ()
-  }) : () -> ()
-  return
+// expected-remark at +1 {{applyPartialConversion failed}}
+builtin.module {
+
+  func @fail_to_convert_region() {
+    // CHECK: "test.region"
+    // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64):
+    "test.region"() ({
+      ^bb1(%i0: i64):
+        // expected-error at +1 {{failed to legalize operation 'test.region_builder'}}
+        "test.region_builder"() : () -> ()
+        "test.valid"() : () -> ()
+    }) : () -> ()
+    return
+  }
+
 }
 
 // -----
@@ -271,10 +285,8 @@ func @undo_child_created_before_parent() {
   return
 }
 
-
 // -----
 
-
 // Check that a conversion pattern on `test.blackhole` can mark the producer
 // for deletion.
 // CHECK-LABEL: @blackhole
@@ -284,3 +296,16 @@ func @blackhole() {
   // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
+
+// -----
+
+// expected-remark at +1 {{applyPartialConversion failed}}
+builtin.module {
+
+  func @create_unregistered_op_in_pattern() -> i32 {
+    // expected-error at +1 {{failed to legalize operation 'test.illegal_op_g'}}
+    %0 = "test.illegal_op_g"() : () -> (i32)
+    "test.return"(%0) : (i32) -> ()
+  }
+
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e17a76b67a514..a887adbc27055 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1415,9 +1415,12 @@ def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32)>;
 def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>;
 def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
 def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
+def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>;
 def LegalOpA : TEST_Op<"legal_op_a">,
   Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
 def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
+def LegalOpC : TEST_Op<"legal_op_c">,
+  Arguments<(ins I32)>, Results<(outs I32)>;
 
 // Check that the conversion infrastructure can properly undo the creation of
 // operations where an operation was created before its parent, in this case,

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index de141dc72e981..d51cf5ea1824c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -562,6 +562,20 @@ struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> {
     return success();
   };
 };
+
+// This pattern replaces explicitly illegal op with explicitly legal op,
+// but in addition creates unregistered operation.
+struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
+  using OpRewritePattern<ILLegalOpG>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ILLegalOpG op,
+                                PatternRewriter &rewriter) const final {
+    IntegerAttr attr = rewriter.getI32IntegerAttr(0);
+    Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
+    rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
+    return success();
+  };
+};
 } // namespace
 
 namespace {
@@ -632,6 +646,10 @@ struct TestLegalizePatternDriver
 
   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<StandardOpsDialect>();
+  }
+
   void runOnOperation() override {
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
@@ -643,8 +661,8 @@ struct TestLegalizePatternDriver
              TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
-             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>(
-            &getContext());
+             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+             TestCreateUnregisteredOp>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, converter);
     mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -652,7 +670,7 @@ struct TestLegalizePatternDriver
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
-    target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
+    target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
                       TerminatorOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
@@ -666,6 +684,11 @@ struct TestLegalizePatternDriver
              converter.isLegal(&op.getBody());
     });
 
+    // TestCreateUnregisteredOp creates `std.constant` operation,
+    // which was not added to target intentionally to test
+    // correct error code from conversion driver.
+    target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; });
+
     // Expect the type_producer/type_consumer operations to only operate on f64.
     target.addDynamicallyLegalOp<TestTypeProducerOp>(
         [](TestTypeProducerOp op) { return op.getType().isF64(); });
@@ -686,8 +709,10 @@ struct TestLegalizePatternDriver
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
-      (void)applyPartialConversion(getOperation(), target, std::move(patterns),
-                                   &unlegalizedOps);
+      if (failed(applyPartialConversion(
+              getOperation(), target, std::move(patterns), &unlegalizedOps))) {
+        getOperation()->emitRemark() << "applyPartialConversion failed";
+      }
       // Emit remarks for each legalizable operation.
       for (auto *op : unlegalizedOps)
         op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
@@ -701,7 +726,10 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
-      (void)applyFullConversion(getOperation(), target, std::move(patterns));
+      if (failed(applyFullConversion(getOperation(), target,
+                                     std::move(patterns)))) {
+        getOperation()->emitRemark() << "applyFullConversion failed";
+      }
       return;
     }
 


        


More information about the Mlir-commits mailing list