[Mlir-commits] [mlir] 28ebb0b - [mlir][sparse] migrate sparse rewriting to sparse transformations pass
Aart Bik
llvmlistbot at llvm.org
Mon Jul 18 09:29:30 PDT 2022
Author: Aart Bik
Date: 2022-07-18T09:29:22-07:00
New Revision: 28ebb0b61d110e4b108fc1ebcbc43d50fff8f087
URL: https://github.com/llvm/llvm-project/commit/28ebb0b61d110e4b108fc1ebcbc43d50fff8f087
DIFF: https://github.com/llvm/llvm-project/commit/28ebb0b61d110e4b108fc1ebcbc43d50fff8f087.diff
LOG: [mlir][sparse] migrate sparse rewriting to sparse transformations pass
The rules in the linalg file were very specific to sparse tensors so will
find a better home under sparse tensor dialect than linalg dialect. Also
moved some rewriting from sparsification into this new "pre-rewriting" file.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D129910
Added:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index b7300dda22fae..71afddcb49245 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -134,12 +134,19 @@ void populateSparseTensorConversionPatterns(
const SparseTensorConversionOptions &options =
SparseTensorConversionOptions());
-std::unique_ptr<Pass> createDenseBufferizationPass(
- const bufferization::OneShotBufferizationOptions &options);
std::unique_ptr<Pass> createSparseTensorConversionPass();
std::unique_ptr<Pass>
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
+//===----------------------------------------------------------------------===//
+// Other rewriting rules and passes.
+//===----------------------------------------------------------------------===//
+
+void populateSparseTensorRewriting(RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createDenseBufferizationPass(
+ const bufferization::OneShotBufferizationOptions &options);
+
//===----------------------------------------------------------------------===//
// Registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5bc2740afbe07..a8112dbe50b84 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,7 +22,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
LinalgStrategyPasses.cpp
NamedOpConversions.cpp
Promotion.cpp
- SparseTensorRewriting.cpp
Split.cpp
SplitReduction.cpp
Tiling.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index b958046be2c48..51f016495fc30 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1717,12 +1717,8 @@ struct LinalgElementwiseOpFusionPass
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
-
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
- // Add the sparse tensor rewriting patterns.
- populateSparseTensorRewriting(patterns);
-
// General canonicalization patterns.
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 59e62209c2bd1..8c7639baaffcd 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -52,11 +52,6 @@ void mlir::sparse_tensor::buildSparseCompiler(
OpPassManager &pm, const SparseCompilerOptions &options) {
// TODO(wrengr): ensure the original `pm` is for ModuleOp
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
- // TODO(springerm): Reactivate element-wise op fusion pass. This pass does not
- // fit well with bufferization because it replaces unused "out" operands of
- // LinalgOps with InitTensorOps. This would result in additional buffer
- // allocations during bufferization.
- // pm.addPass(createLinalgElementwiseOpFusionPass());
pm.addPass(
bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
/*analysisOnly=*/options.testBufferizationAnalysisOnly)));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 76bd31691dfd9..9d99d2f7a5c8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Sparsification.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
+ SparseTensorRewriting.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 9d94e5b72e933..5fcb44a98de3a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -49,13 +49,17 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
void runOnOperation() override {
auto *ctx = &getContext();
- RewritePatternSet patterns(ctx);
+ // Apply pre-rewriting.
+ RewritePatternSet prePatterns(ctx);
+ populateSparseTensorRewriting(prePatterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(prePatterns));
// Translate strategy flags to strategy options.
SparsificationOptions options(
sparseParallelizationStrategy(parallelization),
sparseVectorizationStrategy(vectorization), vectorLength,
enableSIMDIndex32, enableVLAVectorization);
- // Apply rewriting.
+ // Apply sparsification and vector cleanup rewriting.
+ RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
similarity index 79%
rename from mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
rename to mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 8494012315907..3dc97915a683f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -6,20 +6,16 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements linalg dialect rewriting specific to sparse tensors.
-//
-// Sparsity should be mostly transparent to the linalg dialect optimizations
-// (i.e., the dense and sparse take the same path). However, in some cases,
-// optimizations only make sense in the context of sparse tensors. This file
-// implements such sparsity specific rewriting rules.
+// This file implements rewriting rules that are specific to sparse tensors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
@@ -98,6 +94,7 @@ static bool isSumOfMul(GenericOp op) {
//===---------------------------------------------------------------------===//
namespace {
+
/// Rewriting rule that converts two kernels:
///
/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
@@ -114,6 +111,7 @@ namespace {
/// a fusion may actually reduce the asymptotic complexity of the kernel,
/// since intermediate results may be nullified.
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
+public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
@@ -194,13 +192,55 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
}
};
+
+/// Sparse rewriting rule for reshape operator.
+template <typename ReshapeOp>
+struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
+public:
+ using OpRewritePattern<ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReshapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto encDst = getSparseTensorEncoding(op.getResult().getType());
+ auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
+ // Since a pure dense expansion is very cheap (change of view), for
+ // a sparse2dense or dense2sparse, we can simply unfuse a sparse
+ // conversion from the reshape operation itself.
+ // All other cases are handled elsewhere.
+ if (encDst && encSrc) {
+ return failure();
+ } else if (encSrc) {
+ RankedTensorType rtp =
+ op.getSrc().getType().template cast<RankedTensorType>();
+ auto denseTp =
+ RankedTensorType::get(rtp.getShape(), rtp.getElementType());
+ auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
+ op->setOperand(0, convert);
+ return success();
+ } else if (encDst) {
+ RankedTensorType rtp =
+ op.getResult().getType().template cast<RankedTensorType>();
+ auto denseTp =
+ RankedTensorType::get(rtp.getShape(), rtp.getElementType());
+ auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+ op.getReassociation());
+ Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
+ rewriter.replaceOp(op, convert);
+ return success();
+ }
+ return failure();
+ }
+};
+
} // namespace
//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//
-void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
- auto *context = patterns.getContext();
- patterns.add<FuseSparseMultiplyOverAdd>(context);
+void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
+ // TODO(springerm): enable FuseSparseMultiplyOverAdd
+ patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 53182243ab84c..7121438ddad6b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1802,46 +1802,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
SparsificationOptions options;
};
-/// Sparse rewriting rule for reshape operator.
-template <typename ReshapeOp>
-struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
-public:
- using OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ReshapeOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- auto encDst = getSparseTensorEncoding(op.getResult().getType());
- auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
- // Since a pure dense expansion is very cheap (change of view), for
- // a sparse2dense or dense2sparse, we can simply unfuse a sparse
- // conversion from the reshape operation itself.
- // All other cases are handled elsewhere.
- if (encDst && encSrc) {
- return failure();
- } else if (encSrc) {
- RankedTensorType rtp =
- op.getSrc().getType().template cast<RankedTensorType>();
- auto denseTp =
- RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
- op->setOperand(0, convert);
- return success();
- } else if (encDst) {
- RankedTensorType rtp =
- op.getResult().getType().template cast<RankedTensorType>();
- auto denseTp =
- RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
- op.getReassociation());
- Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
- rewriter.replaceOp(op, convert);
- return success();
- }
- return failure();
- }
-};
-
} // namespace
/// Populates the given patterns list with rewriting rules required for
@@ -1849,6 +1809,4 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
void mlir::populateSparsificationPatterns(
RewritePatternSet &patterns, const SparsificationOptions &options) {
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
- patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5058e45a70368..73a17d1c60bb1 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2115,6 +2115,7 @@ cc_library(
":SparseTensorDialect",
":SparseTensorPassIncGen",
":SparseTensorUtils",
+ ":Support",
":TensorDialect",
":Transforms",
":VectorDialect",
More information about the Mlir-commits
mailing list