[Mlir-commits] [mlir] [mlir][tensor][NFC] Code cleanup around shape inference support for `tensor.concat` op (PR #140616)

Aaron St George llvmlistbot at llvm.org
Mon May 19 13:45:30 PDT 2025


https://github.com/AaronStGeorge created https://github.com/llvm/llvm-project/pull/140616

Addresses some code review on https://github.com/llvm/llvm-project/pull/140168 that came in after merge.

>From 7c09a8f2db4cca403c7cc279b9622906ed0d6626 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Mon, 19 May 2025 20:31:44 +0000
Subject: [PATCH] Address CR

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 21 ++++++++++-----------
 1 file changed, 10 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6c32476d8656f..9a0d5d7e16960 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
 
   LogicalResult matchAndRewrite(ConcatOp concatOp,
                                 PatternRewriter &rewriter) const override {
-    auto operandTensorTypes =
-        llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
-          return llvm::cast<RankedTensorType>(type);
-        });
-
     int64_t dim = concatOp.getDim();
-    ArrayRef<int64_t> inferredResultShape =
-        ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
+    RankedTensorType inferredResultType =
+        ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
 
     // Find operands for which a more static shape can be inferred.
     LogicalResult matched = failure();
-    for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+    // Inferred operand shapes are identical in every dimension except the
+    // concatenation dimension.
+    SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
+    for (auto [operandIdx, operandType] :
+         llvm::enumerate(concatOp->getOperandTypes())) {
       // Compute inferred type for operand.
-      SmallVector<int64_t> inferredOperandShape(inferredResultShape);
-      inferredOperandShape[dim] = operandType.getDimSize(dim);
+      inferredOperandShape[dim] =
+          cast<RankedTensorType>(operandType).getDimSize(dim);
       auto inferredOperandType = RankedTensorType::get(
-          inferredOperandShape, operandType.getElementType());
+          inferredOperandShape, inferredResultType.getElementType());
 
       // Check if inferred type is more static.
       if (!preservesStaticInformation(inferredOperandType, operandType)) {



More information about the Mlir-commits mailing list