[Mlir-commits] [mlir] [mlir][Tensor] Move concat operation decomposition as a method of the concat operation. (PR #116004)

Han-Chung Wang llvmlistbot at llvm.org
Wed Nov 13 12:43:03 PST 2024


================
@@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() {
   return success();
 }
 
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+  size_t numInputs = getInputs().size();
+  uint64_t concatDim = getDim();
+
+  SmallVector<SmallVector<OpFoldResult>> inputShapes;
+  inputShapes.reserve(numInputs);
+  SmallVector<OpFoldResult> concatOffsets;
+  concatOffsets.reserve(numInputs);
+  SmallVector<OpFoldResult> outputShape;
+
+  AffineExpr addExpr =
+      builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+  OpFoldResult zero = builder.getIndexAttr(0);
+  Location loc = getLoc();
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    SmallVector<OpFoldResult> inputShape =
+        tensor::getMixedSizes(builder, input.getLoc(), input);
+    if (index == 0) {
+      outputShape = inputShape;
+      concatOffsets.push_back(zero);
+    } else {
+      concatOffsets.push_back(outputShape[concatDim]);
+      outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+          builder, loc, addExpr,
+          {outputShape[concatDim], inputShape[concatDim]});
+    }
+    inputShapes.emplace_back(std::move(inputShape));
+  }
+
+  Value replacement = builder.create<tensor::EmptyOp>(
+      loc, outputShape, getType().getElementType());
+
+  int64_t rank = getType().getRank();
+  OpFoldResult one = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> strides(rank, one);
+  SmallVector<OpFoldResult> offsets(rank, zero);
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    offsets[concatDim] = concatOffsets[index];
+    auto insertSlice = builder.create<tensor::InsertSliceOp>(
+        loc, input, replacement, offsets, inputShapes[index], strides);
+    replacement = insertSlice.getResult();
+  }
+  if (replacement.getType() != getType()) {
+    replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
+  }
----------------
hanhanW wrote:

I see, good point. Have you considered using `ConcatOp::reifyResultShapes` to create the outputShape for tensor.empty op? Though this way we might create more operations and pay the cost. (I don't have preference, just wanna make sure that the idea is evaluated.)

The current implementation looks okay to me because the root issue is that the op does no infer static shapes when possible. We'll end up with these tensor.cast ops even if the shape inference is implemented.



https://github.com/llvm/llvm-project/pull/116004


More information about the Mlir-commits mailing list