[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Update docs for `remapValues` (PR #110414)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 29 01:43:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Simplify the nesting structure of "if" checks in `remapValues` and update the code comments.

This is what the comments stated in case there is no type converter:
```
      // TODO: What we should do here is just set `desiredType` to `origType`
      // and then handle the necessary type conversions after the conversion
      // process has finished. Unfortunately a lot of patterns currently rely on
      // receiving the new operands even if the types change, so we keep the
      // original behavior here for now until all of the patterns relying on
      // this get updated.
```

However, without a type converter it is not possible to perform any materializations. Furthermore, the absence of a type converter indicates that the pattern does not care about type legality. Therefore, the current implementation is correct and this TODO can be removed.

This TODO is outdated:
```
      // TODO: There currently isn't any mechanism to do 1->N type conversion
      // via the PatternRewriter replacement API, so for now we just ignore it.
```
1->N type conversions are already possible as part of block signature conversions. It is incorrect to just ignore such cases. However, there is currently no better way to handle 1->N conversions in this function because of infrastructure limitations. This is now clarified in the comments.


---
Full diff: https://github.com/llvm/llvm-project/pull/110414.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+38-32) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4693edadfb5eec..b5aab2416c3eb9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1092,44 +1092,50 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     SmallVectorImpl<Value> &remapped) {
   remapped.reserve(llvm::size(values));
 
-  SmallVector<Type, 1> legalTypes;
   for (const auto &it : llvm::enumerate(values)) {
     Value operand = it.value();
     Type origType = operand.getType();
+    Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
-    // If a converter was provided, get the desired legal types for this
-    // operand.
-    Type desiredType;
-    if (currentTypeConverter) {
-      // If there is no legal conversion, fail to match this pattern.
-      legalTypes.clear();
-      if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
-        Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-        notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
-          diag << "unable to convert type for " << valueDiagTag << " #"
-               << it.index() << ", type was " << origType;
-        });
-        return failure();
-      }
-      // TODO: There currently isn't any mechanism to do 1->N type conversion
-      // via the PatternRewriter replacement API, so for now we just ignore it.
-      if (legalTypes.size() == 1)
-        desiredType = legalTypes.front();
-    } else {
-      // TODO: What we should do here is just set `desiredType` to `origType`
-      // and then handle the necessary type conversions after the conversion
-      // process has finished. Unfortunately a lot of patterns currently rely on
-      // receiving the new operands even if the types change, so we keep the
-      // original behavior here for now until all of the patterns relying on
-      // this get updated.
+    if (!currentTypeConverter) {
+      // The current pattern does not have a type converter. I.e., it does not
+      // distinguish between legal and illegal types. For each operand, simply
+      // pass through the most recently mapped value.
+      remapped.push_back(mapping.lookupOrDefault(operand));
+      continue;
+    }
+
+    // If there is no legal conversion, fail to match this pattern.
+    SmallVector<Type, 1> legalTypes;
+    if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+      notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
+        diag << "unable to convert type for " << valueDiagTag << " #"
+             << it.index() << ", type was " << origType;
+      });
+      return failure();
     }
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
 
-    // Handle the case where the conversion was 1->1 and the new operand type
-    // isn't legal.
-    Type newOperandType = newOperand.getType();
-    if (currentTypeConverter && desiredType && newOperandType != desiredType) {
-      Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
+    if (legalTypes.size() != 1) {
+      // TODO: Parts of the dialect conversion infrastructure do not support
+      // 1->N type conversions yet. Therefore, if a type is converted to 0 or
+      // multiple types, the only thing that we can do for now is passing
+      // through the most recently mapped value. Fixing this requires
+      // improvements to the `ConversionValueMapping` (to be able to store 1:N
+      // mappings) and to the `ConversionPattern` adaptor handling (to be able
+      // to pass multiple remapped values for a single operand to the adaptor).
+      remapped.push_back(mapping.lookupOrDefault(operand));
+      continue;
+    }
+
+    // Handle 1->1 type conversions.
+    Type desiredType = legalTypes.front();
+    // Try to find a mapped value with the desired type. (Or the operand itself
+    // if the value is not mapped at all.)
+    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
+    if (newOperand.getType() != desiredType) {
+      // If the looked up value's type does not have the desired type, it means
+      // that the value was replaced with a value of different type and no
+      // source materialization was created yet.
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
           operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,

``````````

</details>


https://github.com/llvm/llvm-project/pull/110414


More information about the Mlir-commits mailing list