[Mlir-commits] [mlir] 5b91060 - [mlir] Apply source materialization in case of transitive conversion

Alex Zinenko llvmlistbot at llvm.org
Thu Feb 4 02:15:23 PST 2021


Author: Alex Zinenko
Date: 2021-02-04T11:15:11+01:00
New Revision: 5b91060dcc2ef476e9332c5c6d759aed3b05f6be

URL: https://github.com/llvm/llvm-project/commit/5b91060dcc2ef476e9332c5c6d759aed3b05f6be
DIFF: https://github.com/llvm/llvm-project/commit/5b91060dcc2ef476e9332c5c6d759aed3b05f6be.diff

LOG: [mlir] Apply source materialization in case of transitive conversion

In dialect conversion infrastructure, source materialization applies as part of
the finalization procedure to results of the newly produced operations that
replace previously existing values with values having a different type.
However, such operations may be created to replace operations created in other
patterns. At this point, it is possible that the results of the _original_
operation are still in use and have mismatching types, but the results of the
_intermediate_ operation that performed the type change are not in use leading
to the absence of source materialization. For example,

  %0 = dialect.produce : !dialect.A
  dialect.use %0 : !dialect.A

can be replaced with

  %0 = dialect.other : !dialect.A
  %1 = dialect.produce : !dialect.A  // replaced, scheduled for removal
  dialect.use %1 : !dialect.A

and then with

  %0 = dialect.final : !dialect.B
  %1 = dialect.other : !dialect.A    // replaced, scheduled for removal
  %2 = dialect.produce : !dialect.A  // replaced, scheduled for removal
  dialect.use %2 : !dialect.A

in the same rewriting, but only the %1->%0 replacement is currently considered.

Change the logic in dialect conversion to look up all values that were replaced
by the given value and performing source materialization if any of those values
is still in use with mismatching types. This is performed by computing the
inverse value replacement mapping. This arguably expensive manipulation is
performed only if there were some type-changing replacements. An alternative
could be to consider all replaced operations and not only those that resulted
in type changes, but it would harm pattern-level composability: the pattern
that performed the non-type-changing replacement would have to be made aware of
the type converter in order to call the materialization hook.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95626

Added: 
    

Modified: 
    mlir/include/mlir/IR/BlockAndValueMapping.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalize-type-conversion.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h
index c15920fd3d0a..9959d883f89c 100644
--- a/mlir/include/mlir/IR/BlockAndValueMapping.h
+++ b/mlir/include/mlir/IR/BlockAndValueMapping.h
@@ -76,6 +76,14 @@ class BlockAndValueMapping {
   /// Clears all mappings held by the mapper.
   void clear() { valueMap.clear(); }
 
+  /// Returns a new mapper containing the inverse mapping.
+  BlockAndValueMapping getInverse() const {
+    BlockAndValueMapping result;
+    for (const auto &pair : valueMap)
+      result.valueMap.try_emplace(pair.second, pair.first);
+    return result;
+  }
+
 private:
   /// Utility lookupOrValue that looks up an existing key or returns the
   /// provided value.

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ecbd57d947e8..de3c436cdfc1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -122,6 +122,9 @@ struct ConversionValueMapping {
   /// Drop the last mapping for the given value.
   void erase(Value value) { mapping.erase(value); }
 
+  /// Returns the inverse raw value mapping (without recursive query support).
+  BlockAndValueMapping getInverse() const { return mapping.getInverse(); }
+
 private:
   /// Current value mappings.
   BlockAndValueMapping mapping;
@@ -2131,7 +2134,8 @@ struct OperationConverter {
   legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
                             TypeConverter *replConverter,
                             ConversionPatternRewriter &rewriter,
-                            ConversionPatternRewriterImpl &rewriterImpl);
+                            ConversionPatternRewriterImpl &rewriterImpl,
+                            const BlockAndValueMapping &inverseMapping);
 
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
@@ -2221,6 +2225,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
     return failure();
 
+  if (rewriterImpl.operationsWithChangedResults.empty())
+    return success();
+
+  Optional<BlockAndValueMapping> inverseMapping;
+
   // Process requested operation replacements.
   for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
        i != e; ++i) {
@@ -2241,11 +2250,15 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
       if (result.getType() == newValue.getType())
         continue;
 
+      // Compute the inverse mapping only if it is really needed.
+      if (!inverseMapping)
+        inverseMapping = rewriterImpl.mapping.getInverse();
+
       // Legalize this result.
       rewriter.setInsertionPoint(repl.first);
       if (failed(legalizeChangedResultType(repl.first, result, newValue,
                                            repl.second.converter, rewriter,
-                                           rewriterImpl)))
+                                           rewriterImpl, *inverseMapping)))
         return failure();
 
       // Update the end iterator for this loop in the case it was updated
@@ -2305,16 +2318,32 @@ LogicalResult OperationConverter::legalizeErasedResult(
   return success();
 }
 
+/// Finds a user of the given value, or of any other value that the given value
+/// replaced, that was not replaced in the conversion process.
+static Operation *
+findLiveUserOfReplaced(Value value, ConversionPatternRewriterImpl &rewriterImpl,
+                       const BlockAndValueMapping &inverseMapping) {
+  do {
+    // Walk the users of this value to see if there are any live users that
+    // weren't replaced during conversion.
+    auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
+      return rewriterImpl.isOpIgnored(user);
+    });
+    if (liveUserIt != value.user_end())
+      return *liveUserIt;
+    value = inverseMapping.lookupOrNull(value);
+  } while (value != nullptr);
+  return nullptr;
+}
+
 LogicalResult OperationConverter::legalizeChangedResultType(
     Operation *op, OpResult result, Value newValue,
     TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl) {
-  // Walk the users of this value to see if there are any live users that
-  // weren't replaced during conversion.
-  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
-    return rewriterImpl.isOpIgnored(user);
-  });
-  if (liveUserIt == result.user_end())
+    ConversionPatternRewriterImpl &rewriterImpl,
+    const BlockAndValueMapping &inverseMapping) {
+  Operation *liveUser =
+      findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
+  if (!liveUser)
     return success();
 
   // If the replacement has a type converter, attempt to materialize a
@@ -2340,8 +2369,8 @@ LogicalResult OperationConverter::legalizeChangedResultType(
                               << result.getResultNumber() << " of operation '"
                               << op->getName()
                               << "' that remained live after conversion";
-    diag.attachNote(liveUserIt->getLoc())
-        << "see existing live user here: " << *liveUserIt;
+    diag.attachNote(liveUser->getLoc())
+        << "see existing live user here: " << *liveUser;
     return failure();
   }
 

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 1ea8ddc2660c..9ce69519006a 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -45,6 +45,26 @@ func @test_invalid_result_materialization() {
 
 // -----
 
+// CHECK-LABEL: @test_transitive_use_materialization
+func @test_transitive_use_materialization() {
+  // CHECK: %[[V:.*]] = "test.type_producer"() : () -> f64
+  // CHECK: %[[C:.*]] = "test.cast"(%[[V]]) : (f64) -> f32
+  %result = "test.another_type_producer"() : () -> f32
+  // CHECK: "foo.return"(%[[C]])
+  "foo.return"(%result) : (f32) -> ()
+}
+
+// -----
+
+func @test_transitive_use_invalid_materialization() {
+  // expected-error at below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
+  %result = "test.another_type_producer"() : () -> f16
+  // expected-note at below {{see existing live user here}}
+  "foo.return"(%result) : (f16) -> ()
+}
+
+// -----
+
 func @test_invalid_result_legalization() {
   // expected-error at below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}}
   %result = "test.type_producer"() : () -> i16

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a36983864556..f5df4ac62df2 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1281,6 +1281,8 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
   Arguments<(ins Variadic<AnyType>)>;
 def TestTypeProducerOp : TEST_Op<"type_producer">,
   Results<(outs AnyType)>;
+def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
+  Results<(outs AnyType)>;
 def TestTypeConsumerOp : TEST_Op<"type_consumer">,
   Arguments<(ins AnyType)>;
 def TestValidOp : TEST_Op<"valid", [Terminator]>,

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 1b2fab124c86..fe14b9698832 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -801,6 +801,17 @@ struct TestTypeConsumerForward
   }
 };
 
+struct TestTypeConversionAnotherProducer
+    : public OpRewritePattern<TestAnotherTypeProducerOp> {
+  using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op,
+                                PatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType());
+    return success();
+  }
+};
+
 struct TestTypeConversionDriver
     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -865,6 +876,7 @@ struct TestTypeConversionDriver
     OwningRewritePatternList patterns;
     patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
                     TestSignatureConversionUndo>(converter, &getContext());
+    patterns.insert<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
 


        


More information about the Mlir-commits mailing list