[Mlir-commits] [mlir] [mlir][Tensor] NFC: Move concat operation decomposition as a method of the concat operation. (PR #116004)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 12 22:51:48 PST 2024
https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/116004
Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites.
>From 6fb6c277f09b8c689bf34f713e49bd733a475734 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Tue, 12 Nov 2024 22:47:11 -0800
Subject: [PATCH] [mlir][Tensor] NFC: Move concat operation decomposition as a
method of the concat operation.
Currently the implementation is within a pattern that cannot be used
without a pattern rewriter. Move the decomposition as a method of the
operation to make it usable outside of pattern rewrites.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 3 ++
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 45 ++++++++++++++++
.../Tensor/Transforms/ConcatOpPatterns.cpp | 53 +++----------------
3 files changed, 54 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..b73da8bb6af59c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
int64_t getRank() {
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
}
+
+ // Method to decompose the operation into a sequence of insert_slices.
+ FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8e0d0104397468..dd6c7ebf1d0919 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -615,6 +615,51 @@ LogicalResult ConcatOp::verify() {
return success();
}
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+ size_t numInputs = getInputs().size();
+ uint64_t concatDim = getDim();
+
+ SmallVector<SmallVector<OpFoldResult>> inputShapes;
+ inputShapes.reserve(numInputs);
+ SmallVector<OpFoldResult> concatOffsets;
+ concatOffsets.reserve(numInputs);
+ SmallVector<OpFoldResult> outputShape;
+
+ AffineExpr addExpr =
+ builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+ OpFoldResult zero = builder.getIndexAttr(0);
+ Location loc = getLoc();
+ for (auto [index, input] : llvm::enumerate(getInputs())) {
+ SmallVector<OpFoldResult> inputShape =
+ tensor::getMixedSizes(builder, input.getLoc(), input);
+ if (index == 0) {
+ outputShape = inputShape;
+ concatOffsets.push_back(zero);
+ } else {
+ concatOffsets.push_back(outputShape[concatDim]);
+ outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+ builder, loc, addExpr,
+ {outputShape[concatDim], inputShape[concatDim]});
+ }
+ inputShapes.emplace_back(std::move(inputShape));
+ }
+
+ Value replacement = builder.create<tensor::EmptyOp>(
+ loc, outputShape, getType().getElementType());
+
+ int64_t rank = getType().getRank();
+ OpFoldResult one = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(rank, one);
+ SmallVector<OpFoldResult> offsets(rank, zero);
+ for (auto [index, input] : llvm::enumerate(getInputs())) {
+ offsets[concatDim] = concatOffsets[index];
+ auto insertSlice = builder.create<tensor::InsertSliceOp>(
+ loc, input, replacement, offsets, inputShapes[index], strides);
+ replacement = insertSlice.getResult();
+ }
+ return SmallVector<Value>{replacement};
+}
+
LogicalResult
ConcatOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 7c8403c9609d84..a2a860fcb38abb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
- Location loc = concatOp.getLoc();
- FailureOr<Value> dest =
- tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
- if (failed(dest))
- return failure();
-
- auto empty = dest->getDefiningOp<tensor::EmptyOp>();
- if (!empty)
- return failure();
-
- int64_t dim = concatOp.getDim();
- Value dimValue =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
-
- int64_t rank = concatOp.getResultType().getRank();
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-
- // Compute the partial sums for the slice offsets.
- AffineExpr sum = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> partialSums = {sum};
- SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
- for (auto [idx, input] :
- llvm::enumerate(concatOp.getInputs().drop_back())) {
- sum = sum + rewriter.getAffineDimExpr(idx + 1);
- partialSums.push_back(sum);
- offsetStrides.push_back(
- rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
+ FailureOr<SmallVector<Value>> decomposed =
+ concatOp.decomposeOperation(rewriter);
+ if (failed(decomposed)) {
+ return rewriter.notifyMatchFailure(
+ concatOp, "failed to get the decomposed insert slices");
}
- auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
- partialSums, rewriter.getContext());
- SmallVector<OpFoldResult> dimOffsets =
- affine::makeComposedFoldedMultiResultAffineApply(
- rewriter, loc, partialSumMap, offsetStrides);
-
- // Construct the chain of insert_slice ops into the destination.
- Value result = *dest;
- for (auto [input, offset] :
- llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
- SmallVector<OpFoldResult> sizes =
- tensor::getMixedSizes(rewriter, loc, input);
- offsets[dim] = offset;
- result = rewriter.createOrFold<tensor::InsertSliceOp>(
- loc, input, result, offsets, sizes, strides);
- }
-
- rewriter.replaceOpWithNewOp<tensor::CastOp>(
- concatOp, concatOp.getResultType(), result);
+ rewriter.replaceOp(concatOp, decomposed.value()[0]);
return success();
}
};
More information about the Mlir-commits
mailing list