[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Build unresolved materialization for replaced ops (PR #101514)
Matthias Springer
llvmlistbot at llvm.org
Thu Aug 15 01:54:53 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/101514
>From 9bd5beec743f2e60891147bff9d3db2d97d822d6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 1 Aug 2024 18:24:56 +0200
Subject: [PATCH 1/2] [mlir][Transforms] Dialect conversion: Build unresolved
materialization for replaced ops
When inserting an argument/source/target materialization, the dialect conversion framework first inserts a "dummy" `unrealized_conversion_cast` op (during the rewrite process) and then (in the "finialize" phase) replaces these cast ops with the IR generated by the type converter callback.
This is the case for all materializations, except when ops are being replaced with values that have a different type. In that case, the dialect conversion currently directly emits a source materialization. This commit changes the implementation, such that a temporary `unrealized_conversion_cast` is also inserted in this case.
This commit simplifies the code base: all materializations now happen in `legalizeUnresolvedMaterialization`. This commit makes it possible to decouple source/target/argument materializations from the dialect conversion (to reduce the complexity of the code base). Such materializations can then also be optional. This will be implemented in a follow-up commit.
---
.../Transforms/Utils/DialectConversion.cpp | 126 +++++++-----------
.../VectorToSPIRV/vector-to-spirv.mlir | 4 +-
.../Transforms/finalizing-bufferize.mlir | 3 +-
.../test-legalize-type-conversion.mlir | 11 +-
4 files changed, 57 insertions(+), 87 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8f9b21b7ee1e5b..b0b5a8247b53f4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2348,6 +2348,12 @@ struct OperationConverter {
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);
+ /// Legalize the types of converted op results.
+ LogicalResult legalizeConvertedOpResultTypes(
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ DenseMap<Value, SmallVector<Value>> &inverseMapping);
+
/// Legalize any unresolved type materializations.
LogicalResult legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
@@ -2359,14 +2365,6 @@ struct OperationConverter {
legalizeErasedResult(Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl);
- /// Legalize an operation result that was replaced with a value of a different
- /// type.
- LogicalResult legalizeChangedResultType(
- Operation *op, OpResult result, Value newValue,
- const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- const DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
/// Dialect conversion configuration.
ConversionConfig config;
@@ -2459,10 +2457,42 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
+ if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
+ inverseMapping)))
+ return failure();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();
+ return success();
+}
+/// Finds a user of the given value, or of any other value that the given value
+/// replaced, that was not replaced in the conversion process.
+static Operation *findLiveUserOfReplaced(
+ Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+ SmallVector<Value> worklist(1, initialValue);
+ while (!worklist.empty()) {
+ Value value = worklist.pop_back_val();
+
+ // Walk the users of this value to see if there are any live users that
+ // weren't replaced during conversion.
+ auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
+ return rewriterImpl.isOpIgnored(user);
+ });
+ if (liveUserIt != value.user_end())
+ return *liveUserIt;
+ auto mapIt = inverseMapping.find(value);
+ if (mapIt != inverseMapping.end())
+ worklist.append(mapIt->second);
+ }
+ return nullptr;
+}
+
+LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// Process requested operation replacements.
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
@@ -2485,14 +2515,21 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
if (result.getType() == newValue.getType())
continue;
+ Operation *liveUser =
+ findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
+ if (!liveUser)
+ continue;
+
// Legalize this result.
- rewriter.setInsertionPoint(op);
- if (failed(legalizeChangedResultType(
- op, result, newValue, opReplacement->getConverter(), rewriter,
- rewriterImpl, inverseMapping)))
- return failure();
+ Value castValue = rewriterImpl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
+ /*inputs=*/newValue, /*outputType=*/result.getType(),
+ opReplacement->getConverter());
+ rewriterImpl.mapping.map(result, castValue);
+ inverseMapping[castValue].push_back(result);
}
}
+
return success();
}
@@ -2502,7 +2539,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
- // `legalizeChangedResultType`.
+ // `legalizeConvertedOpResultTypes`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
@@ -2832,67 +2869,6 @@ LogicalResult OperationConverter::legalizeErasedResult(
return success();
}
-/// Finds a user of the given value, or of any other value that the given value
-/// replaced, that was not replaced in the conversion process.
-static Operation *findLiveUserOfReplaced(
- Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
- const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- SmallVector<Value> worklist(1, initialValue);
- while (!worklist.empty()) {
- Value value = worklist.pop_back_val();
-
- // Walk the users of this value to see if there are any live users that
- // weren't replaced during conversion.
- auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
- return rewriterImpl.isOpIgnored(user);
- });
- if (liveUserIt != value.user_end())
- return *liveUserIt;
- auto mapIt = inverseMapping.find(value);
- if (mapIt != inverseMapping.end())
- worklist.append(mapIt->second);
- }
- return nullptr;
-}
-
-LogicalResult OperationConverter::legalizeChangedResultType(
- Operation *op, OpResult result, Value newValue,
- const TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- Operation *liveUser =
- findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
- if (!liveUser)
- return success();
-
- // Functor used to emit a conversion error for a failed materialization.
- auto emitConversionError = [&] {
- InFlightDiagnostic diag = op->emitError()
- << "failed to materialize conversion for result #"
- << result.getResultNumber() << " of operation '"
- << op->getName()
- << "' that remained live after conversion";
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- };
-
- // If the replacement has a type converter, attempt to materialize a
- // conversion back to the original type.
- if (!replConverter)
- return emitConversionError();
-
- // Materialize a conversion for this live result value.
- Type resultType = result.getType();
- Value convertedValue = replConverter->materializeSourceConversion(
- rewriter, op->getLoc(), resultType, newValue);
- if (!convertedValue)
- return emitConversionError();
-
- rewriterImpl.mapping.map(result, convertedValue);
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index dd0ed77470a259..d8570bdaf4247f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -560,8 +560,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
-// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
-// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
+// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: return %[[CAST0]], %[[CAST1]]
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d928..a192434c5accf8 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -78,9 +78,8 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
// memref.cast.
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
%0 = bufferization.to_tensor %m : memref<?xf32>
- // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
+ // expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
%1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
- // expected-note @+1 {{see existing live user here}}
return %1 : memref<?xf32, strided<[1], offset: ?>>
}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index d0563fed8e5d94..252b990210a180 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -20,20 +20,16 @@ func.func @test_valid_arg_materialization(%arg0: i64) {
// -----
func.func @test_invalid_result_materialization() {
- // expected-error at below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
+ // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16
-
- // expected-note at below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
// -----
func.func @test_invalid_result_materialization() {
- // expected-error at below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
+ // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.type_producer"() : () -> f16
-
- // expected-note at below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
@@ -51,9 +47,8 @@ func.func @test_transitive_use_materialization() {
// -----
func.func @test_transitive_use_invalid_materialization() {
- // expected-error at below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
+ // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f16' that remained live after conversion}}
%result = "test.another_type_producer"() : () -> f16
- // expected-note at below {{see existing live user here}}
"foo.return"(%result) : (f16) -> ()
}
>From 7d9ce368fdb4b0fd830b820bf07fd9305e1500a5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 12 Aug 2024 09:37:10 +0200
Subject: [PATCH 2/2] Update mlir/lib/Transforms/Utils/DialectConversion.cpp
Co-authored-by: Jakub Kuderski <jakub at nod-labs.com>
---
mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b0b5a8247b53f4..11e593cebc09b3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2471,7 +2471,7 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
static Operation *findLiveUserOfReplaced(
Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- SmallVector<Value> worklist(1, initialValue);
+ SmallVector<Value> worklist = {initialValue};
while (!worklist.empty()) {
Value value = worklist.pop_back_val();
More information about the Mlir-commits
mailing list