[Mlir-commits] [mlir] [mlir][Tensor] Move concat operation decomposition as a method of the concat operation. (PR #116004)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Wed Nov 13 13:47:23 PST 2024
    
    
  
================
@@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() {
   return success();
 }
 
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+  size_t numInputs = getInputs().size();
+  uint64_t concatDim = getDim();
+
+  SmallVector<SmallVector<OpFoldResult>> inputShapes;
+  inputShapes.reserve(numInputs);
+  SmallVector<OpFoldResult> concatOffsets;
+  concatOffsets.reserve(numInputs);
+  SmallVector<OpFoldResult> outputShape;
+
+  AffineExpr addExpr =
+      builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+  OpFoldResult zero = builder.getIndexAttr(0);
+  Location loc = getLoc();
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    SmallVector<OpFoldResult> inputShape =
+        tensor::getMixedSizes(builder, input.getLoc(), input);
+    if (index == 0) {
+      outputShape = inputShape;
+      concatOffsets.push_back(zero);
+    } else {
+      concatOffsets.push_back(outputShape[concatDim]);
+      outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+          builder, loc, addExpr,
+          {outputShape[concatDim], inputShape[concatDim]});
+    }
+    inputShapes.emplace_back(std::move(inputShape));
+  }
+
+  Value replacement = builder.create<tensor::EmptyOp>(
+      loc, outputShape, getType().getElementType());
+
+  int64_t rank = getType().getRank();
+  OpFoldResult one = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> strides(rank, one);
+  SmallVector<OpFoldResult> offsets(rank, zero);
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    offsets[concatDim] = concatOffsets[index];
+    auto insertSlice = builder.create<tensor::InsertSliceOp>(
+        loc, input, replacement, offsets, inputShapes[index], strides);
+    replacement = insertSlice.getResult();
+  }
+  if (replacement.getType() != getType()) {
+    replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
+  }
----------------
MaheshRavishankar wrote:
(sorry my phone was misbehaving and I was afk). That's exactly right. `reifyResultShapes` would create a lot of operations. We need to have proper cast propagation to resolve the static information that is outside of this decomposition 
https://github.com/llvm/llvm-project/pull/116004
    
    
More information about the Mlir-commits
mailing list