[Mlir-commits] [mlir] [mlir][Transforms] Make lookup without type converter unambiguous (PR #151747)

Matthias Springer llvmlistbot at llvm.org
Sat Aug 2 06:37:56 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/151747

>From d087c2f8fb250f94654447792a9d8c91da8063ed Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 1 Aug 2025 17:15:23 +0000
Subject: [PATCH] [mlir][Transforms] Make lookup without type converter
 unambiguous

---
 mlir/docs/DialectConversion.md                |  34 ++--
 .../Transforms/Utils/DialectConversion.cpp    | 185 ++++++++++++------
 mlir/test/Transforms/test-legalizer.mlir      |  17 ++
 mlir/test/lib/Dialect/Test/TestOps.td         |   4 +
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  49 ++++-
 5 files changed, 220 insertions(+), 69 deletions(-)

diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..89c8c78749957 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -202,17 +202,29 @@ struct MyConversionPattern : public ConversionPattern {
 
 #### Type Safety
 
-The types of the remapped operands provided to a conversion pattern must be of a
-type expected by the pattern. The expected types of a pattern are determined by
-a provided [TypeConverter](#type-converter). If no type converter is provided,
-the types of the remapped operands are expected to match the types of the
-original operands. If a type converter is provided, the types of the remapped
-operands are expected to be legal as determined by the converter. If the
-remapped operand types are not of an expected type, and a materialization to the
-expected type could not be performed, the pattern fails application before the
-`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
-explicitly ensure type safety, or sanitize the types of the incoming remapped
-operands. More information on type conversion is detailed in the
+The types of the remapped operands provided to a conversion pattern (through
+the adaptor or `ArrayRef` of operands) depend on type conversio rules.
+
+If the pattern was initialized with a [type converter](#type-converter), the
+conversion driver passes values whose types match the legalized types of the
+operands of the matched operation as per the type converter. To that end, the
+conversion driver may insert target materializations to convert the most
+recently mapped values to the expected legalized types. The driver tries to
+reuse existing materializations on a best-effort basis, but this is not
+guaranteed by the infrastructure. If the operand types of the matched op could
+not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
+is invoked.
+
+If the pattern was initialized without a type converter, the conversion driver
+passes the most recently mapped values to the pattern, excluding any
+materializations. Materializations are intentionally excluded because their
+presence may depend on other patterns. If a value of the same type as an
+operand is desired, users can directly take the respective operand from the
+matched operation.
+
+The above rules ensure that patterns do not have to explicitly ensure type
+safety, or sanitize the types of the incoming remapped operands. More
+information on type conversion is detailed in the
 [dedicated section](#type-conversion) below.
 
 ## Type Conversion
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..7c13c98b7b0d0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -131,7 +131,16 @@ struct ConversionValueMapping {
   ///   recently mapped values.
   /// - If there is no mapping for the given values at all, return the given
   ///   value.
-  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+  ///
+  /// If `skipMaterializations` is true, materializations are not considered.
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+                              bool skipMaterializations = false) const;
+
+  /// Lookup a value from the mapping. (Just once, not following the chain of
+  /// potential mappings.) Look for actual replacements first, then for
+  /// materializations. The materializations lookup can be skipped.
+  ValueVector lookupSingleStep(const ValueVector &from,
+                               bool skipMaterializations = false) const;
 
   template <typename T>
   struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -139,8 +148,8 @@ struct ConversionValueMapping {
   /// Map a value vector to the one provided.
   template <typename OldVal, typename NewVal>
   std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
-  map(OldVal &&oldVal, NewVal &&newVal) {
-    LLVM_DEBUG({
+  map(OldVal &&oldVal, NewVal &&newVal, bool isOnlyTypeConversion = false) {
+    auto checkCircularMapping = [&](auto &mapping) {
       ValueVector next(newVal);
       while (true) {
         assert(next != oldVal && "inserting cyclic mapping");
@@ -149,45 +158,117 @@ struct ConversionValueMapping {
           break;
         next = it->second;
       }
-    });
+    };
+    (void)checkCircularMapping;
+
     mappedTo.insert_range(newVal);
 
-    mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+    if (isOnlyTypeConversion) {
+      // This is a materialization.
+      LLVM_DEBUG({ checkCircularMapping(materializations); });
+      materializations[std::forward<OldVal>(oldVal)] =
+          std::forward<NewVal>(newVal);
+    } else {
+      // This is a regular value replacement.
+      LLVM_DEBUG({ checkCircularMapping(mapping); });
+      mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+    }
   }
 
   /// Map a value vector or single value to the one provided.
   template <typename OldVal, typename NewVal>
   std::enable_if_t<!IsValueVector<OldVal>::value ||
                    !IsValueVector<NewVal>::value>
-  map(OldVal &&oldVal, NewVal &&newVal) {
+  map(OldVal &&oldVal, NewVal &&newVal, bool isOnlyTypeConversion = false) {
     if constexpr (IsValueVector<OldVal>{}) {
-      map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+      map(std::forward<OldVal>(oldVal), ValueVector{newVal},
+          isOnlyTypeConversion);
     } else if constexpr (IsValueVector<NewVal>{}) {
-      map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+      map(ValueVector{oldVal}, std::forward<NewVal>(newVal),
+          isOnlyTypeConversion);
     } else {
-      map(ValueVector{oldVal}, ValueVector{newVal});
+      map(ValueVector{oldVal}, ValueVector{newVal}, isOnlyTypeConversion);
     }
   }
 
-  void map(Value oldVal, SmallVector<Value> &&newVal) {
+  void map(Value oldVal, SmallVector<Value> &&newVal,
+           bool isOnlyTypeConversion = false) {
     map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
   }
 
   /// Drop the last mapping for the given values.
-  void erase(const ValueVector &value) { mapping.erase(value); }
+  void erase(const ValueVector &value) {
+    mapping.erase(value);
+    materializations.erase(value);
+  }
 
 private:
-  /// Current value mappings.
+  /// Mapping of actual replacements.
   DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
 
+  /// Mapping of materializations that are created only to resolve type
+  /// mismatches.
+  DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> materializations;
+
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
 ValueVector
-ConversionValueMapping::lookupOrDefault(Value from,
-                                        TypeRange desiredTypes) const {
+ConversionValueMapping::lookupSingleStep(const ValueVector &from,
+                                         bool skipMaterializations) const {
+  // Continue the lookup on each value separately. (Each value could have been
+  // mapped to one or multiple other values.)
+  ValueVector next;
+  for (Value v : from) {
+    // First check regular value replacements.
+    auto it = mapping.find({v});
+    if (it != mapping.end()) {
+      llvm::append_range(next, it->second);
+      continue;
+    }
+    if (skipMaterializations) {
+      next.push_back(v);
+      continue;
+    }
+    // Then check materializations.
+    it = materializations.find({v});
+    if (it != materializations.end()) {
+      llvm::append_range(next, it->second);
+      continue;
+    }
+    next.push_back(v);
+  }
+
+  if (next != from)
+    return next;
+
+  // Otherwise: Check if there is a mapping for the entire vector. Such
+  // mappings are materializations. (N:M mapping are not supported for value
+  // replacements.)
+  //
+  // Note: From a correctness point of view, materializations do not have to
+  // be stored (and looked up) in the mapping. But for performance reasons,
+  // we choose to reuse existing IR (when possible) instead of creating it
+  // multiple times.
+  //
+  // First check regular value replacements.
+  auto it = mapping.find(from);
+  if (it != mapping.end())
+    return it->second;
+  if (skipMaterializations)
+    return {};
+  // Then check materializations.
+  it = materializations.find(from);
+  if (it != materializations.end())
+    return it->second;
+  return {};
+}
+
+ValueVector
+ConversionValueMapping::lookupOrDefault(Value from, TypeRange desiredTypes,
+                                        bool skipMaterializations) const {
   // Try to find the deepest values that have the desired types. If there is no
   // such mapping, simply return the deepest values.
   ValueVector desiredValue;
@@ -197,36 +278,13 @@ ConversionValueMapping::lookupOrDefault(Value from,
     if (TypeRange(ValueRange(current)) == desiredTypes)
       desiredValue = current;
 
-    // If possible, Replace each value with (one or multiple) mapped values.
-    ValueVector next;
-    for (Value v : current) {
-      auto it = mapping.find({v});
-      if (it != mapping.end()) {
-        llvm::append_range(next, it->second);
-      } else {
-        next.push_back(v);
-      }
-    }
-    if (next != current) {
-      // If at least one value was replaced, continue the lookup from there.
-      current = std::move(next);
-      continue;
-    }
-
-    // Otherwise: Check if there is a mapping for the entire vector. Such
-    // mappings are materializations. (N:M mapping are not supported for value
-    // replacements.)
-    //
-    // Note: From a correctness point of view, materializations do not have to
-    // be stored (and looked up) in the mapping. But for performance reasons,
-    // we choose to reuse existing IR (when possible) instead of creating it
-    // multiple times.
-    auto it = mapping.find(current);
-    if (it == mapping.end()) {
+    ValueVector next = lookupSingleStep(current, skipMaterializations);
+    if (next.empty()) {
       // No mapping found: The lookup stops here.
       break;
     }
-    current = it->second;
+
+    current = std::move(next);
   } while (true);
 
   // If the desired values were found use them, otherwise default to the leaf
@@ -930,7 +988,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   ///   recently mapped values.
   /// - If there is no mapping for the given values at all, return the given
   ///   value.
-  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+  ///
+  /// If `skipMaterializations` is true, materializations are not considered.
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+                              bool skipMaterializations = false) const;
 
   /// Lookup the given value within the map, or return an empty vector if the
   /// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +1054,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
   /// the results of the unresolved materialization in the conversion value
   /// mapping.
+  ///
+  /// If `isOnlyTypeConversion` is "true", the materialization is created to
+  /// resolve a type mismatch, and not a regular value replacement issued by
+  /// the user. (Replacement values that are created "out of thin air" are
+  /// treated appear like unresolved materializations, but are not just type
+  /// conversions.)
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
       ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
-      UnrealizedConversionCastOp *castOp = nullptr);
+      UnrealizedConversionCastOp *castOp = nullptr,
+      bool isOnlyTypeConversion = true);
 
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1332,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 //===----------------------------------------------------------------------===//
 
-ValueVector
-ConversionPatternRewriterImpl::lookupOrDefault(Value from,
-                                               TypeRange desiredTypes) const {
-  return mapping.lookupOrDefault(from, desiredTypes);
+ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
+    Value from, TypeRange desiredTypes, bool skipMaterializations) const {
+  return mapping.lookupOrDefault(from, desiredTypes, skipMaterializations);
 }
 
 ValueVector
@@ -1324,10 +1391,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
     if (!currentTypeConverter) {
-      // The current pattern does not have a type converter. I.e., it does not
-      // distinguish between legal and illegal types. For each operand, simply
-      // pass through the most recently mapped values.
-      remapped.push_back(lookupOrDefault(operand));
+      // The current pattern does not have a type converter. Pass the most
+      // recently mapped values, excluding materializations. Materializations
+      // are intentionally excluded because their presence may depend on other
+      // patterns. Including materializations would make the lookup fragile
+      // and unpredictable.
+      remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
+                                         /*skipMaterializations=*/true));
       continue;
     }
 
@@ -1356,7 +1426,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     }
 
     // Create a materialization for the most recently mapped values.
-    repl = lookupOrDefault(operand);
+    repl = lookupOrDefault(operand, /*desiredTypes=*/{},
+                           /*skipMaterializations=*/true);
     ValueRange castValues = buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
         /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1553,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
               OpBuilder::InsertPoint(newBlock, newBlock->begin()),
               origArg.getLoc(),
               /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
-              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
+              /*castOp=*/nullptr, /*isOnlyTypeConversion=*/false)
               .front();
       replaceUsesOfBlockArgument(origArg, mat, converter);
       continue;
@@ -1523,7 +1595,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
     ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
-    UnrealizedConversionCastOp *castOp) {
+    UnrealizedConversionCastOp *castOp, bool isOnlyTypeConversion) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   assert(TypeRange(inputs) != outputTypes &&
@@ -1536,7 +1608,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   auto convertOp =
       UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
   if (!valuesToMap.empty())
-    mapping.map(valuesToMap, convertOp.getResults());
+    mapping.map(valuesToMap, convertOp.getResults(), isOnlyTypeConversion);
   if (castOp)
     *castOp = convertOp;
   unresolvedMaterializations[convertOp] =
@@ -1650,7 +1722,8 @@ void ConversionPatternRewriterImpl::replaceOp(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
           /*outputTypes=*/result.getType(), /*originalType=*/Type(),
-          currentTypeConverter);
+          currentTypeConverter, /*castOp=*/nullptr,
+          /*isOnlyTypeConversion=*/false);
       continue;
     }
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e4406e60ffead..5630d1540e4d5 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
   %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
   "test.invalid"(%0) : (f16) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: func @test_lookup_without_converter
+//       CHECK:   %[[producer:.*]] = "test.valid_producer"() : () -> i16
+//       CHECK:   %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
+//       CHECK:   "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+//       CHECK:   "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+func.func @test_lookup_without_converter() {
+  %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
+  "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
+  // Make sure that the second "replace_with_valid_consumer" lowering does not
+  // lookup the materialization that was created for the above op.
+  "test.replace_with_valid_consumer"(%0) : (i64) -> ()
+  // expected-remark at +1 {{op 'func.return' is not legalizable}}
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..843bd30a51aff 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
   Arguments<(ins Variadic<AnyType>)>;
 def TestTypeProducerOp : TEST_Op<"type_producer">,
   Results<(outs AnyType)>;
+def TestValidProducerOp : TEST_Op<"valid_producer">,
+  Results<(outs AnyType)>;
+def TestValidConsumerOp : TEST_Op<"valid_consumer">,
+  Arguments<(ins AnyType)>;
 def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
   Results<(outs AnyType)>;
 def TestTypeConsumerOp : TEST_Op<"type_consumer">,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..7150401bdbdce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern {
   }
 };
 
+/// Pattern that replaces test.replace_with_valid_producer with
+/// test.valid_producer and the specified type.
+class TestReplaceWithValidProducer : public ConversionPattern {
+public:
+  TestReplaceWithValidProducer(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto attr = op->getAttrOfType<TypeAttr>("type");
+    if (!attr)
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidProducerOp>(op, attr.getValue());
+    return success();
+  }
+};
+
+/// Pattern that replaces test.replace_with_valid_consumer with
+/// test.valid_consumer. Can be used with and without a type converter.
+class TestReplaceWithValidConsumer : public ConversionPattern {
+public:
+  TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.replace_with_valid_consumer", 1,
+                          ctx) {}
+  TestReplaceWithValidConsumer(MLIRContext *ctx)
+      : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // with_converter present: pattern must have been initialized with a type
+    // converter.
+    // with_converter absent: pattern must have been initialized without a type
+    // converter.
+    if (op->hasAttr("with_converter") != static_cast<bool>(getTypeConverter()))
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidConsumerOp>(op, operands[0]);
+    return success();
+  }
+};
+
 /// This pattern matches a test.convert_block_args op. It either:
 /// a) Duplicates all block arguments,
 /// b) or: drops all block arguments and replaces each with 2x the first
@@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter {
   TestTypeConverter() {
     addConversion(convertType);
     addSourceMaterialization(materializeCast);
+    addTargetMaterialization(materializeCast);
   }
 
   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver
         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
         TestUndoPropertiesModification, TestEraseOp,
+        TestReplaceWithValidProducer, TestReplaceWithValidConsumer,
         TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
-                 TestBlockArgReplace>(&getContext(), converter);
+                 TestBlockArgReplace, TestReplaceWithValidConsumer>(
+        &getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp, OneRegionOp>();
+                      TerminatorOp, OneRegionOp, TestValidProducerOp,
+                      TestValidConsumerOp>();
     target.addLegalOp(OperationName("test.legal_op", &getContext()));
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();



More information about the Mlir-commits mailing list