[Mlir-commits] [mlir] ea3eeb4 - [mlir][sparse] fuse concat and extract_slice op if possible. (#89825)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 24 13:51:45 PDT 2024
Author: Peiming Liu
Date: 2024-04-24T13:51:41-07:00
New Revision: ea3eeb483fbbe09b9a66ed4c032cc7168f0265dd
URL: https://github.com/llvm/llvm-project/commit/ea3eeb483fbbe09b9a66ed4c032cc7168f0265dd
DIFF: https://github.com/llvm/llvm-project/commit/ea3eeb483fbbe09b9a66ed4c032cc7168f0265dd.diff
LOG: [mlir][sparse] fuse concat and extract_slice op if possible. (#89825)
Added:
mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..5a39dfc6207707 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
namespace {
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+ if (!concatOp)
+ return failure();
+
+ Location loc = extractOp.getLoc();
+ int64_t dim = concatOp.getDim();
+ int64_t rank = extractOp.getResultType().getRank();
+
+ SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+ // Compute the partial sums for the slice offsets.
+ AffineExpr sum = rewriter.getAffineDimExpr(0);
+ SmallVector<AffineExpr> partialSums = {sum};
+ SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+ for (auto [idx, input] :
+ llvm::enumerate(concatOp.getInputs().drop_back())) {
+ sum = sum + rewriter.getAffineDimExpr(idx + 1);
+ partialSums.push_back(sum);
+ offsetStrides.push_back(
+ rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+ }
+ auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+ partialSums, rewriter.getContext());
+ SmallVector<OpFoldResult> dimOffsets =
+ affine::makeComposedFoldedMultiResultAffineApply(
+ rewriter, loc, partialSumMap, offsetStrides);
+
+ auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+ for (auto [l, r] : llvm::zip(lhs, rhs)) {
+ std::optional<int64_t> staticVal = getConstantIntValue(l);
+ if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+ return false;
+ }
+ return lhs.size() == rhs.size();
+ };
+
+ for (auto [i, input, offset] :
+ llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+ SmallVector<OpFoldResult> srcSizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ srcOffsets[dim] = offset;
+
+ SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+ SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+ SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+ if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+ allEqual(srcStrides, dstStrides)) {
+ Value operand = concatOp.getOperand(i);
+ if (operand.getType() == extractOp.getResultType())
+ rewriter.replaceOp(extractOp, operand);
+ break;
+ }
+ }
+
+ return success();
+ }
+};
+
/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
- patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
- GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
- patterns.getContext());
+ patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+ FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+ GenSemiRingSelect, PrintRewriter>(patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir
new file mode 100644
index 00000000000000..5d93301bc8ca76
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_concat_with_extract_slice.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s --pre-sparsification-rewrite | FileCheck %s
+
+#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
+
+
+
+// CHECK-LABEL: func.func @fuse_concat_with_extract(
+// CHECK-SAME: %[[VAL_0:.*0]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME: %[[VAL_1:.*1]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>,
+// CHECK-SAME: %[[VAL_2:.*2]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>)
+// CHECK-NOT: tensor.concat
+// CHECK-NOT: tensor.extract_slice
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+// CHECK: }
+func.func @fuse_concat_with_extract(%t0 : tensor<128x32x32x1xf32, #CCCD>,
+ %t1 : tensor<128x32x32x1xf32, #CCCD>,
+ %t2 : tensor<128x32x32x1xf32, #CCCD>) -> (tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>) {
+ %concat = tensor.concat dim(3) %t0, %t1, %t2 : (tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>) -> tensor<128x32x32x3xf32, #CCCD>
+ %r0 = tensor.extract_slice %concat[0, 0, 0, 0] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
+ %r1 = tensor.extract_slice %concat[0, 0, 0, 1] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
+ %r2 = tensor.extract_slice %concat[0, 0, 0, 2] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
+ return %r0, %r1, %r2 : tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>
+}
More information about the Mlir-commits
mailing list