[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.unpack static shapes inference. (PR #81702)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Feb 13 20:25:58 PST 2024
https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/81702
The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be infered (i.e., with inner_tile_sizes) if the pack op does not have padding value.
This is a follow-up of https://github.com/llvm/llvm-project/pull/80848
>From 9f47c24b4b5b7512376114d419edbf9ad0ee8b21 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 14 Feb 2024 12:20:36 +0800
Subject: [PATCH] [mlir][tensor] Add support for tensor.unpack static shapes
inference.
The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be infered (i.e., with inner_tile_sizes) if the pack op does not
have padding value.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 59 ++++++++++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 50 ++++++++++++++++++
2 files changed, 109 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb72cba96ad935..1df15c0372e6e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4212,6 +4212,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
metadata.outerDimsPerm);
}
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(op.getSourceType().getShape().begin(),
+ op.getSourceType().getShape().end());
+ destShape.assign(op.getDestType().getShape().begin(),
+ op.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+ auto outerDimsPerm = op.getOuterDimsPerm();
+ int destRank = op.getDestRank();
+ for (auto i : llvm::seq<int64_t>(0, destRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ srcPos = outerDimsPerm[destPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
/// pack(unpack(x)) -> x
@@ -4234,6 +4268,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(unPackOp, srcShape, destShape)) {
+ Location loc = unPackOp.getLoc();
+ Value source = unPackOp.getSource();
+ if (srcShape != unPackOp.getSourceType().getShape()) {
+ auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+ source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+ unPackOp.getSource());
+ }
+ Value dest = unPackOp.getDest();
+ if (destShape != unPackOp.getDestType().getShape()) {
+ auto newDestType = unPackOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::UnPackOp>(
+ loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ unPackOp, unPackOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..35619d098f008a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
// -----
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+ return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK: return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+ return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,18 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
} : tensor<?x8xi32>
return %tensor : tensor<?x8xi32>
}
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
+ %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
+ %tile2: index) -> tensor<10x20x?x?xf32> {
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
+ %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
+ return %packed : tensor<10x20x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK: return %[[SRC]]
More information about the Mlir-commits
mailing list