[Mlir-commits] [mlir] 412e30b - [mlir][Transforms] Dialect Conversion: Add 1:N op replacement test case (#121271)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 29 03:35:37 PST 2024


Author: Matthias Springer
Date: 2024-12-29T12:35:33+01:00
New Revision: 412e30b2274a134c01c2140ac7c7579be70f0896

URL: https://github.com/llvm/llvm-project/commit/412e30b2274a134c01c2140ac7c7579be70f0896
DIFF: https://github.com/llvm/llvm-project/commit/412e30b2274a134c01c2140ac7c7579be70f0896.diff

LOG: [mlir][Transforms] Dialect Conversion: Add 1:N op replacement test case (#121271)

This commit adds a test case that performs two back-to-back 1:N
replacements: `(i16) -> (i16, i16) -> ((i16, i16), (i16, i16))`. For the
moment, 3 argument materializations are inserted. In the future (when
the conversion value mapping supports 1:N), a single target
materialization will be inserted. Addresses a
[comment](https://github.com/llvm/llvm-project/pull/116524#discussion_r1894629711)
in #116524.

Added: 
    

Modified: 
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 2ca5f49637523f..297eb5acef21b7 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -450,7 +450,7 @@ func.func @fold_legalization() -> i32 {
 // -----
 
 // CHECK-LABEL: func @convert_detached_signature()
-//       CHECK:   "test.legal_op_with_region"() ({
+//       CHECK:   "test.legal_op"() ({
 //       CHECK:   ^bb0(%arg0: f64):
 //       CHECK:     "test.return"() : () -> ()
 //       CHECK:   }) : () -> ()
@@ -483,3 +483,25 @@ func.func @test_1_to_n_block_signature_conversion() {
   "test.return"() : () -> ()
 }
 
+// -----
+
+// CHECK: notifyOperationInserted: test.step_1
+// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationInserted: test.legal_op
+// CHECK: notifyOperationReplaced: test.step_1
+// CHECK: notifyOperationErased: test.step_1
+
+// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
+//       CHECK:   %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
+// TODO: There should be a single cast (i.e., a single target materialization).
+// This is currently not possible due to 1:N limitations of the conversion
+// mapping. Instead, we have 3 argument materializations.
+//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
+//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
+//       CHECK:   %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
+//       CHECK:   "test.valid"(%[[cast3]]) : (f16) -> ()
+func.func @test_multiple_1_to_n_replacement() {
+  %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
+  "test.invalid"(%0) : (f16) -> ()
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a470497fdbb560..826c222990be4f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
                   ConversionPatternRewriter &rewriter) const final {
     if (op->getNumRegions() != 1)
       return failure();
-    OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+    OperationState state(op->getLoc(), "test.legal_op", operands,
                          op->getResultTypes(), {}, BlockRange());
     Region *newRegion = state.addRegion();
     rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,49 @@ class TestRepetitive1ToNConsumer : public ConversionPattern {
   }
 };
 
+/// A pattern that tests two back-to-back 1 -> 2 op replacements.
+class TestMultiple1ToNReplacement : public ConversionPattern {
+public:
+  TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
+                          ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Helper function that replaces the given op with a new op of the given
+    // name and doubles each result (1 -> 2 replacement of each result).
+    auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+      SmallVector<Type> types;
+      for (Type t : op->getResultTypes()) {
+        types.push_back(t);
+        types.push_back(t);
+      }
+      OperationState state(op->getLoc(), name,
+                           /*operands=*/{}, types, op->getAttrs());
+      auto *newOp = rewriter.create(state);
+      SmallVector<ValueRange> repls;
+      for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
+        repls.push_back(newOp->getResults().slice(2 * i, 2));
+      rewriter.replaceOpWithMultiple(op, repls);
+      return newOp;
+    };
+
+    // Replace test.multiple_1_to_n_replacement with test.step_1.
+    Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
+    // Now replace test.step_1 with test.legal_op.
+    // TODO: Ideally, it should not be necessary to reset the insertion point
+    // here. Based on the API calls, it looks like test.step_1 is entirely
+    // erased. But that's not the case: an argument materialization will
+    // survive. And that argument materialization will be used by the users of
+    // `op`. If we don't reset the insertion point here, we get dominance
+    // errors. This will be fixed when we have 1:N support in the conversion
+    // value mapping.
+    rewriter.setInsertionPoint(repl1);
+    replaceWithDoubleResults(repl1, "test.legal_op");
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -1319,7 +1362,8 @@ struct TestLegalizePatternDriver
              TestUndoPropertiesModification, TestEraseOp,
              TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
-                 TestPassthroughInvalidOp>(&getContext(), converter);
+                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
+        &getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1330,8 +1374,7 @@ struct TestLegalizePatternDriver
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
                       TerminatorOp, OneRegionOp>();
-    target.addLegalOp(
-        OperationName("test.legal_op_with_region", &getContext()));
+    target.addLegalOp(OperationName("test.legal_op", &getContext()));
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {


        


More information about the Mlir-commits mailing list