[Mlir-commits] [mlir] [mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization (PR #91352)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 7 08:59:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir
Author: None (Max191)
<details>
<summary>Changes</summary>
In some cases this pattern may ignore static information due to dynamic operands in the insert_slice sizes operands, e.g.:
```
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[...] [%s0, %s1] [...]
: tensor<?x?xf32> into tensor<?x?xf32>
```
Can be rewritten into:
```
%1 = tensor.insert_slice %arg0 into %arg1[...] [1, %s1] [...]
: tensor<1x?xf32> into tensor<?x?xf32>
```
This PR updates the matching in the pattern to allow rewrites like this.
---
Full diff: https://github.com/llvm/llvm-project/pull/91352.diff
4 Files Affected:
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+7-1)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+26-3)
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+12-12)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+4-5)
``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 2361cf1371237b..5579b138668d2b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -360,9 +360,15 @@ class VectorType::Builder {
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
/// obtained by dropping only `1` entries in `originalShape`.
+/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
+/// `reducedShape` will be considered matching with non-dynamic dims, unless
+/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
+/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
+/// match with the corresponding dynamic dims.
std::optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,
- ArrayRef<int64_t> reducedShape);
+ ArrayRef<int64_t> reducedShape,
+ bool matchDynamic = false);
/// Enum that captures information related to verifier error conditions on
/// slice insert/extract type of ops.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..d560c11464f1c1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2711,15 +2711,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
if (!srcType || !dstType)
return failure();
+
+ // The tensor.cast source could have additional static information not seen
+ // in the insert slice op static sizes, so we ignore dynamic dims when
+ // computing the rank reduction mask.
+ SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
+ auto rankReductionMask = computeRankReductionMask(
+ staticSizes, srcType.getShape(), /*matchDynamic=*/true);
+ if (!rankReductionMask.has_value())
+ return failure();
+ // Replace dimensions in the insert slice op with corresponding static dims
+ // from the cast source type. If the insert slice sizes have static dims
+ // that are not static in the tensor.cast source (i.e., when the cast op
+ // casts a dynamic dim to static), the dim should not be replaced, and the
+ // pattern will fail later in `verifyInsertSliceOp`.
+ SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+ int64_t rankReducedIdx = 0;
+ for (auto [idx, size] : enumerate(staticSizes)) {
+ if (!rankReductionMask.value().contains(idx) &&
+ !srcType.isDynamicDim(rankReducedIdx)) {
+ mixedSizes[idx] = getAsIndexOpFoldResult(
+ rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
+ size = srcType.getDimSize(rankReducedIdx++);
+ }
+ }
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
- insertSliceOp.getStaticSizes(),
- insertSliceOp.getStaticStrides()) !=
+ staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ mixedSizes, insertSliceOp.getMixedStrides());
// In the parallel case there is no result and so nothing to cast.
bool isParallelInsert =
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a2738946de410e..179797cb943a1a 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -408,24 +408,24 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
// MemRefType
//===----------------------------------------------------------------------===//
-/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
-/// `originalShape` with some `1` entries erased, return the set of indices
-/// that specifies which of the entries of `originalShape` are dropped to obtain
-/// `reducedShape`. The returned mask can be applied as a projection to
-/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
-/// which dimensions must be kept when e.g. compute MemRef strides under
-/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
-/// obtained by dropping only `1` entries in `originalShape`.
std::optional<llvm::SmallDenseSet<unsigned>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
- ArrayRef<int64_t> reducedShape) {
+ ArrayRef<int64_t> reducedShape,
+ bool matchDynamic) {
size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
llvm::SmallDenseSet<unsigned> unusedDims;
unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
// Greedily insert `originalIdx` if match.
- if (reducedIdx < reducedRank &&
- originalShape[originalIdx] == reducedShape[reducedIdx]) {
+ int64_t origSize = originalShape[originalIdx];
+ // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
+ if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
+ (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
+ ShapedType::isDynamic(origSize))) {
+ reducedIdx++;
+ continue;
+ }
+ if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
reducedIdx++;
continue;
}
@@ -433,7 +433,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
unusedDims.insert(originalIdx);
// If no match on `originalIdx`, the `originalShape` at this dimension
// must be 1, otherwise we bail.
- if (originalShape[originalIdx] != 1)
+ if (origSize != 1)
return std::nullopt;
}
// The whole reducedShape must be scanned, otherwise we bail.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..53c8a65d39e633 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1890,14 +1890,13 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
// -----
-// There was an issue in cast + insert_slice folding generating invalid ir.
-// https://github.com/llvm/llvm-project/issues/53099
// CHECK-LABEL: func @insert_slice_cast
func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
- // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
%0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
- // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
- // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
+ // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
+ // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
+ // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
%1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: return %[[RES]] : tensor<?x?xf32>
return %1 : tensor<?x?xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/91352
More information about the Mlir-commits
mailing list