[Mlir-commits] [mlir] [mlir][tensor] Add shape inference support for `tensor.concat` op. (PR #140168)
Aaron St George
llvmlistbot at llvm.org
Fri May 16 14:39:15 PDT 2025
https://github.com/AaronStGeorge updated https://github.com/llvm/llvm-project/pull/140168
>From ee357701917540a995152f2332e69de96bb53615 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Thu, 15 May 2025 23:45:27 +0000
Subject: [PATCH 1/4] `tensor.concat` cast propagation
Adds canonicalization patterns which propagate inferred static shapes to
`tensor.concat` operands and result types. Static is propagated to other
canonicalization patterns through casts.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 112 ++++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 26 +++++
2 files changed, 135 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 815806f06b472..d66fc7ad8a905 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
/// Determines whether the tensor::CastOp casts to a more static version of the
/// source tensor. This is useful to fold into a producing op and implement
-/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
-/// being from different dialects. Returns true when all conditions are met:
+/// canonicalization patterns with the `tensor.cast` op as the root, but
+/// producer being from different dialects. Returns true when all conditions are
+/// met:
/// 1. source and result and ranked tensors with same element type and rank.
/// 2. the result type has more static information than the source.
///
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
return success();
}
};
+
+/// Propagate static shapes into the operands of a `tensor.concat`.
+///
+/// `tensor.concat` requires every operand to match on all dimensions except the
+/// concatenation dimension. If one operand is already static in those
+/// dimensions, the other operands may safely be refined to that same static
+/// shape.
+///
+/// Example:
+///
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
+/// tensor<?x12xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
+/// %2 = tensor.concat dim(0) %0, %cast :
+/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+/// ```
+struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto operandTensorTypes =
+ llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
+ return llvm::cast<RankedTensorType>(type);
+ });
+
+ int64_t dim = concatOp.getDim();
+ ArrayRef<int64_t> inferredResultShape =
+ concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+
+ // Find operands for which a more static shape can be inferred.
+ SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+ for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
+ // Compute inferred type for operand.
+ SmallVector<int64_t> inferredOperandShape(inferredResultShape);
+ inferredOperandShape[dim] = operandType.getDimSize(dim);
+ auto inferredOperandType = RankedTensorType::get(
+ inferredOperandShape, operandType.getElementType());
+
+ // Check if inferred type is more static.
+ if (!preservesStaticInformation(inferredOperandType, operandType)) {
+ refinedTypes.push_back({operandIdx, inferredOperandType});
+ }
+ }
+
+ if (refinedTypes.empty()) {
+ return failure();
+ }
+
+ // Use refined types for operands, insert casts for original type.
+ SmallVector<Value> newOperands = concatOp.getOperands();
+ for (auto [operandIdx, refinedType] : refinedTypes) {
+ newOperands[operandIdx] = rewriter.create<CastOp>(
+ concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
+ }
+ rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
+ dim, newOperands);
+
+ return success();
+ }
+};
+
+// Ensure `tensor.concat`'s result type is at least as static as can be inferred
+// from its operand types.
+///
+/// Example:
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
+/// tensor<?x?xi32>
+/// ```
+/// ->
+/// ```mlir
+/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
+/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
+/// tensor<?x?xi32>
+/// ```
+struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ int64_t dim = concatOp.getDim();
+ RankedTensorType inferredResultType =
+ concatOp.inferResultType(dim, concatOp->getOperandTypes());
+
+ // The result type should be at least as static as inferred result type.
+ if (preservesStaticInformation(inferredResultType,
+ concatOp.getResultType())) {
+ return failure();
+ }
+
+ auto newConcatOp = rewriter.create<ConcatOp>(
+ concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
+ rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+ newConcatOp);
+
+ return llvm::success();
+ }
+};
} // namespace
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleInputConcatOp>(context);
+ results
+ .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..cdcd7f305d2d9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
// -----
+// CHECK-LABEL: infer_concat_operand_types
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
+func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
+ // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
+ %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
+ // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+ return %0 : tensor<?x12xi32>
+ // CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
+}
+
+// -----
+
+// CHECK-LABEL: infer_concat_return_type
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
+func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
+ %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
+ // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
+ // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+ // CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
>From 0bcd656cc5fd3f9b4b2e17717c2f19d36dad0ca7 Mon Sep 17 00:00:00 2001
From: Aaron St George <aaronstgeorge at gmail.com>
Date: Fri, 16 May 2025 13:03:48 -0700
Subject: [PATCH 2/4] Apply suggestions from code review
Co-authored-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d66fc7ad8a905..2022551a8e22f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -806,7 +806,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
int64_t dim = concatOp.getDim();
ArrayRef<int64_t> inferredResultShape =
- concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
+ ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
// Find operands for which a more static shape can be inferred.
SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
@@ -861,7 +861,7 @@ struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
PatternRewriter &rewriter) const override {
int64_t dim = concatOp.getDim();
RankedTensorType inferredResultType =
- concatOp.inferResultType(dim, concatOp->getOperandTypes());
+ ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
// The result type should be at least as static as inferred result type.
if (preservesStaticInformation(inferredResultType,
@@ -874,7 +874,7 @@ struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
newConcatOp);
- return llvm::success();
+ return success();
}
};
} // namespace
>From b0d1405e6267abb35ce55259e75c1b4c171b0001 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Fri, 16 May 2025 20:49:23 +0000
Subject: [PATCH 3/4] Code review comment regarding setOperand
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 29 +++++++++++-------------
1 file changed, 13 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2022551a8e22f..4a68e943fe3ce 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
@@ -809,7 +810,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
// Find operands for which a more static shape can be inferred.
- SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
+ LogicalResult matched = failure();
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
// Compute inferred type for operand.
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
@@ -819,24 +820,20 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
// Check if inferred type is more static.
if (!preservesStaticInformation(inferredOperandType, operandType)) {
- refinedTypes.push_back({operandIdx, inferredOperandType});
+ matched = success();
+
+ // Use refined operand type and create cast from original operand.
+ auto castOp =
+ rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
+ concatOp.getOperand(operandIdx));
+ rewriter.modifyOpInPlace(
+ concatOp, [=, operandIdx = (size_t)operandIdx] {
+ concatOp->setOperand(operandIdx, castOp->getResult(0));
+ });
}
}
- if (refinedTypes.empty()) {
- return failure();
- }
-
- // Use refined types for operands, insert casts for original type.
- SmallVector<Value> newOperands = concatOp.getOperands();
- for (auto [operandIdx, refinedType] : refinedTypes) {
- newOperands[operandIdx] = rewriter.create<CastOp>(
- concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
- }
- rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
- dim, newOperands);
-
- return success();
+ return matched;
}
};
>From 19fd69108c6fc58723e28e82b1fce1d8b0f4a383 Mon Sep 17 00:00:00 2001
From: AaronStGeorge <aaronstgeorge at gmail.com>
Date: Fri, 16 May 2025 21:38:52 +0000
Subject: [PATCH 4/4] Cast is not necessary
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4a68e943fe3ce..6c32476d8656f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -826,10 +826,9 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
auto castOp =
rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
concatOp.getOperand(operandIdx));
- rewriter.modifyOpInPlace(
- concatOp, [=, operandIdx = (size_t)operandIdx] {
- concatOp->setOperand(operandIdx, castOp->getResult(0));
- });
+ rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
+ concatOp->setOperand(operandIdx, castOp->getResult(0));
+ });
}
}
More information about the Mlir-commits
mailing list