[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.unpack static shapes inference. (PR #81702)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 13 20:26:32 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Han-Chung Wang (hanhanW)
<details>
<summary>Changes</summary>
The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value.
This is a follow-up of https://github.com/llvm/llvm-project/pull/80848
---
Full diff: https://github.com/llvm/llvm-project/pull/81702.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+59)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+50)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb72cba96ad935..1df15c0372e6e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4212,6 +4212,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
metadata.outerDimsPerm);
}
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(op.getSourceType().getShape().begin(),
+ op.getSourceType().getShape().end());
+ destShape.assign(op.getDestType().getShape().begin(),
+ op.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+ auto outerDimsPerm = op.getOuterDimsPerm();
+ int destRank = op.getDestRank();
+ for (auto i : llvm::seq<int64_t>(0, destRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ srcPos = outerDimsPerm[destPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
/// pack(unpack(x)) -> x
@@ -4234,6 +4268,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(unPackOp, srcShape, destShape)) {
+ Location loc = unPackOp.getLoc();
+ Value source = unPackOp.getSource();
+ if (srcShape != unPackOp.getSourceType().getShape()) {
+ auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+ source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+ unPackOp.getSource());
+ }
+ Value dest = unPackOp.getDest();
+ if (destShape != unPackOp.getDestType().getShape()) {
+ auto newDestType = unPackOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::UnPackOp>(
+ loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ unPackOp, unPackOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..35619d098f008a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
// -----
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+ return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK: return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+ return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,18 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
} : tensor<?x8xi32>
return %tensor : tensor<?x8xi32>
}
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
+ %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
+ %tile2: index) -> tensor<10x20x?x?xf32> {
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
+ %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
+ return %packed : tensor<10x20x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK: return %[[SRC]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81702
More information about the Mlir-commits
mailing list