[Mlir-commits] [mlir] [MLIR][Tensor] Canonicalize fully covering slice insertions into tensors with unit prefixes (PR #92912)
Andi Drebes
llvmlistbot at llvm.org
Tue May 21 06:00:43 PDT 2024
https://github.com/andidr created https://github.com/llvm/llvm-project/pull/92912
If the destination tensor of the insertion of a slice has the same number of elements as the slice, but with a shape that only differs by a prefix of unit-sized dimensions, and if the insertion happens at zero offsets, unit strides and with a size matching the size of the destination, the insertion covers all elements of the destination. The result of such an insertion is equivalent to the slice, with its shape expanded to the type of the destination.
Example:
```mlir
%0 = tensor.insert_slice %slice into
%x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] :
tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
```
folds into:
```mlir
%0 = tensor.expand_shape %slice[[0,1,2,3], [4]] :
tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
```
This PR adds a canonicalization pattern for `InsertSliceOp` that implements this pattern.
>From 6b3a3e109f738c9199a483e4b4d457797a4d2ef2 Mon Sep 17 00:00:00 2001
From: Andi Drebes <andi at drebesium.org>
Date: Tue, 21 May 2024 14:42:39 +0200
Subject: [PATCH] [MLIR][Tensor] Canonicalize fully covering slice insertions
into tensors with unit prefixes
If the destination tensor of the insertion of a slice has the same
number of elements as the slice, but with a shape that only differs by
a prefix of unit-sized dimensions, and if the insertion happens at
zero offsets, unit strides and with a size matching the size of the
destination, the insertion covers all elements of the destination. The
result of such an insertion is equivalent to the slice, with its shape
expanded to the type of the destination.
Example:
```mlir
%0 = tensor.insert_slice %slice into
%x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] :
tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
```
folds into:
```mlir
%0 = tensor.expand_shape %slice[[0,1,2,3], [4]] :
tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
```
This commit adds a canonicalization pattern for `InsertSliceOp` that
implements this pattern.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 88 +++++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 12 +++
2 files changed, 99 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8545c7b9af8f7..52d7005470232 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2835,6 +2835,91 @@ struct InsertSliceOpSourceCastInserter final
return success();
}
};
+
+/// If the destination tensor of the insertion of a slice has the same
+/// number of elements as the slice, but with a shape that only
+/// differs by a prefix of unit-sized dimensions, and if the insertion
+/// happens at zero offsets, unit strides and with a size matching the
+/// size of the destination, the insertion covers all elements of the
+/// destination. The result of such an insertion is equivalent to the
+/// slice, with its shape expanded to the type of the destination.
+///
+/// Example:
+/// ```mlir
+/// %0 = tensor.insert_slice %slice into
+/// %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] :
+/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] :
+/// tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
+/// ```
+struct InsertSliceOpFullRewriteCanonicalizer final
+ : public OpRewritePattern<InsertSliceOp> {
+ using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ RankedTensorType sourceType = insertSliceOp.getSourceType();
+ RankedTensorType resultType = insertSliceOp.getType();
+
+ if (sourceType != resultType && sourceType.hasStaticShape() &&
+ resultType.hasStaticShape() &&
+ isSameSizedSuffixShape(resultType.getShape(), sourceType.getShape()) &&
+ succeeded(foldIdentityOffsetSizeAndStrideOpInterface(insertSliceOp,
+ resultType))) {
+ SmallVector<ReassociationIndices> reassocIndices;
+
+ // Number of leading dimensions with unit size that are not
+ // shared with the source type
+ size_t unitPrefixLength =
+ resultType.getShape().size() - sourceType.getShape().size();
+
+ // Compose mapping of leading dimensions with unit size and the
+ // fist common dimension to the first dimension of the source
+ // tensor
+ ReassociationIndices unitPrefixExpansion;
+
+ size_t dim;
+ for (dim = 0; dim < unitPrefixLength; dim++)
+ unitPrefixExpansion.push_back(dim);
+
+ unitPrefixExpansion.push_back(unitPrefixLength);
+ reassocIndices.push_back(unitPrefixExpansion);
+
+ // Map remaining common dimensions of the source to the target
+ for (dim = dim + 1; dim < resultType.getShape().size(); dim++) {
+ reassocIndices.push_back({static_cast<int64_t>(dim)});
+ }
+
+ rewriter.replaceOpWithNewOp<ExpandShapeOp>(
+ insertSliceOp, insertSliceOp.getType(), insertSliceOp.getSource(),
+ reassocIndices);
+
+ return mlir::success();
+ }
+
+ return mlir::failure();
+ }
+
+private:
+ /// Checks if `suffix` is a suffix of `shape` and all preceding
+ /// elements in `shape` are ones.
+ static bool isSameSizedSuffixShape(ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> suffix) {
+ if (shape.size() >= suffix.size()) {
+ ArrayRef<int64_t> prefix = shape.take_front(shape.size() - suffix.size());
+ ArrayRef<int64_t> remainder = shape.take_back(suffix.size());
+
+ return llvm::all_of(prefix, [](int64_t d) { return d == 1; }) &&
+ remainder == suffix;
+ }
+
+ return false;
+ }
+};
} // namespace
llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
@@ -2845,7 +2930,8 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
InsertSliceOpCastFolder<InsertSliceOp>,
- InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
+ InsertSliceOpSourceCastInserter<InsertSliceOp>,
+ InsertSliceOpFullRewriteCanonicalizer>(context);
}
Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 914e5e8b8c4b8..8e66ef9f89c74 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
// -----
+// CHECK-LABEL: func @trivial_insert_slice_unit_prefix
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+// CHECK-NOT: tensor.insert_slice
+// CHECK: %[[EXPANDED:.[a-z0-9A-Z_]+]] = tensor.expand_shape %[[ARG0]] {{\[\[0, 1, 2, 3\], \[4\], \[5\], \[6\]\] output}}_shape {{\[1, 1, 1, 4, 6, 16, 32\]}} : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8>
+// CHECK: return %[[EXPANDED]] : tensor<1x1x1x4x6x16x32xi8>
+func.func @trivial_insert_slice_unit_prefix(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<1x1x1x4x6x16x32xi8>) -> tensor<1x1x1x4x6x16x32xi8> {
+ %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 4, 6, 16, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8>
+ return %0 : tensor<1x1x1x4x6x16x32xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @empty_insert_slice
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
More information about the Mlir-commits
mailing list