[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Add 1:N op replacement test case (PR #121271)

Matthias Springer llvmlistbot at llvm.org
Sat Dec 28 11:34:11 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/121271

>From 92cf7322559542e66b03d5099d582e57f2c746dd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 28 Dec 2024 17:24:41 +0100
Subject: [PATCH] test double replacement

---
 mlir/test/Transforms/test-legalizer.mlir    | 24 +++++++++-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp | 51 +++++++++++++++++++--
 2 files changed, 70 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 2ca5f49637523f..297eb5acef21b7 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -450,7 +450,7 @@ func.func @fold_legalization() -> i32 {
 // -----
 
 // CHECK-LABEL: func @convert_detached_signature()
-//       CHECK:   "test.legal_op_with_region"() ({
+//       CHECK:   "test.legal_op"() ({
 //       CHECK:   ^bb0(%arg0: f64):
 //       CHECK:     "test.return"() : () -> ()
 //       CHECK:   }) : () -> ()
@@ -483,3 +483,25 @@ func.func @test_1_to_n_block_signature_conversion() {
   "test.return"() : () -> ()
 }
 
+// -----
+
+// CHECK: notifyOperationInserted: test.step_1
+// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationInserted: test.legal_op
+// CHECK: notifyOperationReplaced: test.step_1
+// CHECK: notifyOperationErased: test.step_1
+
+// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
+//       CHECK:   %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
+// TODO: There should be a single cast (i.e., a single target materialization).
+// This is currently not possible due to 1:N limitations of the conversion
+// mapping. Instead, we have 3 argument materializations.
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
+//       CHECK:   %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
+//       CHECK:   "test.valid"(%[[cast3]]) : (f16) -> ()
+func.func @test_multiple_1_to_n_replacement() {
+  %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
+  "test.invalid"(%0) : (f16) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a470497fdbb560..826c222990be4f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
                   ConversionPatternRewriter &rewriter) const final {
     if (op->getNumRegions() != 1)
       return failure();
-    OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+    OperationState state(op->getLoc(), "test.legal_op", operands,
                          op->getResultTypes(), {}, BlockRange());
     Region *newRegion = state.addRegion();
     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,49 @@ class TestRepetitive1ToNConsumer : public ConversionPattern {
   }
 };
 
+/// A pattern that tests two back-to-back 1 -> 2 op replacements.
+class TestMultiple1ToNReplacement : public ConversionPattern {
+public:
+  TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
+                          ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Helper function that replaces the given op with a new op of the given
+    // name and doubles each result (1 -> 2 replacement of each result).
+    auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+      SmallVector<Type> types;
+      for (Type t : op->getResultTypes()) {
+        types.push_back(t);
+        types.push_back(t);
+      }
+      OperationState state(op->getLoc(), name,
+                           /*operands=*/{}, types, op->getAttrs());
+      auto *newOp = rewriter.create(state);
+      SmallVector<ValueRange> repls;
+      for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
+        repls.push_back(newOp->getResults().slice(2 * i, 2));
+      rewriter.replaceOpWithMultiple(op, repls);
+      return newOp;
+    };
+
+    // Replace test.multiple_1_to_n_replacement with test.step_1.
+    Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
+    // Now replace test.step_1 with test.legal_op.
+    // TODO: Ideally, it should not be necessary to reset the insertion point
+    // here. Based on the API calls, it looks like test.step_1 is entirely
+    // erased. But that's not the case: an argument materialization will
+    // survive. And that argument materialization will be used by the users of
+    // `op`. If we don't reset the insertion point here, we get dominance
+    // errors. This will be fixed when we have 1:N support in the conversion
+    // value mapping.
+    rewriter.setInsertionPoint(repl1);
+    replaceWithDoubleResults(repl1, "test.legal_op");
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -1319,7 +1362,8 @@ struct TestLegalizePatternDriver
              TestUndoPropertiesModification, TestEraseOp,
              TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
-                 TestPassthroughInvalidOp>(&getContext(), converter);
+                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
+        &getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1330,8 +1374,7 @@ struct TestLegalizePatternDriver
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
                       TerminatorOp, OneRegionOp>();
-    target.addLegalOp(
-        OperationName("test.legal_op_with_region", &getContext()));
+    target.addLegalOp(OperationName("test.legal_op", &getContext()));
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {



More information about the Mlir-commits mailing list