[Mlir-commits] [mlir] [mlir][sparse] fuse concat and extract_slice op if possible. (PR #89825)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 23 13:36:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/89825.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+83-3) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..9e8998a8a07f35 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
 
 namespace {
 
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// % extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// % extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+    if (!concatOp)
+      return failure();
+
+    Location loc = extractOp.getLoc();
+    int64_t dim = concatOp.getDim();
+    int64_t rank = extractOp.getResultType().getRank();
+
+    SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+    // Compute the partial sums for the slice offsets.
+    AffineExpr sum = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> partialSums = {sum};
+    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+    for (auto [idx, input] :
+         llvm::enumerate(concatOp.getInputs().drop_back())) {
+      sum = sum + rewriter.getAffineDimExpr(idx + 1);
+      partialSums.push_back(sum);
+      offsetStrides.push_back(
+          rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+    }
+    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+                                        partialSums, rewriter.getContext());
+    SmallVector<OpFoldResult> dimOffsets =
+        affine::makeComposedFoldedMultiResultAffineApply(
+            rewriter, loc, partialSumMap, offsetStrides);
+
+    auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+      for (auto [l, r] : llvm::zip(lhs, rhs)) {
+        std::optional<int64_t> staticVal = getConstantIntValue(l);
+        if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+          return false;
+      }
+      return lhs.size() == rhs.size();
+    };
+
+    for (auto [i, input, offset] :
+         llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> srcSizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      srcOffsets[dim] = offset;
+
+      SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+      SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+      SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+      if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+          allEqual(srcStrides, dstStrides)) {
+        Value operand = concatOp.getOperand(i);
+        if (operand.getType() == extractOp.getResultType())
+          rewriter.replaceOp(extractOp, operand);
+        break;
+      }
+    }
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
-      patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,

``````````

</details>


https://github.com/llvm/llvm-project/pull/89825


More information about the Mlir-commits mailing list