[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.pack static shapes inference. (PR #80848)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Feb 13 00:28:53 PST 2024
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/80848
>From 3511d061d71fa9d3edaac143c7ab87dc89220b41 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 6 Feb 2024 23:23:45 +0800
Subject: [PATCH 1/2] [mlir][tensor] Add support for tensor.pack static shapes
inference.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 60 ++++++++++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 39 ++++++++++++++
2 files changed, 99 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a5713..737f897fd4fd41 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
op.getMixedTiles());
}
+// Returns true if the `srcShape` or `destShape` is different from the one in
+// `packOp`.
+static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(packOp.getSourceType().getShape().begin(),
+ packOp.getSourceType().getShape().end());
+ destShape.assign(packOp.getDestType().getShape().begin(),
+ packOp.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(packOp.getInnerDimsPos().begin(),
+ packOp.getInnerDimsPos().end());
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ int srcRank = packOp.getSourceRank();
+ for (auto i : llvm::seq<int64_t>(0, srcRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ destPos = outerDimsPerm[srcPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.finalizeOpModification(packOp);
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(packOp, srcShape, destShape)) {
+ Location loc = packOp.getLoc();
+ Value source = packOp.getSource();
+ if (srcShape != packOp.getSourceType().getShape()) {
+ auto newSrcType = packOp.getSourceType().clone(srcShape);
+ source =
+ rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ }
+ Value dest = packOp.getDest();
+ if (destShape != packOp.getDestType().getShape()) {
+ auto newDestType = packOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::PackOp>(
+ loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+ packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ packOp, packOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7192a719ceb13d..a8e08241d28c06 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -791,6 +791,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312
// -----
+func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %src
+ padding_value(%cst : f32)
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
+ return %pack : tensor<10x20x30x40x16xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_pack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
+// CHECK: return %[[PACK]]
+
+// -----
+
+func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %src
+ padding_value(%cst : f32)
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
+ return %pack : tensor<?x?x?x?x16xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_pack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK: return %[[CAST_PACK]]
+
+// -----
+
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
>From c1c91c7b72defc1959bdcd280638d96ec6177007 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 13 Feb 2024 16:28:38 +0800
Subject: [PATCH 2/2] update comments
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 737f897fd4fd41..28d302196de9f5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3983,8 +3983,8 @@ static bool paddingIsNotNeeded(PackOp op) {
op.getMixedTiles());
}
-// Returns true if the `srcShape` or `destShape` is different from the one in
-// `packOp`.
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `packOp` and populates each with the inferred static shape.
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
SmallVectorImpl<int64_t> &destShape) {
bool changeNeeded = false;
More information about the Mlir-commits
mailing list