[Mlir-commits] [mlir] bc08cc2 - [mlir][tensor] Add support for tensor.pack static shapes inference. (#80848)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 13 20:20:27 PST 2024
Author: Han-Chung Wang
Date: 2024-02-13T20:20:24-08:00
New Revision: bc08cc2ac8b0fc0898d191e36db08d136d659f7d
URL: https://github.com/llvm/llvm-project/commit/bc08cc2ac8b0fc0898d191e36db08d136d659f7d
DIFF: https://github.com/llvm/llvm-project/commit/bc08cc2ac8b0fc0898d191e36db08d136d659f7d.diff
LOG: [mlir][tensor] Add support for tensor.pack static shapes inference. (#80848)
Fixes https://github.com/openxla/iree/issues/16317
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8298cf102e28a3..bb72cba96ad935 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
op.getMixedTiles());
}
+/// Returns true if the `srcShape` or `destShape` is
diff erent from the one in
+/// `packOp` and populates each with the inferred static shape.
+static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(packOp.getSourceType().getShape().begin(),
+ packOp.getSourceType().getShape().end());
+ destShape.assign(packOp.getDestType().getShape().begin(),
+ packOp.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(packOp.getInnerDimsPos().begin(),
+ packOp.getInnerDimsPos().end());
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ int srcRank = packOp.getSourceRank();
+ for (auto i : llvm::seq<int64_t>(0, srcRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ destPos = outerDimsPerm[srcPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.finalizeOpModification(packOp);
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(packOp, srcShape, destShape)) {
+ Location loc = packOp.getLoc();
+ Value source = packOp.getSource();
+ if (srcShape != packOp.getSourceType().getShape()) {
+ auto newSrcType = packOp.getSourceType().clone(srcShape);
+ source =
+ rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ }
+ Value dest = packOp.getDest();
+ if (destShape != packOp.getDestType().getShape()) {
+ auto newDestType = packOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::PackOp>(
+ loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+ packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ packOp, packOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90c715bf2eb2da..3b6cd799a6f348 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -809,6 +809,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312
// -----
+func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %src
+ padding_value(%cst : f32)
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
+ return %pack : tensor<10x20x30x40x16xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_pack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
+// CHECK: return %[[PACK]]
+
+// -----
+
+func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %src
+ padding_value(%cst : f32)
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
+ return %pack : tensor<?x?x?x?x16xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_pack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK: return %[[CAST_PACK]]
+
+// -----
+
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
More information about the Mlir-commits
mailing list