[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.unpack static shapes inference. (PR #81702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 20:26:32 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value.

This is a follow-up of https://github.com/llvm/llvm-project/pull/80848

---
Full diff: https://github.com/llvm/llvm-project/pull/81702.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+59) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+50) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb72cba96ad935..1df15c0372e6e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4212,6 +4212,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
                             metadata.outerDimsPerm);
 }
 
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+                             SmallVectorImpl<int64_t> &destShape) {
+  bool changeNeeded = false;
+  srcShape.assign(op.getSourceType().getShape().begin(),
+                  op.getSourceType().getShape().end());
+  destShape.assign(op.getDestType().getShape().begin(),
+                   op.getDestType().getShape().end());
+  llvm::SmallSetVector<int64_t, 4> innerDims;
+  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+  auto outerDimsPerm = op.getOuterDimsPerm();
+  int destRank = op.getDestRank();
+  for (auto i : llvm::seq<int64_t>(0, destRank)) {
+    if (innerDims.contains(i))
+      continue;
+    int64_t srcPos = i;
+    int64_t destPos = i;
+    if (!outerDimsPerm.empty())
+      srcPos = outerDimsPerm[destPos];
+    if (ShapedType::isDynamic(srcShape[srcPos]) ==
+        ShapedType::isDynamic(destShape[destPos])) {
+      continue;
+    }
+    int64_t size = srcShape[srcPos];
+    if (ShapedType::isDynamic(size))
+      size = destShape[destPos];
+    srcShape[srcPos] = size;
+    destShape[destPos] = size;
+    changeNeeded = true;
+  }
+  return changeNeeded;
+}
+
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
   /// pack(unpack(x)) -> x
@@ -4234,6 +4268,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(unPackOp, srcShape, destShape)) {
+    Location loc = unPackOp.getLoc();
+    Value source = unPackOp.getSource();
+    if (srcShape != unPackOp.getSourceType().getShape()) {
+      auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+      source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+                                               unPackOp.getSource());
+    }
+    Value dest = unPackOp.getDest();
+    if (destShape != unPackOp.getDestType().getShape()) {
+      auto newDestType = unPackOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::UnPackOp>(
+        loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+        unPackOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        unPackOp, unPackOp.getResult().getType(), newOp);
+    return success();
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..35619d098f008a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
 
 // -----
 
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+  return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK:         return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+  return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK:         return %[[UNPACK]]
+
+// -----
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,18 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
   } : tensor<?x8xi32>
   return %tensor : tensor<?x8xi32>
 }
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
+    %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
+    %tile2: index) -> tensor<10x20x?x?xf32> {
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
+  %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
+  return %packed : tensor<10x20x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK:         return %[[SRC]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81702


More information about the Mlir-commits mailing list