[Mlir-commits] [mlir] a0a55df - [mlir][tensor][NFC] Code cleanup around shape inference support for `tensor.concat` op (#140616)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 19 18:54:16 PDT 2025


Author: Aaron St George
Date: 2025-05-19T18:54:13-07:00
New Revision: a0a55df385a41c5bfa4107f83a5fa62b89d68914

URL: https://github.com/llvm/llvm-project/commit/a0a55df385a41c5bfa4107f83a5fa62b89d68914
DIFF: https://github.com/llvm/llvm-project/commit/a0a55df385a41c5bfa4107f83a5fa62b89d68914.diff

LOG: [mlir][tensor][NFC] Code cleanup around shape inference support for `tensor.concat` op (#140616)

Addresses some code review on
https://github.com/llvm/llvm-project/pull/140168 that came in after
merge.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6c32476d8656f..9a0d5d7e16960 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
 
   LogicalResult matchAndRewrite(ConcatOp concatOp,
                                 PatternRewriter &rewriter) const override {
-    auto operandTensorTypes =
-        llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
-          return llvm::cast<RankedTensorType>(type);
-        });
-
     int64_t dim = concatOp.getDim();
-    ArrayRef<int64_t> inferredResultShape =
-        ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
+    RankedTensorType inferredResultType =
+        ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
 
     // Find operands for which a more static shape can be inferred.
     LogicalResult matched = failure();
-    for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+    // Inferred operand shapes are identical in every dimension except the
+    // concatenation dimension.
+    SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
+    for (auto [operandIdx, operandType] :
+         llvm::enumerate(concatOp->getOperandTypes())) {
       // Compute inferred type for operand.
-      SmallVector<int64_t> inferredOperandShape(inferredResultShape);
-      inferredOperandShape[dim] = operandType.getDimSize(dim);
+      inferredOperandShape[dim] =
+          cast<RankedTensorType>(operandType).getDimSize(dim);
       auto inferredOperandType = RankedTensorType::get(
-          inferredOperandShape, operandType.getElementType());
+          inferredOperandShape, inferredResultType.getElementType());
 
       // Check if inferred type is more static.
       if (!preservesStaticInformation(inferredOperandType, operandType)) {


        


More information about the Mlir-commits mailing list