[Mlir-commits] [flang] [mlir] [mlir][Transforms] Dialect conversion: Fix missing source materialization (PR #97903)
Matthias Springer
llvmlistbot at llvm.org
Sat Jul 13 06:49:38 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/97903
>From d4963ded9d778f8a2ba90f86d7a6be08fee61f78 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 13 Jul 2024 14:32:06 +0200
Subject: [PATCH 1/2] [flang] Remove materialization workaround in type
converter
This change is in preparation of #97903, which adds extra checks for materializations: it is now enforced that they produce an SSA value of the correct type, so the current workaround no longer works.
For `fir.has_value` the fix is simple: no target materializations on the operands are performed if the lowering patterns is initialized without a type converter. For `cg::XEmboxOp`, the existing workaround that skips `unrealized_conversion_cast` ops can be generalized. (This is still a workaround.)
Also remove the lowering pattern for `unrealized_conversion_cast`. This pattern has no effect because `unrealized_conversion_cast` ops that are inserted by the dialect conversion framework are never matched by the pattern driver.
---
flang/include/flang/Tools/CLOptions.inc | 5 ++
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 67 +++++++------------
flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 34 ----------
flang/test/Fir/basic-program.fir | 1 +
4 files changed, 30 insertions(+), 77 deletions(-)
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 7f2910c5cfd3c..7df5044949463 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -9,6 +9,7 @@
/// This file defines some shared command-line options that can be used when
/// debugging the test tools. This file must be included into the tool.
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Pass/PassManager.h"
@@ -223,6 +224,10 @@ inline void addFIRToLLVMPass(
options.forceUnifiedTBAATree = useOldAliasTags;
addPassConditionally(pm, disableFirToLlvmIr,
[&]() { return fir::createFIRToLLVMPass(options); });
+ // The dialect conversion framework may leave dead unrealized_conversion_cast
+ // ops behind, so run reconcile-unrealized-casts to clean them up.
+ addPassConditionally(pm, disableFirToLlvmIr,
+ [&]() { return mlir::createReconcileUnrealizedCastsPass(); });
}
inline void addLLVMDialectToLLVMPass(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 7483acfcd1ca7..11535f073c807 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -35,7 +35,6 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
-#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
@@ -2042,13 +2041,13 @@ struct ExtractValueOpConversion
/// InsertValue is the generalized instruction for the composition of new
/// aggregate type values.
struct InsertValueOpConversion
- : public fir::FIROpAndTypeConversion<fir::InsertValueOp>,
+ : public mlir::OpConversionPattern<fir::InsertValueOp>,
public ValueOpCommon {
- using FIROpAndTypeConversion::FIROpAndTypeConversion;
+ using OpConversionPattern::OpConversionPattern;
llvm::LogicalResult
- doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ matchAndRewrite(fir::InsertValueOp insertVal, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ValueRange operands = adaptor.getOperands();
auto indices = collectIndices(rewriter, insertVal.getCoor());
toRowMajor(indices, operands[0].getType());
@@ -2669,8 +2668,9 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
};
/// Lower `fir.has_value` operation to `llvm.return` operation.
-struct HasValueOpConversion : public fir::FIROpConversion<fir::HasValueOp> {
- using FIROpConversion::FIROpConversion;
+struct HasValueOpConversion
+ : public mlir::OpConversionPattern<fir::HasValueOp> {
+ using OpConversionPattern::OpConversionPattern;
llvm::LogicalResult
matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
@@ -3515,29 +3515,6 @@ struct MustBeDeadConversion : public fir::FIROpConversion<FromOp> {
}
};
-struct UnrealizedConversionCastOpConversion
- : public fir::FIROpConversion<mlir::UnrealizedConversionCastOp> {
- using FIROpConversion::FIROpConversion;
-
- llvm::LogicalResult
- matchAndRewrite(mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
- assert(op.getOutputs().getTypes().size() == 1 && "expect a single type");
- mlir::Type convertedType = convertType(op.getOutputs().getTypes()[0]);
- if (convertedType == adaptor.getInputs().getTypes()[0]) {
- rewriter.replaceOp(op, adaptor.getInputs());
- return mlir::success();
- }
-
- convertedType = adaptor.getInputs().getTypes()[0];
- if (convertedType == op.getOutputs().getType()[0]) {
- rewriter.replaceOp(op, adaptor.getInputs());
- return mlir::success();
- }
- return mlir::failure();
- }
-};
-
struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> {
using MustBeDeadConversion::MustBeDeadConversion;
};
@@ -3714,7 +3691,8 @@ class FIRToLLVMLowering
signalPassFailure();
}
- // Run pass to add comdats to functions that have weak linkage on relevant platforms
+ // Run pass to add comdats to functions that have weak linkage on relevant
+ // platforms
if (fir::getTargetTriple(mod).supportsCOMDAT()) {
mlir::OpPassManager comdatPM("builtin.module");
comdatPM.addPass(mlir::LLVM::createLLVMAddComdats());
@@ -3789,16 +3767,19 @@ void fir::populateFIRToLLVMConversionPatterns(
DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
- GlobalOpConversion, HasValueOpConversion, InsertOnRangeOpConversion,
- InsertValueOpConversion, IsPresentOpConversion, LenParamIndexOpConversion,
- LoadOpConversion, MulcOpConversion, NegcOpConversion,
- NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion,
- SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
- ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
- StoreOpConversion, StringLitOpConversion, SubcOpConversion,
- TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
- UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
- UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
- XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter,
- options);
+ GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
+ LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
+ NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
+ SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
+ ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
+ SliceOpConversion, StoreOpConversion, StringLitOpConversion,
+ SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
+ UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
+ UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
+ XReboxOpConversion, ZeroOpConversion>(converter, options);
+
+ // Patterns that are populated without a type converter do not trigger
+ // target materializations for the operands of the root op.
+ patterns.insert<HasValueOpConversion, InsertValueOpConversion>(
+ patterns.getContext());
}
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index ce86c625e082f..7b46a1a92142b 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -122,40 +122,6 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
// Convert it here to i1 just in case it survives.
return mlir::IntegerType::get(&getContext(), 1);
});
- // FIXME: https://reviews.llvm.org/D82831 introduced an automatic
- // materialization of conversion around function calls that is not working
- // well with fir lowering to llvm (incorrect llvm.mlir.cast are inserted).
- // Workaround until better analysis: register a handler that does not insert
- // any conversions.
- addSourceMaterialization(
- [&](mlir::OpBuilder &builder, mlir::Type resultType,
- mlir::ValueRange inputs,
- mlir::Location loc) -> std::optional<mlir::Value> {
- if (inputs.size() != 1)
- return std::nullopt;
- return inputs[0];
- });
- // Similar FIXME workaround here (needed for compare.fir/select-type.fir
- // as well as rebox-global.fir tests). This is needed to cope with the
- // the fact that codegen does not lower some operation results to the LLVM
- // type produced by this LLVMTypeConverter. For instance, inside FIR
- // globals, fir.box are lowered to llvm.struct, while the fir.box type
- // conversion translates it into an llvm.ptr<llvm.struct<>> because
- // descriptors are manipulated in memory outside of global initializers
- // where this is not possible. Hence, MLIR inserts
- // builtin.unrealized_conversion_cast after the translation of operations
- // producing fir.box in fir.global codegen. addSourceMaterialization and
- // addTargetMaterialization allow ignoring these ops and removing them
- // after codegen assuming the type discrepencies are intended (like for
- // fir.box inside globals).
- addTargetMaterialization(
- [&](mlir::OpBuilder &builder, mlir::Type resultType,
- mlir::ValueRange inputs,
- mlir::Location loc) -> std::optional<mlir::Value> {
- if (inputs.size() != 1)
- return std::nullopt;
- return inputs[0];
- });
}
// i32 is used here because LLVM wants i32 constants when indexing into struct
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 7bbfd709b0aaf..dda4f32872fef 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -119,4 +119,5 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated
// PASSES-NEXT: TargetRewrite
// PASSES-NEXT: FIRToLLVMLowering
+// PASSES-NEXT: ReconcileUnrealizedCasts
// PASSES-NEXT: LLVMIRLoweringPass
>From d8a0ebeda89b3fda310447616449db1cffd017f4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 6 Jul 2024 14:28:41 +0200
Subject: [PATCH 2/2] fix test
---
mlir/docs/DialectConversion.md | 3 +-
.../SCF/TransformOps/SCFTransformOps.td | 11 +++
.../mlir/Transforms/DialectConversion.h | 10 +--
.../Conversion/LLVMCommon/TypeConverter.cpp | 28 ++++--
.../Dialect/SCF/TransformOps/CMakeLists.txt | 1 +
.../SCF/TransformOps/SCFTransformOps.cpp | 13 ++-
.../Transforms/Utils/DialectConversion.cpp | 86 +++++++++----------
.../FuncToLLVM/func-memref-return.mlir | 4 +-
.../Transforms/test-block-legalization.mlir | 44 ++++++++++
9 files changed, 140 insertions(+), 60 deletions(-)
create mode 100644 mlir/test/Transforms/test-block-legalization.mlir
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index db26e6477d5fc..23e74470a835f 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -352,7 +352,8 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value.
+ /// a signature conversion of a single block argument, to a single SSA value
+ /// with the old argument type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 7bf914f6456ce..20880d94a83ca 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
+ "apply_conversion_patterns.scf.scf_to_control_flow",
+ [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+ let description = [{
+ Collects patterns that lower structured control flow ops to unstructured
+ control flow.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a22f198bdf252..a51b00271f0ae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -174,15 +174,15 @@ class TypeConverter {
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
- /// must return a Value of the converted type on success, an `std::nullopt` if
+ /// must return a Value of the type `T` on success, an `std::nullopt` if
/// it failed but other materialization can be attempted, and `nullptr` on
- /// unrecoverable failure. It will only be called for (sub)types of `T`.
- /// Materialization functions must be provided when a type conversion may
- /// persist after the conversion has finished.
+ /// unrecoverable failure. Materialization functions must be provided when a
+ /// type conversion may persist after the conversion has finished.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value.
+ /// a signature conversion of a single block argument, to a single SSA value
+ /// with the old block argument type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index f5620a6a7cd91..32d02d5e438bd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
- // Materialization for memrefs creates descriptor structs from individual
- // values constituting them, when descriptors are used, i.e. more than one
- // value represents a memref.
+ // Argument materializations convert from the new block argument types
+ // (multiple SSA values that make up a memref descriptor) back to the
+ // original block argument type. The dialect conversion framework will then
+ // insert a target materialization from the original block argument type to
+ // a legal type.
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
@@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
}
- return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
- inputs);
+ Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
+ resultType, inputs);
+ // An argument materialization must return a value of type
+ // `resultType`, so insert a cast from the memref descriptor type
+ // (!llvm.struct) to the original memref type.
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+ .getResult(0);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
+ Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
@@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return std::nullopt;
- return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+ desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
+ } else {
+ desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
}
- return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ // An argument materialization must return a value of type `resultType`,
+ // so insert a cast from the memref descriptor type (!llvm.struct) to the
+ // original memref type.
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+ .getResult(0);
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
index 1d6f9ebd153f0..06bccab80e7d8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
MLIRIR
MLIRLoopLikeInterface
MLIRSCFDialect
+ MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRSCFUtils
MLIRTransformDialect
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 56ff2709a589e..c4a55c302d0a3 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
conversionTarget);
}
+void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ populateSCFToControlFlowConversionPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//
@@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp,
return 1;
};
- std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
- std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubConstant =
+ getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> lbConstant =
+ getConstantIntValue(forOp.getLowerBound());
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e6c0ee2ab2949..4a3c11f398e7d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target,
- Type origOutputType = nullptr)
+ MaterializationKind kind = MaterializationKind::Target)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), origOutputType(origOutputType) {}
+ converterAndKind(converter, kind) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}
- /// Return the original illegal output type of the input values.
- Type getOrigOutputType() const { return origOutputType; }
-
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
converterAndKind;
-
- /// The original output type. This is only used for argument conversions.
- Type origOutputType;
};
} // namespace
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Block *insertBlock,
Block::iterator insertPt, Location loc,
ValueRange inputs, Type outputType,
- Type origOutputType,
const TypeConverter *converter);
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
- Type origOutputType,
Type outputType,
const TypeConverter *converter);
@@ -1388,20 +1379,27 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
+ mapping.map(origArg, newArg);
} else {
- Type origOutputType = origArg.getType();
-
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
-
- newArg = buildUnresolvedArgumentMaterialization(
- newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter);
+ // Build argument materialization: new block arguments -> old block
+ // argument type.
+ Value argMat = buildUnresolvedArgumentMaterialization(
+ newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
+ mapping.map(origArg, argMat);
+
+ // Build target materialization: old block argument type -> legal type.
+ // Note: This function returns an "empty" type if no valid conversion to
+ // a legal type exists. In that case, we continue the conversion with the
+ // original block argument type.
+ if (Type legalOutputType = converter->convertType(origArg.getType())) {
+ newArg = buildUnresolvedTargetMaterialization(
+ origArg.getLoc(), argMat, legalOutputType, converter);
+ mapping.map(argMat, newArg);
+ } else {
+ newArg = argMat;
+ }
}
- mapping.map(origArg, newArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
@@ -1424,7 +1422,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
- Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+ Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1435,16 +1433,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
OpBuilder builder(insertBlock, insertPt);
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- origOutputType);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
- Block *block, Location loc, ValueRange inputs, Type origOutputType,
- Type outputType, const TypeConverter *converter) {
+ Block *block, Location loc, ValueRange inputs, Type outputType,
+ const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
- origOutputType, converter);
+ converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
@@ -1456,7 +1453,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
return buildUnresolvedMaterialization(MaterializationKind::Target,
insertBlock, insertPt, loc, input,
- outputType, outputType, converter);
+ outputType, converter);
}
//===----------------------------------------------------------------------===//
@@ -2672,6 +2669,9 @@ static void computeNecessaryMaterializations(
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping,
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
+ // Helper function to check if the given value or a not yet materialized
+ // replacement of the given value is live.
+ // Note: `inverseMapping` maps from replaced values to original values.
auto isLive = [&](Value value) {
auto findFn = [&](Operation *user) {
auto matIt = materializationOps.find(user);
@@ -2679,12 +2679,18 @@ static void computeNecessaryMaterializations(
return !necessaryMaterializations.count(matIt->second);
return rewriterImpl.isOpIgnored(user);
};
- // This value may be replacing another value that has a live user.
- for (Value inv : inverseMapping.lookup(value))
- if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
+ // A worklist is needed because a value may have gone through a chain of
+ // replacements and each of the replaced values may have live users.
+ SmallVector<Value> worklist;
+ worklist.push_back(value);
+ while (!worklist.empty()) {
+ Value next = worklist.pop_back_val();
+ if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
return true;
- // Or have live users itself.
- return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
+ // This value may be replacing another value that has a live user.
+ llvm::append_range(worklist, inverseMapping.lookup(next));
+ }
+ return false;
};
llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2850,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
switch (mat.getMaterializationKind()) {
case MaterializationKind::Argument:
// Try to materialize an argument conversion.
- // FIXME: The current argument materialization hook expects the original
- // output type, even though it doesn't use that as the actual output type
- // of the generated IR. The output type is just used as an indicator of
- // the type of materialization to do. This behavior is really awkward in
- // that it diverges from the behavior of the other hooks, and can be
- // easily misunderstood. We should clean up the argument hooks to better
- // represent the desired invariants we actually care about.
newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
+ rewriter, op->getLoc(), outputType, inputOperands);
if (newMaterialization)
break;
-
// If an argument materialization failed, fallback to trying a target
// materialization.
[[fallthrough]];
@@ -2865,6 +2863,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
break;
}
if (newMaterialization) {
+ assert(newMaterialization.getType() == outputType &&
+ "materialization callback produced value of incorrect type");
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
inverseMapping);
return success();
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 91ef571cb3bf7..6b9df32fe02dd 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s
-// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
-// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
+// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
// These tests were separated from func-memref.mlir because applying
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
diff --git a/mlir/test/Transforms/test-block-legalization.mlir b/mlir/test/Transforms/test-block-legalization.mlir
new file mode 100644
index 0000000000000..d739f95a56947
--- /dev/null
+++ b/mlir/test/Transforms/test-block-legalization.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @complex_block_signature_conversion(
+// CHECK: %[[cst:.*]] = complex.constant
+// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
+// Note: Some blocks are omitted.
+// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]]
+// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
+// CHECK: llvm.br ^[[block2:.*]]
+// CHECK: ^[[block2]]:
+// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
+func.func @complex_block_signature_conversion() {
+ %cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
+ %true = arith.constant true
+ %0 = scf.if %true -> complex<f64> {
+ scf.yield %cst : complex<f64>
+ } else {
+ scf.yield %cst : complex<f64>
+ }
+
+ // Regression test to ensure that the a source materialization is inserted.
+ // The operand of "test.consumer_of_complex" must not change.
+ "test.consumer_of_complex"(%0) : (complex<f64>) -> ()
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %toplevel_module
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_conversion_patterns to %func {
+ transform.apply_conversion_patterns.dialect_to_llvm "cf"
+ transform.apply_conversion_patterns.func.func_to_llvm
+ transform.apply_conversion_patterns.scf.scf_to_control_flow
+ } with type_converter {
+ transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+ } {
+ legal_dialects = ["llvm"],
+ partial_conversion
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list