[Mlir-commits] [mlir] [mlir][tensor] Add shape inference support for `tensor.concat` op. (PR #140168)

Aaron St George llvmlistbot at llvm.org
Fri May 16 14:00:59 PDT 2025


https://github.com/AaronStGeorge updated https://github.com/llvm/llvm-project/pull/140168

>From ee357701917540a995152f2332e69de96bb53615 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Thu, 15 May 2025 23:45:27 +0000
Subject: [PATCH 1/3] `tensor.concat` cast propagation

Adds canonicalization patterns which propagate inferred static shapes to
`tensor.concat` operands and result types. Static is propagated to other
canonicalization patterns through casts.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 112 ++++++++++++++++++++-
 mlir/test/Dialect/Tensor/canonicalize.mlir |  26 +++++
 2 files changed, 135 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 815806f06b472..d66fc7ad8a905 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
 
 /// Determines whether the tensor::CastOp casts to a more static version of the
 /// source tensor. This is useful to fold into a producing op and implement
-/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
-/// being from different dialects. Returns true when all conditions are met:
+/// canonicalization patterns with the `tensor.cast` op as the root, but
+/// producer being from different dialects. Returns true when all conditions are
+/// met:
 /// 1. source and result and ranked tensors with same element type and rank.
 /// 2. the result type has more static information than the source.
 ///
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
     return success();
   }
 };
+
+/// Propagate static shapes into the operands of a `tensor.concat`.
+///
+/// `tensor.concat` requires every operand to match on all dimensions except the
+/// concatenation dimension. If one operand is already static in those
+/// dimensions, the other operands may safely be refined to that same static
+/// shape.
+///
+/// Example:
+///
+/// ```mlir
+///   %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
+///        tensor<?x12xi32>
+/// ```
+/// ->
+/// ```mlir
+///   %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
+///   %2 = tensor.concat dim(0) %0, %cast :
+///        (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+/// ```
+struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto operandTensorTypes =
+        llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
+          return llvm::cast<RankedTensorType>(type);
+        });
+
+    int64_t dim = concatOp.getDim();
+    ArrayRef<int64_t> inferredResultShape =
+        concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+
+    // Find operands for which a more static shape can be inferred.
+    SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+    for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+      // Compute inferred type for operand.
+      SmallVector<int64_t> inferredOperandShape(inferredResultShape);
+      inferredOperandShape[dim] = operandType.getDimSize(dim);
+      auto inferredOperandType = RankedTensorType::get(
+          inferredOperandShape, operandType.getElementType());
+
+      // Check if inferred type is more static.
+      if (!preservesStaticInformation(inferredOperandType, operandType)) {
+        refinedTypes.push_back({operandIdx, inferredOperandType});
+      }
+    }
+
+    if (refinedTypes.empty()) {
+      return failure();
+    }
+
+    // Use refined types for operands, insert casts for original type.
+    SmallVector<Value> newOperands = concatOp.getOperands();
+    for (auto [operandIdx, refinedType] : refinedTypes) {
+      newOperands[operandIdx] = rewriter.create<CastOp>(
+          concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
+    }
+    rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
+                                          dim, newOperands);
+
+    return success();
+  }
+};
+
+// Ensure `tensor.concat`'s result type is at least as static as can be inferred
+// from its operand types.
+///
+/// Example:
+/// ```mlir
+///   %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
+///   tensor<?x?xi32>
+/// ```
+/// ->
+/// ```mlir
+///   %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
+///   -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
+///   tensor<?x?xi32>
+/// ```
+struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    int64_t dim = concatOp.getDim();
+    RankedTensorType inferredResultType =
+        concatOp.inferResultType(dim, concatOp->getOperandTypes());
+
+    // The result type should be at least as static as inferred result type.
+    if (preservesStaticInformation(inferredResultType,
+                                   concatOp.getResultType())) {
+      return failure();
+    }
+
+    auto newConcatOp = rewriter.create<ConcatOp>(
+        concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
+    rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+                                        newConcatOp);
+
+    return llvm::success();
+  }
+};
 } // namespace
 
 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<SingleInputConcatOp>(context);
+  results
+      .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..cdcd7f305d2d9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
 
 // -----
 
+// CHECK-LABEL: infer_concat_operand_types
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
+func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
+  // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
+  %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
+  // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+  return %0 : tensor<?x12xi32>
+  // CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
+}
+
+// -----
+
+// CHECK-LABEL: infer_concat_return_type
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
+func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
+  %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
+  // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+  // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+  // CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract
 func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index

>From 0bcd656cc5fd3f9b4b2e17717c2f19d36dad0ca7 Mon Sep 17 00:00:00 2001
From: Aaron St George <aaronstgeorge at gmail.com>
Date: Fri, 16 May 2025 13:03:48 -0700
Subject: [PATCH 2/3] Apply suggestions from code review

Co-authored-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d66fc7ad8a905..2022551a8e22f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -806,7 +806,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
 
     int64_t dim = concatOp.getDim();
     ArrayRef<int64_t> inferredResultShape =
-        concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+        ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
 
     // Find operands for which a more static shape can be inferred.
     SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
@@ -861,7 +861,7 @@ struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
                                 PatternRewriter &rewriter) const override {
     int64_t dim = concatOp.getDim();
     RankedTensorType inferredResultType =
-        concatOp.inferResultType(dim, concatOp->getOperandTypes());
+        ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
 
     // The result type should be at least as static as inferred result type.
     if (preservesStaticInformation(inferredResultType,
@@ -874,7 +874,7 @@ struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
     rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
                                         newConcatOp);
 
-    return llvm::success();
+    return success();
   }
 };
 } // namespace

>From b0d1405e6267abb35ce55259e75c1b4c171b0001 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Fri, 16 May 2025 20:49:23 +0000
Subject: [PATCH 3/3] Code review comment regarding setOperand

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 29 +++++++++++-------------
 1 file changed, 13 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2022551a8e22f..4a68e943fe3ce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -33,6 +33,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include <algorithm>
 #include <optional>
@@ -809,7 +810,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
         ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
 
     // Find operands for which a more static shape can be inferred.
-    SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+    LogicalResult matched = failure();
     for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
       // Compute inferred type for operand.
       SmallVector<int64_t> inferredOperandShape(inferredResultShape);
@@ -819,24 +820,20 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
 
       // Check if inferred type is more static.
       if (!preservesStaticInformation(inferredOperandType, operandType)) {
-        refinedTypes.push_back({operandIdx, inferredOperandType});
+        matched = success();
+
+        // Use refined operand type and create cast from original operand.
+        auto castOp =
+            rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
+                                    concatOp.getOperand(operandIdx));
+        rewriter.modifyOpInPlace(
+            concatOp, [=, operandIdx = (size_t)operandIdx] {
+              concatOp->setOperand(operandIdx, castOp->getResult(0));
+            });
       }
     }
 
-    if (refinedTypes.empty()) {
-      return failure();
-    }
-
-    // Use refined types for operands, insert casts for original type.
-    SmallVector<Value> newOperands = concatOp.getOperands();
-    for (auto [operandIdx, refinedType] : refinedTypes) {
-      newOperands[operandIdx] = rewriter.create<CastOp>(
-          concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
-    }
-    rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
-                                          dim, newOperands);
-
-    return success();
+    return matched;
   }
 };
 



More information about the Mlir-commits mailing list