[Mlir-commits] [mlir] [mlir][tensor] Add support for tensor.pack static shapes inference. (PR #80848)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 6 07:29:41 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/80848.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+60) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+39) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a571..737f897fd4fd4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
                                       op.getMixedTiles());
 }
 
+// Returns true if the `srcShape` or `destShape` is different from the one in
+// `packOp`.
+static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
+                             SmallVectorImpl<int64_t> &destShape) {
+  bool changeNeeded = false;
+  srcShape.assign(packOp.getSourceType().getShape().begin(),
+                  packOp.getSourceType().getShape().end());
+  destShape.assign(packOp.getDestType().getShape().begin(),
+                   packOp.getDestType().getShape().end());
+  llvm::SmallSetVector<int64_t, 4> innerDims;
+  innerDims.insert(packOp.getInnerDimsPos().begin(),
+                   packOp.getInnerDimsPos().end());
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  int srcRank = packOp.getSourceRank();
+  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
+    if (innerDims.contains(i))
+      continue;
+    int64_t srcPos = i;
+    int64_t destPos = i;
+    if (!outerDimsPerm.empty())
+      destPos = outerDimsPerm[srcPos];
+    if (ShapedType::isDynamic(srcShape[srcPos]) ==
+        ShapedType::isDynamic(destShape[destPos])) {
+      continue;
+    }
+    int64_t size = srcShape[srcPos];
+    if (ShapedType::isDynamic(size))
+      size = destShape[destPos];
+    srcShape[srcPos] = size;
+    destShape[destPos] = size;
+    changeNeeded = true;
+  }
+  return changeNeeded;
+}
+
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
   // Fold an unpack(pack(x)) to x.
   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.finalizeOpModification(packOp);
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(packOp, srcShape, destShape)) {
+    Location loc = packOp.getLoc();
+    Value source = packOp.getSource();
+    if (srcShape != packOp.getSourceType().getShape()) {
+      auto newSrcType = packOp.getSourceType().clone(srcShape);
+      source =
+          rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+    }
+    Value dest = packOp.getDest();
+    if (destShape != packOp.getDestType().getShape()) {
+      auto newDestType = packOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::PackOp>(
+        loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+        packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        packOp, packOp.getResult().getType(), newOp);
+    return success();
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7192a719ceb13..a8e08241d28c0 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -791,6 +791,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312
 
 // -----
 
+func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+   %pack = tensor.pack %src
+    padding_value(%cst : f32)
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
+  return %pack : tensor<10x20x30x40x16xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_pack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
+// CHECK:         return %[[PACK]]
+
+// -----
+
+func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+   %pack = tensor.pack %src
+    padding_value(%cst : f32)
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
+  return %pack : tensor<?x?x?x?x16xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_pack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK:         return %[[CAST_PACK]]
+
+// -----
+
 func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/80848


More information about the Mlir-commits mailing list