[Mlir-commits] [mlir] [mlir][tensor::ConcatOp] `tensor.concat` cast propagation (PR #140168)

Ian Wood llvmlistbot at llvm.org
Fri May 16 10:36:21 PDT 2025


================
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
     return success();
   }
 };
+
+/// Propagate static shapes into the operands of a `tensor.concat`.
+///
+/// `tensor.concat` requires every operand to match on all dimensions except the
+/// concatenation dimension. If one operand is already static in those
+/// dimensions, the other operands may safely be refined to that same static
+/// shape.
+///
+/// Example:
+///
+/// ```mlir
+///   %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
+///        tensor<?x12xi32>
+/// ```
+/// ->
+/// ```mlir
+///   %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
+///   %2 = tensor.concat dim(0) %0, %cast :
+///        (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+/// ```
+struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto operandTensorTypes =
+        llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
+          return llvm::cast<RankedTensorType>(type);
+        });
+
+    int64_t dim = concatOp.getDim();
+    ArrayRef<int64_t> inferredResultShape =
+        concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+
+    // Find operands for which a more static shape can be inferred.
+    SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+    for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+      // Compute inferred type for operand.
+      SmallVector<int64_t> inferredOperandShape(inferredResultShape);
+      inferredOperandShape[dim] = operandType.getDimSize(dim);
+      auto inferredOperandType = RankedTensorType::get(
+          inferredOperandShape, operandType.getElementType());
+
+      // Check if inferred type is more static.
+      if (!preservesStaticInformation(inferredOperandType, operandType)) {
+        refinedTypes.push_back({operandIdx, inferredOperandType});
+      }
+    }
+
+    if (refinedTypes.empty()) {
+      return failure();
+    }
+
+    // Use refined types for operands, insert casts for original type.
+    SmallVector<Value> newOperands = concatOp.getOperands();
+    for (auto [operandIdx, refinedType] : refinedTypes) {
+      newOperands[operandIdx] = rewriter.create<CastOp>(
+          concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
+    }
+    rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
+                                          dim, newOperands);
+
+    return success();
+  }
+};
+
+// Ensure `tensor.concat`'s result type is at least as static as can be inferred
+// from its operand types.
+///
+/// Example:
+/// ```mlir
+///   %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
+///   tensor<?x?xi32>
+/// ```
+/// ->
+/// ```mlir
+///   %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
+///   -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
+///   tensor<?x?xi32>
+/// ```
+struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    int64_t dim = concatOp.getDim();
+    RankedTensorType inferredResultType =
+        concatOp.inferResultType(dim, concatOp->getOperandTypes());
+
+    // The result type should be at least as static as inferred result type.
+    if (preservesStaticInformation(inferredResultType,
+                                   concatOp.getResultType())) {
+      return failure();
+    }
+
+    auto newConcatOp = rewriter.create<ConcatOp>(
+        concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
+    rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+                                        newConcatOp);
+
+    return llvm::success();
----------------
IanWood1 wrote:

```suggestion
    return success();
```

https://github.com/llvm/llvm-project/pull/140168


More information about the Mlir-commits mailing list