[Mlir-commits] [mlir] [mlir][Transforms] Make lookup without type converter unambiguous (PR #151747)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 11:44:53 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

When a conversion pattern is initialized without a type converter, the driver implementation currently looks up the most recently mapped value. This is undesirable because the most recently mapped value could be a materialization. I.e., the type of the value being looked up could depend on which other patterns have run before. This implementation make the type conversion infrastructure fragile and unpredictable.

The current implementation also contradicts the documentation in the markdown file. According to that documentation, the values provided by the adaptor should match the types of the operands of the match operation when running without a type converter. This mechanism is not desirable, either, for two reasons:

1. Some patterns have started to rely on receiving the most recently mapped value. Changing the behavior to the documented behavior will cause regressions. (And there would be no easy way to fix those without forcing the use of a type converter or extending the `getRemappedValue` API.)
2. It is more useful to receive the most recently mapped value. A value of the original operand type can be retrieved by using the operand of the matched operation. The adaptor is not needed at all in that case.


---

Patch is 20.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151747.diff


5 Files Affected:

- (modified) mlir/docs/DialectConversion.md (+23-11) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+114-53) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+17) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+4) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+47-2) 


``````````diff
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..89c8c78749957 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -202,17 +202,29 @@ struct MyConversionPattern : public ConversionPattern {
 
 #### Type Safety
 
-The types of the remapped operands provided to a conversion pattern must be of a
-type expected by the pattern. The expected types of a pattern are determined by
-a provided [TypeConverter](#type-converter). If no type converter is provided,
-the types of the remapped operands are expected to match the types of the
-original operands. If a type converter is provided, the types of the remapped
-operands are expected to be legal as determined by the converter. If the
-remapped operand types are not of an expected type, and a materialization to the
-expected type could not be performed, the pattern fails application before the
-`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
-explicitly ensure type safety, or sanitize the types of the incoming remapped
-operands. More information on type conversion is detailed in the
+The types of the remapped operands provided to a conversion pattern (through
+the adaptor or `ArrayRef` of operands) depend on type conversio rules.
+
+If the pattern was initialized with a [type converter](#type-converter), the
+conversion driver passes values whose types match the legalized types of the
+operands of the matched operation as per the type converter. To that end, the
+conversion driver may insert target materializations to convert the most
+recently mapped values to the expected legalized types. The driver tries to
+reuse existing materializations on a best-effort basis, but this is not
+guaranteed by the infrastructure. If the operand types of the matched op could
+not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
+is invoked.
+
+If the pattern was initialized without a type converter, the conversion driver
+passes the most recently mapped values to the pattern, excluding any
+materializations. Materializations are intentionally excluded because their
+presence may depend on other patterns. If a value of the same type as an
+operand is desired, users can directly take the respective operand from the
+matched operation.
+
+The above rules ensure that patterns do not have to explicitly ensure type
+safety, or sanitize the types of the incoming remapped operands. More
+information on type conversion is detailed in the
 [dedicated section](#type-conversion) below.
 
 ## Type Conversion
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 08803e082b057..c533a2f9d323f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -130,7 +130,16 @@ struct ConversionValueMapping {
   ///   recently mapped values.
   /// - If there is no mapping for the given values at all, return the given
   ///   value.
-  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+  ///
+  /// If `skipMaterializations` is true, materializations are not considered.
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+                              bool skipMaterializations = false) const;
+
+  /// Lookup a single value from the mapping. Look for actual replacements
+  /// first, then for materializations. The materializations lookup can be
+  /// skipped.
+  ValueVector lookupSingle(const ValueVector &from,
+                           bool skipMaterializations = false) const;
 
   template <typename T>
   struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -138,7 +147,7 @@ struct ConversionValueMapping {
   /// Map a value vector to the one provided.
   template <typename OldVal, typename NewVal>
   std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
-  map(OldVal &&oldVal, NewVal &&newVal) {
+  map(OldVal &&oldVal, NewVal &&newVal, bool isMaterialization = false) {
     LLVM_DEBUG({
       ValueVector next(newVal);
       while (true) {
@@ -151,42 +160,105 @@ struct ConversionValueMapping {
     });
     mappedTo.insert_range(newVal);
 
-    mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+    if (isMaterialization) {
+      // This is a materialization.
+      materializations[std::forward<OldVal>(oldVal)] =
+          std::forward<NewVal>(newVal);
+    } else {
+      mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+    }
   }
 
   /// Map a value vector or single value to the one provided.
   template <typename OldVal, typename NewVal>
   std::enable_if_t<!IsValueVector<OldVal>::value ||
                    !IsValueVector<NewVal>::value>
-  map(OldVal &&oldVal, NewVal &&newVal) {
+  map(OldVal &&oldVal, NewVal &&newVal, bool isMaterialization = false) {
     if constexpr (IsValueVector<OldVal>{}) {
-      map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+      map(std::forward<OldVal>(oldVal), ValueVector{newVal}, isMaterialization);
     } else if constexpr (IsValueVector<NewVal>{}) {
-      map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+      map(ValueVector{oldVal}, std::forward<NewVal>(newVal), isMaterialization);
     } else {
-      map(ValueVector{oldVal}, ValueVector{newVal});
+      map(ValueVector{oldVal}, ValueVector{newVal}, isMaterialization);
     }
   }
 
-  void map(Value oldVal, SmallVector<Value> &&newVal) {
+  void map(Value oldVal, SmallVector<Value> &&newVal,
+           bool isMaterialization = false) {
     map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
   }
 
   /// Drop the last mapping for the given values.
-  void erase(const ValueVector &value) { mapping.erase(value); }
+  void erase(const ValueVector &value) {
+    mapping.erase(value);
+    materializations.erase(value);
+  }
 
 private:
-  /// Current value mappings.
+  /// Mapping of actual replacements.
   DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
 
+  /// Mapping of materializations that are created only to resolve type
+  /// mismatches.
+  DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> materializations;
+
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
 ValueVector
-ConversionValueMapping::lookupOrDefault(Value from,
-                                        TypeRange desiredTypes) const {
+ConversionValueMapping::lookupSingle(const ValueVector &from,
+                                     bool skipMaterializations) const {
+  ValueVector next;
+  for (Value v : from) {
+    auto it = mapping.find({v});
+    if (it != mapping.end()) {
+      llvm::append_range(next, it->second);
+      continue;
+    }
+    if (skipMaterializations) {
+      next.push_back(v);
+      continue;
+    }
+
+    it = materializations.find({v});
+    if (it != materializations.end()) {
+      llvm::append_range(next, it->second);
+      continue;
+    }
+
+    next.push_back(v);
+  }
+
+  if (next != from)
+    return next;
+
+  // Otherwise: Check if there is a mapping for the entire vector. Such
+  // mappings are materializations. (N:M mapping are not supported for value
+  // replacements.)
+  //
+  // Note: From a correctness point of view, materializations do not have to
+  // be stored (and looked up) in the mapping. But for performance reasons,
+  // we choose to reuse existing IR (when possible) instead of creating it
+  // multiple times.
+  auto it = mapping.find(from);
+  if (it != mapping.end())
+    return it->second;
+
+  if (skipMaterializations)
+    return {};
+
+  it = materializations.find(from);
+  if (it != materializations.end())
+    return it->second;
+
+  return {};
+}
+
+ValueVector
+ConversionValueMapping::lookupOrDefault(Value from, TypeRange desiredTypes,
+                                        bool skipMaterializations) const {
   // Try to find the deepest values that have the desired types. If there is no
   // such mapping, simply return the deepest values.
   ValueVector desiredValue;
@@ -196,36 +268,13 @@ ConversionValueMapping::lookupOrDefault(Value from,
     if (TypeRange(ValueRange(current)) == desiredTypes)
       desiredValue = current;
 
-    // If possible, Replace each value with (one or multiple) mapped values.
-    ValueVector next;
-    for (Value v : current) {
-      auto it = mapping.find({v});
-      if (it != mapping.end()) {
-        llvm::append_range(next, it->second);
-      } else {
-        next.push_back(v);
-      }
-    }
-    if (next != current) {
-      // If at least one value was replaced, continue the lookup from there.
-      current = std::move(next);
-      continue;
-    }
-
-    // Otherwise: Check if there is a mapping for the entire vector. Such
-    // mappings are materializations. (N:M mapping are not supported for value
-    // replacements.)
-    //
-    // Note: From a correctness point of view, materializations do not have to
-    // be stored (and looked up) in the mapping. But for performance reasons,
-    // we choose to reuse existing IR (when possible) instead of creating it
-    // multiple times.
-    auto it = mapping.find(current);
-    if (it == mapping.end()) {
+    ValueVector next = lookupSingle(current, skipMaterializations);
+    if (next.empty()) {
       // No mapping found: The lookup stops here.
       break;
     }
-    current = it->second;
+
+    current = std::move(next);
   } while (true);
 
   // If the desired values were found use them, otherwise default to the leaf
@@ -929,7 +978,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   ///   recently mapped values.
   /// - If there is no mapping for the given values at all, return the given
   ///   value.
-  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+  ///
+  /// If `skipMaterializations` is true, materializations are not considered.
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+                              bool skipMaterializations = false) const;
 
   /// Lookup the given value within the map, or return an empty vector if the
   /// value is not mapped. If it is mapped, this follows the same behavior
@@ -992,11 +1044,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
   /// the results of the unresolved materialization in the conversion value
   /// mapping.
+  ///
+  /// If `isMaterialization` is true, the materialization is created to resolve
+  /// a type mismatch. (Replacement values that are created "out of thin air"
+  /// are treated like unresolved materializations, but `isMaterialization` is
+  /// set to "false" in that case.)
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
       ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
-      UnrealizedConversionCastOp *castOp = nullptr);
+      UnrealizedConversionCastOp *castOp = nullptr,
+      bool isMaterialization = false);
 
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
@@ -1258,10 +1316,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 //===----------------------------------------------------------------------===//
 
-ValueVector
-ConversionPatternRewriterImpl::lookupOrDefault(Value from,
-                                               TypeRange desiredTypes) const {
-  return mapping.lookupOrDefault(from, desiredTypes);
+ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
+    Value from, TypeRange desiredTypes, bool skipMaterializations) const {
+  return mapping.lookupOrDefault(from, desiredTypes, skipMaterializations);
 }
 
 ValueVector
@@ -1318,10 +1375,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
     if (!currentTypeConverter) {
-      // The current pattern does not have a type converter. I.e., it does not
-      // distinguish between legal and illegal types. For each operand, simply
-      // pass through the most recently mapped values.
-      remapped.push_back(lookupOrDefault(operand));
+      // The current pattern does not have a type converter. Pass the most
+      // recently mapped values, excluding materializations. Materializations
+      // are intentionally excluded because their presence may depend on other
+      // patterns. Including materializations would make the lookup fragile
+      // and unpredictable.
+      remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
+                                         /*skipMaterializations=*/true));
       continue;
     }
 
@@ -1350,11 +1410,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     }
 
     // Create a materialization for the most recently mapped values.
-    repl = lookupOrDefault(operand);
+    repl = lookupOrDefault(operand, /*desiredTypes=*/{},
+                           /*skipMaterializations=*/true);
     ValueRange castValues = buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
         /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
-        /*originalType=*/origType, currentTypeConverter);
+        /*originalType=*/origType, currentTypeConverter, nullptr, true);
     remapped.push_back(castValues);
   }
   return success();
@@ -1517,7 +1578,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
     ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
-    UnrealizedConversionCastOp *castOp) {
+    UnrealizedConversionCastOp *castOp, bool isMaterialization) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   assert(TypeRange(inputs) != outputTypes &&
@@ -1530,7 +1591,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   auto convertOp =
       UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
   if (!valuesToMap.empty())
-    mapping.map(valuesToMap, convertOp.getResults());
+    mapping.map(valuesToMap, convertOp.getResults(), isMaterialization);
   if (castOp)
     *castOp = convertOp;
   unresolvedMaterializations[convertOp] =
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 68c863cff69bf..6b4c2c78cd556 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -557,3 +557,20 @@ func.func @test_multiple_1_to_n_replacement() {
   %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
   "test.invalid"(%0) : (f16) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: func @test_lookup_without_converter
+//       CHECK:   %[[producer:.*]] = "test.valid_producer"() : () -> i16
+//       CHECK:   %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
+//       CHECK:   "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+//       CHECK:   "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+func.func @test_lookup_without_converter() {
+  %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
+  "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
+  // Make sure that the second "replace_with_valid_consumer" lowering does not
+  // lookup the materialization that was created for the above op.
+  "test.replace_with_valid_consumer"(%0) : (i64) -> ()
+  // expected-remark at +1 {{op 'func.return' is not legalizable}}
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..843bd30a51aff 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
   Arguments<(ins Variadic<AnyType>)>;
 def TestTypeProducerOp : TEST_Op<"type_producer">,
   Results<(outs AnyType)>;
+def TestValidProducerOp : TEST_Op<"valid_producer">,
+  Results<(outs AnyType)>;
+def TestValidConsumerOp : TEST_Op<"valid_consumer">,
+  Arguments<(ins AnyType)>;
 def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
   Results<(outs AnyType)>;
 def TestTypeConsumerOp : TEST_Op<"type_consumer">,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..7150401bdbdce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern {
   }
 };
 
+/// Pattern that replaces test.replace_with_valid_producer with
+/// test.valid_producer and the specified type.
+class TestReplaceWithValidProducer : public ConversionPattern {
+public:
+  TestReplaceWithValidProducer(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto attr = op->getAttrOfType<TypeAttr>("type");
+    if (!attr)
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidProducerOp>(op, attr.getValue());
+    return success();
+  }
+};
+
+/// Pattern that replaces test.replace_with_valid_consumer with
+/// test.valid_consumer. Can be used with and without a type converter.
+class TestReplaceWithValidConsumer : public ConversionPattern {
+public:
+  TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.replace_with_valid_consumer", 1,
+                          ctx) {}
+  TestReplaceWithValidConsumer(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // with_converter present: pattern must have been initialized with a type
+    // converter.
+    // with_converter absent: pattern must have been initialized without a type
+    // converter.
+    if (op->hasAttr("with_converter") != static_cast<bool>(getTypeConverter()))
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidConsumerOp>(op, operands[0]);
+    return success();
+  }
+};
+
 /// This pattern matches a test.convert_block_args op. It either:
 /// a) Duplicates all block arguments,
 /// b) or: drops all block arguments and replaces each with 2x the first
@@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter {
   TestTypeConverter() {
     addConversion(convertType);
     addSourceMaterialization(materializeCast);
+    addTargetMaterialization(materializeCast);
   }
 
   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver
         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
         TestUndoPropertiesModification, TestEraseOp,
+        TestReplaceWithValidProducer, TestReplaceWithValidConsumer,
         TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
-                 TestBlockArgReplace>(&getContext(), converter);
+                 TestBlockArgReplace, TestReplaceWithValidConsumer>(
+        &getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp,...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/151747


More information about the Mlir-commits mailing list