[Mlir-commits] [mlir] [mlir][draft] Support 1:N dialect conversion (PR #112141)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 13 07:48:34 PDT 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 9f24c145494ee238e65e25205a4dcb4451f009ae 7ec251bc0e69b4611d5acf8884be39d5461eb17b --extensions h,cpp -- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h mlir/include/mlir/Transforms/DialectConversion.h mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/test/lib/Dialect/Test/TestPatterns.cpp mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp mlir/unittests/ExecutionEngine/Invoke.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5285074ec6..f6cb99cfa9 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,55 +153,54 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
-/*
- // Argument materializations convert from the new block argument types
- // (multiple SSA values that make up a memref descriptor) back to the
- // original block argument type. The dialect conversion framework will then
- // insert a target materialization from the original block argument type to
- // a legal type.
- addArgumentMaterialization(
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
- if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
+ /*
+ // Argument materializations convert from the new block argument types
+ // (multiple SSA values that make up a memref descriptor) back to the
+ // original block argument type. The dialect conversion framework will then
+ // insert a target materialization from the original block argument type to
+ // a legal type.
+ addArgumentMaterialization(
+ [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange
+ inputs, Location loc) -> std::optional<Value> { if (inputs.size() == 1) {
+ // Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ return std::nullopt;
+ }
+ Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
+ resultType, inputs);
+ // An argument materialization must return a value of type
+ // `resultType`, so insert a cast from the memref descriptor type
+ // (!llvm.struct) to the original memref type.
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType,
+ desc) .getResult(0);
+ });
+ addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ Value desc;
+ if (inputs.size() == 1) {
+ // This is a bare pointer. We allow bare pointers only for function
+ entry
+ // blocks.
+ BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+ if (!barePtr)
return std::nullopt;
- }
- Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
- resultType, inputs);
- // An argument materialization must return a value of type
- // `resultType`, so insert a cast from the memref descriptor type
- // (!llvm.struct) to the original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
- Value desc;
- if (inputs.size() == 1) {
- // This is a bare pointer. We allow bare pointers only for function entry
- // blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return std::nullopt;
- Block *block = barePtr.getOwner();
- if (!block->isEntryBlock() ||
- !isa<FunctionOpInterface>(block->getParentOp()))
- return std::nullopt;
- desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
- inputs[0]);
- } else {
- desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- }
- // An argument materialization must return a value of type `resultType`,
- // so insert a cast from the memref descriptor type (!llvm.struct) to the
- // original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
-
-*/
+ Block *block = barePtr.getOwner();
+ if (!block->isEntryBlock() ||
+ !isa<FunctionOpInterface>(block->getParentOp()))
+ return std::nullopt;
+ desc = MemRefDescriptor::fromStaticShape(builder, loc, *this,
+ resultType, inputs[0]); } else { desc = MemRefDescriptor::pack(builder, loc,
+ *this, resultType, inputs);
+ }
+ // An argument materialization must return a value of type `resultType`,
+ // so insert a cast from the memref descriptor type (!llvm.struct) to the
+ // original memref type.
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+ .getResult(0);
+ });
+
+ */
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
@@ -211,14 +210,14 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc, Type originalType) -> std::optional<Value> {
+ ValueRange inputs, Location loc,
+ Type originalType) -> std::optional<Value> {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc, Type originalType) -> std::optional<Value> {
+ ValueRange inputs, Location loc,
+ Type originalType) -> std::optional<Value> {
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
if (!originalType) {
llvm::errs() << " -- no orig\n";
@@ -228,8 +227,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
if (inputs.size() == 1) {
Value input = inputs.front();
- if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
- if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
+ if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+ if (castOp.getInputs().size() == 1 &&
+ isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
input = castOp.getInputs()[0];
}
}
@@ -243,23 +243,23 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
!isa<FunctionOpInterface>(block->getParentOp()))
return std::nullopt;
// Bare ptr
- return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
- input);
+ return MemRefDescriptor::fromStaticShape(builder, loc, *this,
+ memrefType, input);
}
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
}
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
- return std::nullopt;
+ // Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ return std::nullopt;
}
- return UnrankedMemRefDescriptor::pack(builder, loc, *this,
- memrefType, inputs);
+ return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
+ inputs);
}
- return std::nullopt;
+ return std::nullopt;
});
// Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index efc2d24ddb..2b78a09b87 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1275,7 +1275,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
- /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter);
+ /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+ currentTypeConverter);
mapping.mapMaterialization(vals, castValues);
remapped.push_back(castValues);
@@ -1430,7 +1431,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter) {
+ ValueRange inputs, TypeRange outputTypes, Type originalType,
+ const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (TypeRange(inputs) == outputTypes)
return inputs;
@@ -1441,7 +1443,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ originalType);
return convertOp.getResults();
}
@@ -1495,7 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
repl = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
- /*outputTypes=*/result.getType(), /*originalType=*/Type(), currentTypeConverter);
+ /*outputTypes=*/result.getType(), /*originalType=*/Type(),
+ currentTypeConverter);
} else {
// Make sure that the user does not mess with unresolved materializations
// that were inserted by the conversion driver. We keep track of these
@@ -2735,8 +2739,8 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(newValues),
originalValue.getLoc(),
- /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(), /*originalType=*/Type(),
- converter)[0];
+ /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(),
+ /*originalType=*/Type(), converter)[0];
rewriterImpl.mapping.mapMaterialization(newValues, {castValue});
llvm::append_range(inverseMapping[castValue], newValues);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 291f6af84b..4465cfa4c1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -994,7 +994,8 @@ struct TestUpdateConsumerType : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n";
+ llvm::errs() << "TestUpdateConsumerType operand: " << operands.front()
+ << "\n";
// Verify that the incoming operand has been successfully remapped to F64.
if (!operands[0].getType().isF64())
return failure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/112141
More information about the Mlir-commits
mailing list