[Mlir-commits] [mlir] [mlir][Transforms] Make lookup without type converter unambiguous (PR #151747)

Matthias Springer llvmlistbot at llvm.org
Wed Aug 6 04:44:58 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/151747

>From 45f831dc4c9644f41913b0285a2368605745654d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 3 Aug 2025 11:03:21 +0000
Subject: [PATCH] make lookup unambiguous

---
 mlir/docs/DialectConversion.md                |  67 +++++-
 .../Transforms/Utils/DialectConversion.cpp    | 203 +++++++++++-------
 mlir/test/Transforms/test-legalizer.mlir      |  17 ++
 mlir/test/lib/Dialect/Test/TestOps.td         |   4 +
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  49 ++++-
 5 files changed, 255 insertions(+), 85 deletions(-)

diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..556e73c2d56c7 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -202,17 +202,62 @@ struct MyConversionPattern : public ConversionPattern {
 
 #### Type Safety
 
-The types of the remapped operands provided to a conversion pattern must be of a
-type expected by the pattern. The expected types of a pattern are determined by
-a provided [TypeConverter](#type-converter). If no type converter is provided,
-the types of the remapped operands are expected to match the types of the
-original operands. If a type converter is provided, the types of the remapped
-operands are expected to be legal as determined by the converter. If the
-remapped operand types are not of an expected type, and a materialization to the
-expected type could not be performed, the pattern fails application before the
-`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
-explicitly ensure type safety, or sanitize the types of the incoming remapped
-operands. More information on type conversion is detailed in the
+The types of the remapped operands provided to a conversion pattern (through
+the adaptor or `ArrayRef` of operands) depend on type conversion rules.
+
+If the pattern was initialized with a [type converter](#type-converter), the
+conversion driver passes values whose types match the legalized types of the
+operands of the matched operation as per the type converter. To that end, the
+conversion driver may insert target materializations to convert the most
+recently mapped values to the expected legalized types. The driver tries to
+reuse existing materializations on a best-effort basis, but this is not
+guaranteed by the infrastructure. If the operand types of the matched op could
+not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
+is invoked.
+
+Example:
+```c++
+// Type converter that converts all FloatTypes to IntegerTypes.
+TypeConverter converter;
+converter.addConversion([](FloatType t) {
+    return IntegerType::get(t.getContext(), t.getWidth());
+});
+
+// Assuming that `MyConversionPattern` was initialized with `converter`.
+struct MyConversionPattern : public ConversionPattern {
+  virtual LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands, /* ... */) const {
+//                                               ^^^^^^^^
+//      If `op` has a FloatType operand, the respective value in `operands`
+//      is guaranteed to have the legalized IntegerType. If another pattern
+//      previously replaced the operand SSA value with an SSA value of the
+//      legalized type (via "replaceOp" or "applySignatureConversion"), you
+//      will get that SSA value directly (unless the replacement value was
+//      also replaced). Otherwise, you will get a materialization to the
+//      legalized type.
+```
+
+If the pattern was initialized without a type converter, the conversion driver
+passes the most recently mapped values to the pattern, excluding any
+materializations. If a value with the same type as the original operand is
+desired, users can directly take the respective operand from the matched
+operation.
+
+Example: When initializing the pattern from the above example without a type
+converter, `operands` contains the most recent replacement values, regardless
+of their types.
+
+Note: When running without a type converter, materializations are intentionally
+excluded from the lookup process because their presence may depend on other
+patterns. Passing materializations would make the conversion infrastructure
+fragile and unpredictable. Moreover, there could be multiple materializations
+to different types. (This can be the case when multiple patterns are running
+with different type converters.) In such a case, it would be unclear which
+materialization to pass.
+
+The above rules ensure that patterns do not have to explicitly ensure type
+safety, or sanitize the types of the incoming remapped operands. More
+information on type conversion is detailed in the
 [dedicated section](#type-conversion) below.
 
 ## Type Conversion
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..dedc84f1adde9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -121,17 +121,8 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup the most recently mapped values with the desired types in the
-  /// mapping.
-  ///
-  /// Special cases:
-  /// - If the desired type range is empty, simply return the most recently
-  ///   mapped values.
-  /// - If there is no mapping to the desired types, also return the most
-  ///   recently mapped values.
-  /// - If there is no mapping for the given values at all, return the given
-  ///   value.
-  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+  /// Lookup a value in the mapping.
+  ValueVector lookup(const ValueVector &from) const;
 
   template <typename T>
   struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -185,54 +176,31 @@ struct ConversionValueMapping {
 };
 } // namespace
 
-ValueVector
-ConversionValueMapping::lookupOrDefault(Value from,
-                                        TypeRange desiredTypes) const {
-  // Try to find the deepest values that have the desired types. If there is no
-  // such mapping, simply return the deepest values.
-  ValueVector desiredValue;
-  ValueVector current{from};
-  do {
-    // Store the current value if the types match.
-    if (TypeRange(ValueRange(current)) == desiredTypes)
-      desiredValue = current;
-
-    // If possible, Replace each value with (one or multiple) mapped values.
-    ValueVector next;
-    for (Value v : current) {
-      auto it = mapping.find({v});
-      if (it != mapping.end()) {
-        llvm::append_range(next, it->second);
-      } else {
-        next.push_back(v);
-      }
-    }
-    if (next != current) {
-      // If at least one value was replaced, continue the lookup from there.
-      current = std::move(next);
-      continue;
-    }
-
-    // Otherwise: Check if there is a mapping for the entire vector. Such
-    // mappings are materializations. (N:M mapping are not supported for value
-    // replacements.)
-    //
-    // Note: From a correctness point of view, materializations do not have to
-    // be stored (and looked up) in the mapping. But for performance reasons,
-    // we choose to reuse existing IR (when possible) instead of creating it
-    // multiple times.
-    auto it = mapping.find(current);
-    if (it == mapping.end()) {
-      // No mapping found: The lookup stops here.
-      break;
-    }
-    current = it->second;
-  } while (true);
+/// Marker attribute for pure type conversions. I.e., mappings whose only
+/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
+/// the replacement values of a "replaceOp" call, etc., are not pure type
+/// conversions.)
+static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+
+/// A vector of values is a pure type conversion if all values are defined by
+/// the same operation and the operation has the `kPureTypeConversionMarker`
+/// attribute.
+static bool isPureTypeConversion(const ValueVector &values) {
+  assert(!values.empty() && "expected non-empty value vector");
+  Operation *op = values.front().getDefiningOp();
+  for (Value v : llvm::drop_begin(values))
+    if (v.getDefiningOp() != op)
+      return false;
+  return op && op->hasAttr(kPureTypeConversionMarker);
+}
 
-  // If the desired values were found use them, otherwise default to the leaf
-  // values.
-  // Note: If `desiredTypes` is empty, this function always returns `current`.
-  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
+ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
+  auto it = mapping.find(from);
+  if (it == mapping.end()) {
+    // No mapping found: The lookup stops here.
+    return {};
+  }
+  return it->second;
 }
 
 //===----------------------------------------------------------------------===//
@@ -930,7 +898,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   ///   recently mapped values.
   /// - If there is no mapping for the given values at all, return the given
   ///   value.
-  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+  ///
+  /// If `skipPureTypeConversions` is "true", materializations that are pure
+  /// type conversions are not considered.
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+                              bool skipPureTypeConversions = false) const;
 
   /// Lookup the given value within the map, or return an empty vector if the
   /// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +965,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
   /// the results of the unresolved materialization in the conversion value
   /// mapping.
+  ///
+  /// If `isPureTypeConversion` is "true", the materialization is created only
+  /// to resolve a type mismatch. That means it is not a regular value
+  /// replacement issued by the user. (Replacement values that are created
+  /// "out of thin air" appear like unresolved materializations because they are
+  /// unrealized_conversion_cast ops. However, they must be treated like
+  /// regular value replacements.)
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
       ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
-      UnrealizedConversionCastOp *castOp = nullptr);
+      UnrealizedConversionCastOp *castOp = nullptr,
+      bool isPureTypeConversion = true);
 
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1244,77 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 //===----------------------------------------------------------------------===//
 
-ValueVector
-ConversionPatternRewriterImpl::lookupOrDefault(Value from,
-                                               TypeRange desiredTypes) const {
-  return mapping.lookupOrDefault(from, desiredTypes);
+ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
+    Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+  // Helper function that looks up each value in `values` individually and then
+  // composes the results. If that fails, it tries to look up the entire vector
+  // at once.
+  auto composedLookup = [&](const ValueVector &values) -> ValueVector {
+    // If possible, replace each value with (one or multiple) mapped values.
+    ValueVector next;
+    for (Value v : values) {
+      ValueVector r = mapping.lookup({v});
+      if (!r.empty()) {
+        llvm::append_range(next, r);
+      } else {
+        next.push_back(v);
+      }
+    }
+    if (next != values) {
+      // At least one value was replaced.
+      return next;
+    }
+
+    // Otherwise: Check if there is a mapping for the entire vector. Such
+    // mappings are materializations. (N:M mapping are not supported for value
+    // replacements.)
+    //
+    // Note: From a correctness point of view, materializations do not have to
+    // be stored (and looked up) in the mapping. But for performance reasons,
+    // we choose to reuse existing IR (when possible) instead of creating it
+    // multiple times.
+    ValueVector r = mapping.lookup(values);
+    if (r.empty()) {
+      // No mapping found: The lookup stops here.
+      return {};
+    }
+    return r;
+  };
+
+  // Try to find the deepest values that have the desired types. If there is no
+  // such mapping, simply return the deepest values.
+  ValueVector desiredValue;
+  ValueVector current{from};
+  ValueVector lastNonMaterialization{from};
+  do {
+    // Store the current value if the types match.
+    bool match = TypeRange(ValueRange(current)) == desiredTypes;
+    if (skipPureTypeConversions) {
+      // Skip pure type conversions, if requested.
+      bool pureConversion = isPureTypeConversion(current);
+      match &= !pureConversion;
+      // Keep track of the last mapped value that was not a pure type
+      // conversion.
+      if (!pureConversion)
+        lastNonMaterialization = current;
+    }
+    if (match)
+      desiredValue = current;
+
+    // Lookup next value in the mapping.
+    ValueVector next = composedLookup(current);
+    if (next.empty())
+      break;
+    current = std::move(next);
+  } while (true);
+
+  // If the desired values were found use them, otherwise default to the leaf
+  // values. (Skip pure type conversions, if requested.)
+  if (!desiredTypes.empty())
+    return desiredValue;
+  if (skipPureTypeConversions)
+    return lastNonMaterialization;
+  return current;
 }
 
 ValueVector
@@ -1324,10 +1371,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
     if (!currentTypeConverter) {
-      // The current pattern does not have a type converter. I.e., it does not
-      // distinguish between legal and illegal types. For each operand, simply
-      // pass through the most recently mapped values.
-      remapped.push_back(lookupOrDefault(operand));
+      // The current pattern does not have a type converter. Pass the most
+      // recently mapped values, excluding materializations. Materializations
+      // are intentionally excluded because their presence may depend on other
+      // patterns. Including materializations would make the lookup fragile
+      // and unpredictable.
+      remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
+                                         /*skipPureTypeConversions=*/true));
       continue;
     }
 
@@ -1356,7 +1406,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     }
 
     // Create a materialization for the most recently mapped values.
-    repl = lookupOrDefault(operand);
+    repl = lookupOrDefault(operand, /*desiredTypes=*/{},
+                           /*skipPureTypeConversions=*/true);
     ValueRange castValues = buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
         /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1533,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
               OpBuilder::InsertPoint(newBlock, newBlock->begin()),
               origArg.getLoc(),
               /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
-              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
+              /*castOp=*/nullptr, /*isPureTypeConversion=*/false)
               .front();
       replaceUsesOfBlockArgument(origArg, mat, converter);
       continue;
@@ -1523,7 +1575,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
     ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
-    UnrealizedConversionCastOp *castOp) {
+    UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   assert(TypeRange(inputs) != outputTypes &&
@@ -1535,6 +1587,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+  if (isPureTypeConversion)
+    convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
   if (!valuesToMap.empty())
     mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
@@ -1650,7 +1704,8 @@ void ConversionPatternRewriterImpl::replaceOp(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
           /*outputTypes=*/result.getType(), /*originalType=*/Type(),
-          currentTypeConverter);
+          currentTypeConverter, /*castOp=*/nullptr,
+          /*isPureTypeConversion=*/false);
       continue;
     }
 
@@ -2902,6 +2957,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   SmallVector<UnrealizedConversionCastOp> remainingCastOps;
   reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
 
+  // Drop markers.
+  for (UnrealizedConversionCastOp castOp : remainingCastOps)
+    castOp->removeAttr(kPureTypeConversionMarker);
+
   // Try to legalize all unresolved materializations.
   if (config.buildMaterializations) {
     IRRewriter rewriter(rewriterImpl.context, config.listener);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e4406e60ffead..5630d1540e4d5 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
   %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
   "test.invalid"(%0) : (f16) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: func @test_lookup_without_converter
+//       CHECK:   %[[producer:.*]] = "test.valid_producer"() : () -> i16
+//       CHECK:   %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
+//       CHECK:   "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+//       CHECK:   "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+func.func @test_lookup_without_converter() {
+  %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
+  "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
+  // Make sure that the second "replace_with_valid_consumer" lowering does not
+  // lookup the materialization that was created for the above op.
+  "test.replace_with_valid_consumer"(%0) : (i64) -> ()
+  // expected-remark at +1 {{op 'func.return' is not legalizable}}
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..843bd30a51aff 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
   Arguments<(ins Variadic<AnyType>)>;
 def TestTypeProducerOp : TEST_Op<"type_producer">,
   Results<(outs AnyType)>;
+def TestValidProducerOp : TEST_Op<"valid_producer">,
+  Results<(outs AnyType)>;
+def TestValidConsumerOp : TEST_Op<"valid_consumer">,
+  Arguments<(ins AnyType)>;
 def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
   Results<(outs AnyType)>;
 def TestTypeConsumerOp : TEST_Op<"type_consumer">,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..7150401bdbdce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern {
   }
 };
 
+/// Pattern that replaces test.replace_with_valid_producer with
+/// test.valid_producer and the specified type.
+class TestReplaceWithValidProducer : public ConversionPattern {
+public:
+  TestReplaceWithValidProducer(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto attr = op->getAttrOfType<TypeAttr>("type");
+    if (!attr)
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidProducerOp>(op, attr.getValue());
+    return success();
+  }
+};
+
+/// Pattern that replaces test.replace_with_valid_consumer with
+/// test.valid_consumer. Can be used with and without a type converter.
+class TestReplaceWithValidConsumer : public ConversionPattern {
+public:
+  TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.replace_with_valid_consumer", 1,
+                          ctx) {}
+  TestReplaceWithValidConsumer(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // with_converter present: pattern must have been initialized with a type
+    // converter.
+    // with_converter absent: pattern must have been initialized without a type
+    // converter.
+    if (op->hasAttr("with_converter") != static_cast<bool>(getTypeConverter()))
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidConsumerOp>(op, operands[0]);
+    return success();
+  }
+};
+
 /// This pattern matches a test.convert_block_args op. It either:
 /// a) Duplicates all block arguments,
 /// b) or: drops all block arguments and replaces each with 2x the first
@@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter {
   TestTypeConverter() {
     addConversion(convertType);
     addSourceMaterialization(materializeCast);
+    addTargetMaterialization(materializeCast);
   }
 
   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver
         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
         TestUndoPropertiesModification, TestEraseOp,
+        TestReplaceWithValidProducer, TestReplaceWithValidConsumer,
         TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
-                 TestBlockArgReplace>(&getContext(), converter);
+                 TestBlockArgReplace, TestReplaceWithValidConsumer>(
+        &getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp, OneRegionOp>();
+                      TerminatorOp, OneRegionOp, TestValidProducerOp,
+                      TestValidConsumerOp>();
     target.addLegalOp(OperationName("test.legal_op", &getContext()));
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();



More information about the Mlir-commits mailing list