[Mlir-commits] [mlir] [mlir][tensor::ConcatOp] `tensor.concat` cast propagation (PR #140168)
Aaron St George
llvmlistbot at llvm.org
Thu May 15 17:53:20 PDT 2025
https://github.com/AaronStGeorge created https://github.com/llvm/llvm-project/pull/140168
## description
`tensor.concat` requires operands and the result to match on all dimensions except the concatenation dimension. If one operand is already static in those dimensions, the other operands and result type may safely be refined to that same static shape. This PR adds canonicalization patterns to refine `tensor.concat` types and propagate static shapes to other canonicalization patterns through casts.
## example
```mlir
// Second operand dim 1 has dynamic shape constrained by dim 1 of first
// operand.
%2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->tensor<?x12xi32>
```
becomes:
```mlir
%cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
%2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
```
>From 27233afecf4804b8418776df375227087bb7bbe8 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Thu, 15 May 2025 23:45:27 +0000
Subject: [PATCH] `tensor.concat` cast propagation
Adds canonicalization patterns which propagate inferred static shapes to
`tensor.concat` operands and result types. Static is propagated to other
canonicalization patterns through casts.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 114 ++++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 26 +++++
2 files changed, 137 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 815806f06b472..633e502db5a3a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
/// Determines whether the tensor::CastOp casts to a more static version of the
/// source tensor. This is useful to fold into a producing op and implement
-/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
-/// being from different dialects. Returns true when all conditions are met:
+/// canonicalization patterns with the `tensor.cast` op as the root, but
+/// producer being from different dialects. Returns true when all conditions are
+/// met:
/// 1. source and result and ranked tensors with same element type and rank.
/// 2. the result type has more static information than the source.
///
@@ -773,11 +774,118 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
return success();
}
};
+
+/// Propagate static shapes into the operands of a `tensor.concat`.
+///
+/// `tensor.concat` requires every operand to match on all dimensions except the
+/// concatenation dimension. If one operand is already static in those
+/// dimensions, the other operands may safely be refined to that same static
+/// shape.
+///
+/// Example:
+///
+/// ```mlir
+/// // Second operand dim 1 has dynamic shape constrained by dim 1 of first
+/// // operand.
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
+/// tensor<?x12xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
+/// %2 = tensor.concat dim(0) %0, %cast :
+/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+/// ```
+struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto operandTensorTypes =
+ llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
+ return llvm::cast<RankedTensorType>(type);
+ });
+
+ int64_t dim = concatOp.getDim();
+ ArrayRef<int64_t> inferredResultShape =
+ concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+
+ // Find operands for which a more static shape can be inferred.
+ SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+ for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+ // Compute inferred type for operand.
+ SmallVector<int64_t> inferredOperandShape(inferredResultShape);
+ inferredOperandShape[dim] = operandType.getDimSize(dim);
+ auto inferredOperandType = RankedTensorType::get(
+ inferredOperandShape, operandType.getElementType());
+
+ // Check if inferred type is more static.
+ if (!preservesStaticInformation(inferredOperandType, operandType)) {
+ refinedTypes.push_back({operandIdx, inferredOperandType});
+ }
+ }
+
+ if (refinedTypes.empty()) {
+ return failure();
+ }
+
+ // Use refined types for operands, insert casts for original type.
+ SmallVector<Value> newOperands = concatOp.getOperands();
+ for (auto [operandIdx, refinedType] : refinedTypes) {
+ newOperands[operandIdx] = rewriter.create<CastOp>(
+ concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
+ }
+ rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
+ dim, newOperands);
+
+ return success();
+ }
+};
+
+// Ensure `tensor.concat`'s result type is at least as static as can be inferred
+// from its operand types.
+///
+/// Example:
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
+/// tensor<?x?xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
+/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
+/// tensor<?x?xi32>
+/// ```
+struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ int64_t dim = concatOp.getDim();
+ RankedTensorType inferredResultType =
+ concatOp.inferResultType(dim, concatOp->getOperandTypes());
+
+ // The result type should be at least as static as inferred result type.
+ if (preservesStaticInformation(inferredResultType,
+ concatOp.getResultType())) {
+ return failure();
+ }
+
+ auto newConcatOp = rewriter.create<ConcatOp>(
+ concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
+ rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+ newConcatOp);
+
+ return llvm::success();
+ }
+};
} // namespace
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleInputConcatOp>(context);
+ results
+ .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..cdcd7f305d2d9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
// -----
+// CHECK-LABEL: infer_concat_operand_types
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
+func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
+ // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
+ %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
+ // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+ return %0 : tensor<?x12xi32>
+ // CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
+}
+
+// -----
+
+// CHECK-LABEL: infer_concat_return_type
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
+func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
+ %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
+ // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+ // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+ // CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list