[Mlir-commits] [mlir] da944e0 - [mlir][tensor] Add shape inference support for `tensor.concat` op. (#140168)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 16 15:06:45 PDT 2025
Author: Aaron St George
Date: 2025-05-16T15:06:42-07:00
New Revision: da944e009955982927759c2f1fd47d43b236cc51
URL: https://github.com/llvm/llvm-project/commit/da944e009955982927759c2f1fd47d43b236cc51
DIFF: https://github.com/llvm/llvm-project/commit/da944e009955982927759c2f1fd47d43b236cc51.diff
LOG: [mlir][tensor] Add shape inference support for `tensor.concat` op. (#140168)
## description
`tensor.concat` requires operands and the result to match on all
dimensions except the concatenation dimension. If one operand is already
static in those dimensions, the other operands and result type may
safely be refined to that same static shape. This PR adds
canonicalization patterns to refine `tensor.concat` types and propagate
static shapes to other canonicalization patterns through casts.
## example
```mlir
%2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->tensor<?x12xi32>
```
becomes:
```mlir
%cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
%2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>,
tensor<?x12xi32>) -> tensor<?x12xi32>
```
---------
Co-authored-by: Ian Wood <ianwood2024 at u.northwestern.edu>
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 815806f06b472..6c32476d8656f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
@@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
/// Determines whether the tensor::CastOp casts to a more static version of the
/// source tensor. This is useful to fold into a producing op and implement
-/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
-/// being from
diff erent dialects. Returns true when all conditions are met:
+/// canonicalization patterns with the `tensor.cast` op as the root, but
+/// producer being from
diff erent dialects. Returns true when all conditions are
+/// met:
/// 1. source and result and ranked tensors with same element type and rank.
/// 2. the result type has more static information than the source.
///
@@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
return success();
}
};
+
+/// Propagate static shapes into the operands of a `tensor.concat`.
+///
+/// `tensor.concat` requires every operand to match on all dimensions except the
+/// concatenation dimension. If one operand is already static in those
+/// dimensions, the other operands may safely be refined to that same static
+/// shape.
+///
+/// Example:
+///
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
+/// tensor<?x12xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
+/// %2 = tensor.concat dim(0) %0, %cast :
+/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+/// ```
+struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto operandTensorTypes =
+ llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
+ return llvm::cast<RankedTensorType>(type);
+ });
+
+ int64_t dim = concatOp.getDim();
+ ArrayRef<int64_t> inferredResultShape =
+ ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
+
+ // Find operands for which a more static shape can be inferred.
+ LogicalResult matched = failure();
+ for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+ // Compute inferred type for operand.
+ SmallVector<int64_t> inferredOperandShape(inferredResultShape);
+ inferredOperandShape[dim] = operandType.getDimSize(dim);
+ auto inferredOperandType = RankedTensorType::get(
+ inferredOperandShape, operandType.getElementType());
+
+ // Check if inferred type is more static.
+ if (!preservesStaticInformation(inferredOperandType, operandType)) {
+ matched = success();
+
+ // Use refined operand type and create cast from original operand.
+ auto castOp =
+ rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
+ concatOp.getOperand(operandIdx));
+ rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
+ concatOp->setOperand(operandIdx, castOp->getResult(0));
+ });
+ }
+ }
+
+ return matched;
+ }
+};
+
+// Ensure `tensor.concat`'s result type is at least as static as can be inferred
+// from its operand types.
+///
+/// Example:
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
+/// tensor<?x?xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
+/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
+/// tensor<?x?xi32>
+/// ```
+struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ int64_t dim = concatOp.getDim();
+ RankedTensorType inferredResultType =
+ ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
+
+ // The result type should be at least as static as inferred result type.
+ if (preservesStaticInformation(inferredResultType,
+ concatOp.getResultType())) {
+ return failure();
+ }
+
+ auto newConcatOp = rewriter.create<ConcatOp>(
+ concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
+ rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+ newConcatOp);
+
+ return success();
+ }
+};
} // namespace
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleInputConcatOp>(context);
+ results
+ .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..cdcd7f305d2d9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
// -----
+// CHECK-LABEL: infer_concat_operand_types
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
+func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
+ // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
+ %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
+ // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+ return %0 : tensor<?x12xi32>
+ // CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
+}
+
+// -----
+
+// CHECK-LABEL: infer_concat_return_type
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
+func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
+ %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
+ // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+ // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+ // CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list