[Mlir-commits] [mlir] eac8604 - [mlir][tensor] Add support for tensor.unpack static shapes inference. (#81702)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 19 16:26:16 PST 2024
Author: Han-Chung Wang
Date: 2024-02-19T16:26:12-08:00
New Revision: eac8604d989cb4220367937bae04937e67b9001b
URL: https://github.com/llvm/llvm-project/commit/eac8604d989cb4220367937bae04937e67b9001b
DIFF: https://github.com/llvm/llvm-project/commit/eac8604d989cb4220367937bae04937e67b9001b.diff
LOG: [mlir][tensor] Add support for tensor.unpack static shapes inference. (#81702)
The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be inferred (i.e., with inner_tile_sizes) if the pack op does not
have padding value.
This is a follow-up of https://github.com/llvm/llvm-project/pull/80848
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 62bf6b48f18b4a..e6efec14e31a60 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4229,6 +4229,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
metadata.outerDimsPerm);
}
+/// Returns true if the `srcShape` or `destShape` is
diff erent from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(op.getSourceType().getShape().begin(),
+ op.getSourceType().getShape().end());
+ destShape.assign(op.getDestType().getShape().begin(),
+ op.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+ auto outerDimsPerm = op.getOuterDimsPerm();
+ int destRank = op.getDestRank();
+ for (auto i : llvm::seq<int64_t>(0, destRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ srcPos = outerDimsPerm[destPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
/// pack(unpack(x)) -> x
@@ -4251,6 +4285,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(unPackOp, srcShape, destShape)) {
+ Location loc = unPackOp.getLoc();
+ Value source = unPackOp.getSource();
+ if (srcShape != unPackOp.getSourceType().getShape()) {
+ auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+ source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+ unPackOp.getSource());
+ }
+ Value dest = unPackOp.getDest();
+ if (destShape != unPackOp.getDestType().getShape()) {
+ auto newDestType = unPackOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::UnPackOp>(
+ loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ unPackOp, unPackOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..e123c77aabd57c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
// -----
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+ return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK: return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+ return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,19 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
} : tensor<?x8xi32>
return %tensor : tensor<?x8xi32>
}
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
+ %dim1 = arith.constant 40 : index
+ %dim2 = arith.constant 80 : index
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
+ %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
+ %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
+ %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
+ return %packed : tensor<10x20x4x4xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK: return %[[SRC]]
More information about the Mlir-commits
mailing list