[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: add missing argument materialization. (PR #121200)

Benjamin Chetioui llvmlistbot at llvm.org
Fri Dec 27 02:35:36 PST 2024


https://github.com/bchetioui created https://github.com/llvm/llvm-project/pull/121200

When replacing a block argument, previously to #117513, we would automatically insert a N->1 argument materialization. After #117513, this is no longer the case for 1->1 mappings.

As a result, no materialization is added until `ReplaceBlockArgRewrite` is committed, where `findOrBuildReplacementValue` inserts a source materialization. The switch from an argument materialization to a source materialization causes legalization to fail.

[Here is an example reproducer](https://github.com/openxla/xla/blob/eb9d08bae564680ff465d772ceb70f4d84542e8c/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir#L3016-L3031).

>From bdeb48e6dd7aa9c01c0ec388cb8d53777a221c96 Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui at google.com>
Date: Fri, 27 Dec 2024 10:27:58 +0000
Subject: [PATCH] [mlir][Transforms] Dialect conversion: add missing argument
 materialization.

When replacing a block argument, previously to #117513, we would
automatically insert a N->1 argument materialization. After #117513,
this is no longer the case for 1->1 mappings.

As a result, no materialization is added until `ReplaceBlockArgRewrite`
is committed, where `findOrBuildReplacementValue` inserts a source
materialization. The switch from an argument materialization to a source
materialization causes legalization to fail.
---
 mlir/lib/Transforms/Utils/DialectConversion.cpp | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 255b0ba2559ee6..5229c0f8d7f2ce 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1375,12 +1375,18 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
+    auto insertPoint = OpBuilder::InsertPoint(newBlock, newBlock->begin());
     if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
+      // We need an argument materialization to replace the block argument.
+      Value argMat = buildUnresolvedMaterialization(
+          MaterializationKind::Argument, insertPoint, origArg.getLoc(),
+          /*valueToMap=*/origArg, /*inputs=*/replArgs,
+          /*outputType=*/origArg.getType(), /*originalType=*/Type(), converter);
+      mapping.map(origArg, argMat);
     } else {
       insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+          insertPoint, origArg.getLoc(), /*replacements=*/replArgs,
+          /*originalValue=*/origArg, converter);
     }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }



More information about the Mlir-commits mailing list