[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: add missing argument materialization. (PR #121200)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 27 02:36:12 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Benjamin Chetioui (bchetioui)
<details>
<summary>Changes</summary>
When replacing a block argument, previously to #<!-- -->117513, we would automatically insert a N->1 argument materialization. After #<!-- -->117513, this is no longer the case for 1->1 mappings.
As a result, no materialization is added until `ReplaceBlockArgRewrite` is committed, where `findOrBuildReplacementValue` inserts a source materialization. The switch from an argument materialization to a source materialization causes legalization to fail.
[Here is an example reproducer](https://github.com/openxla/xla/blob/eb9d08bae564680ff465d772ceb70f4d84542e8c/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir#L3016-L3031).
---
Full diff: https://github.com/llvm/llvm-project/pull/121200.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+9-3)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 255b0ba2559ee6..5229c0f8d7f2ce 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1375,12 +1375,18 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
+ auto insertPoint = OpBuilder::InsertPoint(newBlock, newBlock->begin());
if (replArgs.size() == 1) {
- mapping.map(origArg, replArgs.front());
+ // We need an argument materialization to replace the block argument.
+ Value argMat = buildUnresolvedMaterialization(
+ MaterializationKind::Argument, insertPoint, origArg.getLoc(),
+ /*valueToMap=*/origArg, /*inputs=*/replArgs,
+ /*outputType=*/origArg.getType(), /*originalType=*/Type(), converter);
+ mapping.map(origArg, argMat);
} else {
insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ insertPoint, origArg.getLoc(), /*replacements=*/replArgs,
+ /*originalValue=*/origArg, converter);
}
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/121200
More information about the Mlir-commits
mailing list