[Mlir-commits] [mlir] [mlir][sparse] fuse concat and extract_slice op if possible. (PR #89825)

Peiming Liu llvmlistbot at llvm.org
Tue Apr 23 13:35:44 PDT 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/89825

None

>From 69fd4bfa415e6e68ed6854ea98cfdc6cad8a1373 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 23 Apr 2024 17:40:02 +0000
Subject: [PATCH] [mlir][sparse] fuse concat and extract_slice op if possible.

---
 .../Transforms/SparseTensorRewriting.cpp      | 86 ++++++++++++++++++-
 1 file changed, 83 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..9e8998a8a07f35 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
 
 namespace {
 
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// % extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// % extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+    if (!concatOp)
+      return failure();
+
+    Location loc = extractOp.getLoc();
+    int64_t dim = concatOp.getDim();
+    int64_t rank = extractOp.getResultType().getRank();
+
+    SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+    // Compute the partial sums for the slice offsets.
+    AffineExpr sum = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> partialSums = {sum};
+    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+    for (auto [idx, input] :
+         llvm::enumerate(concatOp.getInputs().drop_back())) {
+      sum = sum + rewriter.getAffineDimExpr(idx + 1);
+      partialSums.push_back(sum);
+      offsetStrides.push_back(
+          rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+    }
+    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+                                        partialSums, rewriter.getContext());
+    SmallVector<OpFoldResult> dimOffsets =
+        affine::makeComposedFoldedMultiResultAffineApply(
+            rewriter, loc, partialSumMap, offsetStrides);
+
+    auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+      for (auto [l, r] : llvm::zip(lhs, rhs)) {
+        std::optional<int64_t> staticVal = getConstantIntValue(l);
+        if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+          return false;
+      }
+      return lhs.size() == rhs.size();
+    };
+
+    for (auto [i, input, offset] :
+         llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> srcSizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      srcOffsets[dim] = offset;
+
+      SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+      SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+      SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+      if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+          allEqual(srcStrides, dstStrides)) {
+        Value operand = concatOp.getOperand(i);
+        if (operand.getType() == extractOp.getResultType())
+          rewriter.replaceOp(extractOp, operand);
+        break;
+      }
+    }
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
-      patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,



More information about the Mlir-commits mailing list