[Mlir-commits] [mlir] [mlir][sparse] fuse concat and extract_slice op if possible. (PR #89825)
Peiming Liu
llvmlistbot at llvm.org
Tue Apr 23 13:35:44 PDT 2024
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/89825
None
>From 69fd4bfa415e6e68ed6854ea98cfdc6cad8a1373 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 23 Apr 2024 17:40:02 +0000
Subject: [PATCH] [mlir][sparse] fuse concat and extract_slice op if possible.
---
.../Transforms/SparseTensorRewriting.cpp | 86 ++++++++++++++++++-
1 file changed, 83 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..9e8998a8a07f35 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
namespace {
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// % extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// % extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+ if (!concatOp)
+ return failure();
+
+ Location loc = extractOp.getLoc();
+ int64_t dim = concatOp.getDim();
+ int64_t rank = extractOp.getResultType().getRank();
+
+ SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+ // Compute the partial sums for the slice offsets.
+ AffineExpr sum = rewriter.getAffineDimExpr(0);
+ SmallVector<AffineExpr> partialSums = {sum};
+ SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+ for (auto [idx, input] :
+ llvm::enumerate(concatOp.getInputs().drop_back())) {
+ sum = sum + rewriter.getAffineDimExpr(idx + 1);
+ partialSums.push_back(sum);
+ offsetStrides.push_back(
+ rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+ }
+ auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+ partialSums, rewriter.getContext());
+ SmallVector<OpFoldResult> dimOffsets =
+ affine::makeComposedFoldedMultiResultAffineApply(
+ rewriter, loc, partialSumMap, offsetStrides);
+
+ auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+ for (auto [l, r] : llvm::zip(lhs, rhs)) {
+ std::optional<int64_t> staticVal = getConstantIntValue(l);
+ if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+ return false;
+ }
+ return lhs.size() == rhs.size();
+ };
+
+ for (auto [i, input, offset] :
+ llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+ SmallVector<OpFoldResult> srcSizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ srcOffsets[dim] = offset;
+
+ SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+ SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+ SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+ if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+ allEqual(srcStrides, dstStrides)) {
+ Value operand = concatOp.getOperand(i);
+ if (operand.getType() == extractOp.getResultType())
+ rewriter.replaceOp(extractOp, operand);
+ break;
+ }
+ }
+
+ return success();
+ }
+};
+
/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
- patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
- GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
- patterns.getContext());
+ patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+ FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+ GenSemiRingSelect, PrintRewriter>(patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
More information about the Mlir-commits
mailing list