[Mlir-commits] [mlir] [mlir][tensor] Add shape inference support for `tensor.concat` op. (PR #140168)
Aaron St George
llvmlistbot at llvm.org
Fri May 16 13:03:33 PDT 2025
================
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
return success();
}
};
+
+/// Propagate static shapes into the operands of a `tensor.concat`.
+///
+/// `tensor.concat` requires every operand to match on all dimensions except the
+/// concatenation dimension. If one operand is already static in those
+/// dimensions, the other operands may safely be refined to that same static
+/// shape.
+///
+/// Example:
+///
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
+/// tensor<?x12xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
+/// %2 = tensor.concat dim(0) %0, %cast :
+/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+/// ```
+struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto operandTensorTypes =
+ llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
+ return llvm::cast<RankedTensorType>(type);
+ });
+
+ int64_t dim = concatOp.getDim();
+ ArrayRef<int64_t> inferredResultShape =
+ concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+
+ // Find operands for which a more static shape can be inferred.
+ SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+ for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+ // Compute inferred type for operand.
+ SmallVector<int64_t> inferredOperandShape(inferredResultShape);
+ inferredOperandShape[dim] = operandType.getDimSize(dim);
+ auto inferredOperandType = RankedTensorType::get(
+ inferredOperandShape, operandType.getElementType());
+
+ // Check if inferred type is more static.
+ if (!preservesStaticInformation(inferredOperandType, operandType)) {
+ refinedTypes.push_back({operandIdx, inferredOperandType});
+ }
+ }
+
+ if (refinedTypes.empty()) {
+ return failure();
+ }
+
+ // Use refined types for operands, insert casts for original type.
+ SmallVector<Value> newOperands = concatOp.getOperands();
+ for (auto [operandIdx, refinedType] : refinedTypes) {
+ newOperands[operandIdx] = rewriter.create<CastOp>(
+ concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
+ }
+ rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
+ dim, newOperands);
----------------
AaronStGeorge wrote:
I'll give that a shot, I was under the impression that `concatOp.setOperand()` didn't update the operator types appropriately but I'm not sure of that. Making the mutation in the first loop will require storing off a `bool` or something to determine if the pattern successfully matched or not, but it would be lighter weight than `refinedTypes` and `newOperands`.
https://github.com/llvm/llvm-project/pull/140168
More information about the Mlir-commits
mailing list