[Mlir-commits] [mlir] [mlir][tensor][NFC] Code cleanup around shape inference support for `tensor.concat` op (PR #140616)
Aaron St George
llvmlistbot at llvm.org
Mon May 19 13:45:30 PDT 2025
https://github.com/AaronStGeorge created https://github.com/llvm/llvm-project/pull/140616
Addresses some code review on https://github.com/llvm/llvm-project/pull/140168 that came in after merge.
>From 7c09a8f2db4cca403c7cc279b9622906ed0d6626 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Mon, 19 May 2025 20:31:44 +0000
Subject: [PATCH] Address CR
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 21 ++++++++++-----------
1 file changed, 10 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6c32476d8656f..9a0d5d7e16960 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
- auto operandTensorTypes =
- llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
- return llvm::cast<RankedTensorType>(type);
- });
-
int64_t dim = concatOp.getDim();
- ArrayRef<int64_t> inferredResultShape =
- ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
+ RankedTensorType inferredResultType =
+ ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
// Find operands for which a more static shape can be inferred.
LogicalResult matched = failure();
- for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+ // Inferred operand shapes are identical in every dimension except the
+ // concatenation dimension.
+ SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
+ for (auto [operandIdx, operandType] :
+ llvm::enumerate(concatOp->getOperandTypes())) {
// Compute inferred type for operand.
- SmallVector<int64_t> inferredOperandShape(inferredResultShape);
- inferredOperandShape[dim] = operandType.getDimSize(dim);
+ inferredOperandShape[dim] =
+ cast<RankedTensorType>(operandType).getDimSize(dim);
auto inferredOperandType = RankedTensorType::get(
- inferredOperandShape, operandType.getElementType());
+ inferredOperandShape, inferredResultType.getElementType());
// Check if inferred type is more static.
if (!preservesStaticInformation(inferredOperandType, operandType)) {
More information about the Mlir-commits
mailing list