[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect conversion: Unify materialization of value replacements (PR #108381)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Sep 12 05:56:13 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
PR #<!-- -->106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`.
This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterated over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed.
---
Full diff: https://github.com/llvm/llvm-project/pull/108381.diff
2 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+42-92)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+2-2)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ed15b571f01883..0556b4ab833c30 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2336,17 +2336,6 @@ struct OperationConverter {
/// remaining artifacts and complete the conversion.
LogicalResult finalize(ConversionPatternRewriter &rewriter);
- /// Legalize the types of converted block arguments.
- LogicalResult
- legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl);
-
- /// Legalize the types of converted op results.
- LogicalResult legalizeConvertedOpResultTypes(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
/// Dialect conversion configuration.
ConversionConfig config;
@@ -2510,19 +2499,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}
-LogicalResult
-OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
- ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
- return failure();
- DenseMap<Value, SmallVector<Value>> inverseMapping =
- rewriterImpl.mapping.getInverse();
- if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
- inverseMapping)))
- return failure();
- return success();
-}
-
/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
static Operation *findLiveUserOfReplaced(
@@ -2546,87 +2522,61 @@ static Operation *findLiveUserOfReplaced(
return nullptr;
}
-LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- DenseMap<Value, SmallVector<Value>> &inverseMapping) {
- // Process requested operation replacements.
- for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
- auto *opReplacement =
- dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!opReplacement)
- continue;
- Operation *op = opReplacement->getOperation();
- for (OpResult result : op->getResults()) {
- // If the type of this op result changed and the result is still live,
- // we need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+ if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+ return std::make_pair(opRewrite->getOperation()->getResults(),
+ opRewrite->getConverter());
+ if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+ return std::make_pair(blockRewrite->getOrigBlock()->getArguments(),
+ blockRewrite->getConverter());
+ return std::make_pair(ValueRange(), nullptr);
+}
+
+LogicalResult
+OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ DenseMap<Value, SmallVector<Value>> inverseMapping =
+ rewriterImpl.mapping.getInverse();
+
+ // Process requested value replacements.
+ for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+ ValueRange replacedValues;
+ const TypeConverter *converter;
+ std::tie(replacedValues, converter) =
+ getReplacedValues(rewriterImpl.rewrites[i].get());
+ for (Value originalValue : replacedValues) {
+ // If the type of this value changed and the value is still live, we need
+ // to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(originalValue,
+ originalValue.getType()))
continue;
Operation *liveUser =
- findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
+ findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
if (!liveUser)
continue;
- // Legalize this result.
- Value newValue = rewriterImpl.mapping.lookupOrNull(result);
+ // Legalize this value replacement.
+ Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
- /*inputs=*/newValue, /*outputType=*/result.getType(),
- opReplacement->getConverter());
- rewriterImpl.mapping.map(result, castValue);
- inverseMapping[castValue].push_back(result);
- llvm::erase(inverseMapping[newValue], result);
+ MaterializationKind::Source, computeInsertPoint(newValue),
+ originalValue.getLoc(),
+ /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+ converter);
+ rewriterImpl.mapping.map(originalValue, castValue);
+ inverseMapping[castValue].push_back(originalValue);
+ llvm::erase(inverseMapping[newValue], originalValue);
}
}
return success();
}
-LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl) {
- // Functor used to check if all users of a value will be dead after
- // conversion.
- // TODO: This should probably query the inverse mapping, same as in
- // `legalizeConvertedOpResultTypes`.
- auto findLiveUser = [&](Value val) {
- auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
- return rewriterImpl.isOpIgnored(user);
- });
- return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
- };
- // Note: `rewrites` may be reallocated as the loop is running.
- for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
- ++i) {
- auto &rewrite = rewriterImpl.rewrites[i];
- if (auto *blockTypeConversionRewrite =
- dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
- // Process the remapping for each of the original arguments.
- for (Value origArg :
- blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
- // If the type of this argument changed and the argument is still live,
- // we need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
- assert(replacementValue && "replacement value not found");
- Value repl = rewriterImpl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(replacementValue),
- origArg.getLoc(), /*inputs=*/replacementValue,
- /*outputType=*/origArg.getType(),
- blockTypeConversionRewrite->getConverter());
- rewriterImpl.mapping.map(origArg, repl);
- }
- }
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index d8570bdaf4247f..25ec5d0159bd5d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -558,8 +558,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK-LABEL: func @deinterleave_scalar
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
-// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
-// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
+// CHECK-DAG: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
+// CHECK-DAG: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: return %[[CAST0]], %[[CAST1]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/108381
More information about the llvm-branch-commits
mailing list