[Mlir-commits] [mlir] 7267c85 - [mlir][Func] Delete `DecomposeCallGraphTypes.cpp` (#117424)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 2 16:02:06 PST 2024
Author: Matthias Springer
Date: 2024-12-02T16:02:03-08:00
New Revision: 7267c85959aa2490e2950f7fb817a76af7e94043
URL: https://github.com/llvm/llvm-project/commit/7267c85959aa2490e2950f7fb817a76af7e94043
DIFF: https://github.com/llvm/llvm-project/commit/7267c85959aa2490e2950f7fb817a76af7e94043.diff
LOG: [mlir][Func] Delete `DecomposeCallGraphTypes.cpp` (#117424)
`DecomposeCallGraphTypes.cpp` was a workaround around missing 1:N
support in the dialect conversion. Now that 1:N support was added, the
workaround can be deleted. The test remains in place, as an example for
how to write such a transformation with the dialect conversion
framework.
Note for LLVM integration: If you are using
`DecomposeCallGraphTypes.cpp`, switch to the patterns that are used in
`TestDecomposeCallGraphTypes.cpp`.
Added:
Modified:
mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
Removed:
mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
deleted file mode 100644
index 1be406bf3adf92..00000000000000
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ /dev/null
@@ -1,34 +0,0 @@
-//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Conversion patterns for decomposing types along call graph edges. That is,
-// decomposing types for calls, returns, and function args.
-//
-// TODO: Make this handle dialect-defined functions, calls, and returns.
-// Currently, the generic interfaces aren't sophisticated enough for the
-// types of mutations that we are doing here.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-
-#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
-
-namespace mlir {
-
-/// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `TypeConverter`.
-void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
- const TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index f8fb1f436a95b1..6384d25ee70273 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRFuncTransforms
- DecomposeCallGraphTypes.cpp
DuplicateFunctionElimination.cpp
FuncConversions.cpp
OneToNFuncConversions.cpp
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
deleted file mode 100644
index 03be00328bda33..00000000000000
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ /dev/null
@@ -1,136 +0,0 @@
-//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-
-using namespace mlir;
-using namespace mlir::func;
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForFuncArgs
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand function arguments according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForFuncArgs
- : public OpConversionPattern<func::FuncOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- auto functionType = op.getFunctionType();
-
- // Convert function arguments using the provided TypeConverter.
- TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
- for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
- SmallVector<Type, 2> decomposedTypes;
- if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
- return failure();
- if (!decomposedTypes.empty())
- conversion.addInputs(argType.index(), decomposedTypes);
- }
-
- // If the SignatureConversion doesn't apply, bail out.
- if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
- &conversion)))
- return failure();
-
- // Update the signature of the function.
- SmallVector<Type, 2> newResultTypes;
- if (failed(typeConverter->convertTypes(functionType.getResults(),
- newResultTypes)))
- return failure();
- rewriter.modifyOpInPlace(op, [&] {
- op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
- newResultTypes));
- });
- return success();
- }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForReturnOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand return operands according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForReturnOp
- : public OpConversionPattern<ReturnOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- SmallVector<Value, 2> newOperands;
- for (ValueRange operand : adaptor.getOperands())
- llvm::append_range(newOperands, operand);
- rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
- return success();
- }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand call op operands and results according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
-
- // Create the operands list of the new `CallOp`.
- SmallVector<Value, 2> newOperands;
- for (ValueRange operand : adaptor.getOperands())
- llvm::append_range(newOperands, operand);
-
- // Create the new result types for the new `CallOp` and track the number of
- // replacement types for each original op result.
- SmallVector<Type, 2> newResultTypes;
- SmallVector<unsigned> expandedResultSizes;
- for (Type resultType : op.getResultTypes()) {
- unsigned oldSize = newResultTypes.size();
- if (failed(typeConverter->convertType(resultType, newResultTypes)))
- return failure();
- expandedResultSizes.push_back(newResultTypes.size() - oldSize);
- }
-
- CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
- newResultTypes, newOperands);
-
- // Build a replacement value for each result to replace its uses.
- SmallVector<ValueRange> replacedValues;
- replacedValues.reserve(op.getNumResults());
- unsigned startIdx = 0;
- for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
- ValueRange repl =
- newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
- replacedValues.push_back(repl);
- startIdx += expandedResultSizes[i];
- }
- rewriter.replaceOpWithMultiple(op, replacedValues);
- return success();
- }
-};
-} // namespace
-
-void mlir::populateDecomposeCallGraphTypesPatterns(
- MLIRContext *context, const TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns
- .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
-}
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 9e7759bef6d8fd..a3638c8766a5c6 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -124,12 +124,10 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- // For a return, all operands go to the results of the parent, so
- // rewrite them all.
- rewriter.modifyOpInPlace(op,
- [&] { op->setOperands(adaptor.getOperands()); });
+ rewriter.replaceOpWithNewOp<ReturnOp>(op,
+ flattenValues(adaptor.getOperands()));
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index de511c58ae6ee0..09c5b4b2a0ad50 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -9,7 +9,7 @@
#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -142,7 +142,10 @@ struct TestDecomposeCallGraphTypes
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);
- populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list