[Mlir-commits] [mlir] [mlir][Tensor] Move concat operation decomposition as a method of the concat operation. (PR #116004)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Nov 13 12:30:23 PST 2024
================
@@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() {
return success();
}
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+ size_t numInputs = getInputs().size();
+ uint64_t concatDim = getDim();
+
+ SmallVector<SmallVector<OpFoldResult>> inputShapes;
+ inputShapes.reserve(numInputs);
+ SmallVector<OpFoldResult> concatOffsets;
+ concatOffsets.reserve(numInputs);
+ SmallVector<OpFoldResult> outputShape;
+
+ AffineExpr addExpr =
+ builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+ OpFoldResult zero = builder.getIndexAttr(0);
+ Location loc = getLoc();
+ for (auto [index, input] : llvm::enumerate(getInputs())) {
+ SmallVector<OpFoldResult> inputShape =
+ tensor::getMixedSizes(builder, input.getLoc(), input);
+ if (index == 0) {
+ outputShape = inputShape;
+ concatOffsets.push_back(zero);
+ } else {
+ concatOffsets.push_back(outputShape[concatDim]);
+ outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+ builder, loc, addExpr,
+ {outputShape[concatDim], inputShape[concatDim]});
+ }
+ inputShapes.emplace_back(std::move(inputShape));
+ }
+
+ Value replacement = builder.create<tensor::EmptyOp>(
+ loc, outputShape, getType().getElementType());
+
+ int64_t rank = getType().getRank();
+ OpFoldResult one = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(rank, one);
+ SmallVector<OpFoldResult> offsets(rank, zero);
+ for (auto [index, input] : llvm::enumerate(getInputs())) {
+ offsets[concatDim] = concatOffsets[index];
+ auto insertSlice = builder.create<tensor::InsertSliceOp>(
+ loc, input, replacement, offsets, inputShapes[index], strides);
+ replacement = insertSlice.getResult();
+ }
+ if (replacement.getType() != getType()) {
+ replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
+ }
----------------
hanhanW wrote:
I think the main difference is that the `getOrCreateDestination` "infers" the static shape when possible. We can get rid of the `tensor.cast ops` if we use the method. Why do you like the current way better?
https://github.com/llvm/llvm-project/pull/116004
More information about the Mlir-commits
mailing list