[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.unpack static shapes inference. (PR #81702)

Han-Chung Wang llvmlistbot at llvm.org
Mon Feb 19 14:22:09 PST 2024


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/81702

>From 9f47c24b4b5b7512376114d419edbf9ad0ee8b21 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 14 Feb 2024 12:20:36 +0800
Subject: [PATCH 1/2] [mlir][tensor] Add support for tensor.unpack static
 shapes inference.

The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be infered (i.e., with inner_tile_sizes) if the pack op does not
have padding value.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 59 ++++++++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 50 ++++++++++++++++++
 2 files changed, 109 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb72cba96ad935..1df15c0372e6e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4212,6 +4212,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
                             metadata.outerDimsPerm);
 }
 
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+                             SmallVectorImpl<int64_t> &destShape) {
+  bool changeNeeded = false;
+  srcShape.assign(op.getSourceType().getShape().begin(),
+                  op.getSourceType().getShape().end());
+  destShape.assign(op.getDestType().getShape().begin(),
+                   op.getDestType().getShape().end());
+  llvm::SmallSetVector<int64_t, 4> innerDims;
+  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+  auto outerDimsPerm = op.getOuterDimsPerm();
+  int destRank = op.getDestRank();
+  for (auto i : llvm::seq<int64_t>(0, destRank)) {
+    if (innerDims.contains(i))
+      continue;
+    int64_t srcPos = i;
+    int64_t destPos = i;
+    if (!outerDimsPerm.empty())
+      srcPos = outerDimsPerm[destPos];
+    if (ShapedType::isDynamic(srcShape[srcPos]) ==
+        ShapedType::isDynamic(destShape[destPos])) {
+      continue;
+    }
+    int64_t size = srcShape[srcPos];
+    if (ShapedType::isDynamic(size))
+      size = destShape[destPos];
+    srcShape[srcPos] = size;
+    destShape[destPos] = size;
+    changeNeeded = true;
+  }
+  return changeNeeded;
+}
+
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
   /// pack(unpack(x)) -> x
@@ -4234,6 +4268,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(unPackOp, srcShape, destShape)) {
+    Location loc = unPackOp.getLoc();
+    Value source = unPackOp.getSource();
+    if (srcShape != unPackOp.getSourceType().getShape()) {
+      auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+      source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+                                               unPackOp.getSource());
+    }
+    Value dest = unPackOp.getDest();
+    if (destShape != unPackOp.getDestType().getShape()) {
+      auto newDestType = unPackOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::UnPackOp>(
+        loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+        unPackOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        unPackOp, unPackOp.getResult().getType(), newOp);
+    return success();
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..35619d098f008a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
 
 // -----
 
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+  return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK:         return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+  return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK:         return %[[UNPACK]]
+
+// -----
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,18 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
   } : tensor<?x8xi32>
   return %tensor : tensor<?x8xi32>
 }
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
+    %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
+    %tile2: index) -> tensor<10x20x?x?xf32> {
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
+  %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
+  return %packed : tensor<10x20x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK:         return %[[SRC]]

>From d852e6fdc1a33b40853850f39fe7cc0d397b3bfe Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 19 Feb 2024 14:21:36 -0800
Subject: [PATCH 2/2] update tests

---
 mlir/test/Dialect/Tensor/canonicalize.mlir | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 35619d098f008a..e123c77aabd57c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2214,14 +2214,15 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
 
 // -----
 
-func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
-    %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
-    %tile2: index) -> tensor<10x20x?x?xf32> {
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
+  %dim1 = arith.constant 40 : index
+  %dim2 = arith.constant 80 : index
   %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
-  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
-  %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
-  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
-  return %packed : tensor<10x20x?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
+  %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
+  %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
+  %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
+  return %packed : tensor<10x20x4x4xf32>
 }
 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]



More information about the Mlir-commits mailing list