[Mlir-commits] [mlir] [mlir][sparse] fuse concat and extract_slice op if possible. (PR #89825)

Peiming Liu llvmlistbot at llvm.org
Wed Apr 24 13:51:30 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89825

>From 0d33540f592e348d47aeb663b0794c9340f0e4c8 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 23 Apr 2024 17:40:02 +0000
Subject: [PATCH] [mlir][sparse] fuse concat and extract_slice op if possible.

---
 .../Transforms/SparseTensorRewriting.cpp      | 86 ++++++++++++++++++-
 ...fuse_sparse_concat_with_extract_slice.mlir | 23 +++++
 2 files changed, 106 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..5a39dfc6207707 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
 
 namespace {
 
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+    if (!concatOp)
+      return failure();
+
+    Location loc = extractOp.getLoc();
+    int64_t dim = concatOp.getDim();
+    int64_t rank = extractOp.getResultType().getRank();
+
+    SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+    // Compute the partial sums for the slice offsets.
+    AffineExpr sum = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> partialSums = {sum};
+    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+    for (auto [idx, input] :
+         llvm::enumerate(concatOp.getInputs().drop_back())) {
+      sum = sum + rewriter.getAffineDimExpr(idx + 1);
+      partialSums.push_back(sum);
+      offsetStrides.push_back(
+          rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+    }
+    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+                                        partialSums, rewriter.getContext());
+    SmallVector<OpFoldResult> dimOffsets =
+        affine::makeComposedFoldedMultiResultAffineApply(
+            rewriter, loc, partialSumMap, offsetStrides);
+
+    auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+      for (auto [l, r] : llvm::zip(lhs, rhs)) {
+        std::optional<int64_t> staticVal = getConstantIntValue(l);
+        if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+          return false;
+      }
+      return lhs.size() == rhs.size();
+    };
+
+    for (auto [i, input, offset] :
+         llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> srcSizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      srcOffsets[dim] = offset;
+
+      SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+      SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+      SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+      if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+          allEqual(srcStrides, dstStrides)) {
+        Value operand = concatOp.getOperand(i);
+        if (operand.getType() == extractOp.getResultType())
+          rewriter.replaceOp(extractOp, operand);
+        break;
+      }
+    }
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
-      patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir
new file mode 100644
index 00000000000000..5d93301bc8ca76
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite | FileCheck %s
+
+#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+
+
+// CHECK-LABEL: func.func @fuse_concat_with_extract(
+// CHECK-SAME:    %[[VAL_0:.*0]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:    %[[VAL_1:.*1]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME:    %[[VAL_2:.*2]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>)
+// CHECK-NOT:     tensor.concat
+// CHECK-NOT:     tensor.extract_slice
+// CHECK:         return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+// CHECK:       }
+func.func @fuse_concat_with_extract(%t0 : tensor<128x32x32x1xf32, #CCCD>,
+                                    %t1 : tensor<128x32x32x1xf32, #CCCD>,
+                                    %t2 : tensor<128x32x32x1xf32, #CCCD>) -> (tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>) {
+  %concat = tensor.concat dim(3) %t0, %t1, %t2 : (tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>) -> tensor<128x32x32x3xf32, #CCCD>
+  %r0 = tensor.extract_slice %concat[0, 0, 0, 0] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
+  %r1 = tensor.extract_slice %concat[0, 0, 0, 1] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
+  %r2 = tensor.extract_slice %concat[0, 0, 0, 2] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
+  return %r0, %r1, %r2 : tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>
+}



More information about the Mlir-commits mailing list