[Mlir-commits] [mlir] a0a55df - [mlir][tensor][NFC] Code cleanup around shape inference support for `tensor.concat` op (#140616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 19 18:54:16 PDT 2025
Author: Aaron St George
Date: 2025-05-19T18:54:13-07:00
New Revision: a0a55df385a41c5bfa4107f83a5fa62b89d68914
URL: https://github.com/llvm/llvm-project/commit/a0a55df385a41c5bfa4107f83a5fa62b89d68914
DIFF: https://github.com/llvm/llvm-project/commit/a0a55df385a41c5bfa4107f83a5fa62b89d68914.diff
LOG: [mlir][tensor][NFC] Code cleanup around shape inference support for `tensor.concat` op (#140616)
Addresses some code review on
https://github.com/llvm/llvm-project/pull/140168 that came in after
merge.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6c32476d8656f..9a0d5d7e16960 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
- auto operandTensorTypes =
- llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
- return llvm::cast<RankedTensorType>(type);
- });
-
int64_t dim = concatOp.getDim();
- ArrayRef<int64_t> inferredResultShape =
- ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
+ RankedTensorType inferredResultType =
+ ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
// Find operands for which a more static shape can be inferred.
LogicalResult matched = failure();
- for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+ // Inferred operand shapes are identical in every dimension except the
+ // concatenation dimension.
+ SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
+ for (auto [operandIdx, operandType] :
+ llvm::enumerate(concatOp->getOperandTypes())) {
// Compute inferred type for operand.
- SmallVector<int64_t> inferredOperandShape(inferredResultShape);
- inferredOperandShape[dim] = operandType.getDimSize(dim);
+ inferredOperandShape[dim] =
+ cast<RankedTensorType>(operandType).getDimSize(dim);
auto inferredOperandType = RankedTensorType::get(
- inferredOperandShape, operandType.getElementType());
+ inferredOperandShape, inferredResultType.getElementType());
// Check if inferred type is more static.
if (!preservesStaticInformation(inferredOperandType, operandType)) {
More information about the Mlir-commits
mailing list