[Mlir-commits] [mlir] eac8604 - [mlir][tensor] Add support for tensor.unpack static shapes inference. (#81702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 19 16:26:16 PST 2024

Author: Han-Chung Wang
Date: 2024-02-19T16:26:12-08:00
New Revision: eac8604d989cb4220367937bae04937e67b9001b

URL: https://github.com/llvm/llvm-project/commit/eac8604d989cb4220367937bae04937e67b9001b
DIFF: https://github.com/llvm/llvm-project/commit/eac8604d989cb4220367937bae04937e67b9001b.diff

LOG: [mlir][tensor] Add support for tensor.unpack static shapes inference. (#81702)

The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be inferred (i.e., with inner_tile_sizes) if the pack op does not
have padding value.

This is a follow-up of https://github.com/llvm/llvm-project/pull/80848




diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 62bf6b48f18b4a..e6efec14e31a60 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4229,6 +4229,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
+/// Returns true if the `srcShape` or `destShape` is 
diff erent from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+                             SmallVectorImpl<int64_t> &destShape) {
+  bool changeNeeded = false;
+  srcShape.assign(op.getSourceType().getShape().begin(),
+                  op.getSourceType().getShape().end());
+  destShape.assign(op.getDestType().getShape().begin(),
+                   op.getDestType().getShape().end());
+  llvm::SmallSetVector<int64_t, 4> innerDims;
+  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+  auto outerDimsPerm = op.getOuterDimsPerm();
+  int destRank = op.getDestRank();
+  for (auto i : llvm::seq<int64_t>(0, destRank)) {
+    if (innerDims.contains(i))
+      continue;
+    int64_t srcPos = i;
+    int64_t destPos = i;
+    if (!outerDimsPerm.empty())
+      srcPos = outerDimsPerm[destPos];
+    if (ShapedType::isDynamic(srcShape[srcPos]) ==
+        ShapedType::isDynamic(destShape[destPos])) {
+      continue;
+    }
+    int64_t size = srcShape[srcPos];
+    if (ShapedType::isDynamic(size))
+      size = destShape[destPos];
+    srcShape[srcPos] = size;
+    destShape[destPos] = size;
+    changeNeeded = true;
+  }
+  return changeNeeded;
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
   /// pack(unpack(x)) -> x
@@ -4251,6 +4285,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(unPackOp, srcShape, destShape)) {
+    Location loc = unPackOp.getLoc();
+    Value source = unPackOp.getSource();
+    if (srcShape != unPackOp.getSourceType().getShape()) {
+      auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+      source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+                                               unPackOp.getSource());
+    }
+    Value dest = unPackOp.getDest();
+    if (destShape != unPackOp.getDestType().getShape()) {
+      auto newDestType = unPackOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::UnPackOp>(
+        loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+        unPackOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        unPackOp, unPackOp.getResult().getType(), newOp);
+    return success();
+  }
   return failure();

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..e123c77aabd57c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
 // -----
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+  return %unpack : tensor<?x?x?x?xf32>
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK:         return %[[CAST_UNPACK]]
+// -----
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+  return %unpack : tensor<30x20x?x10xf32>
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK:         return %[[UNPACK]]
+// -----
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,19 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
   } : tensor<?x8xi32>
   return %tensor : tensor<?x8xi32>
+// -----
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
+  %dim1 = arith.constant 40 : index
+  %dim2 = arith.constant 80 : index
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
+  %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
+  %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
+  %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
+  return %packed : tensor<10x20x4x4xf32>
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK:         return %[[SRC]]


More information about the Mlir-commits mailing list