[Mlir-commits] [mlir] 14e7846 - [mlir][Tensor] Fold destination-style ops into `tensor.unpack` operation. (#71468)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 21:42:36 PST 2023


Author: MaheshRavishankar
Date: 2023-11-07T21:42:32-08:00
New Revision: 14e7846d6e2c5c99c52ba3882e59bfb021a5f0fa

URL: https://github.com/llvm/llvm-project/commit/14e7846d6e2c5c99c52ba3882e59bfb021a5f0fa
DIFF: https://github.com/llvm/llvm-project/commit/14e7846d6e2c5c99c52ba3882e59bfb021a5f0fa.diff

LOG: [mlir][Tensor] Fold destination-style ops into `tensor.unpack` operation. (#71468)

The destination operand of the `tensor.unpack` operation is only needed
to carry shape information. So if the producer of the destination
operand implements the `DestinationStyleOpInterface`, then fold it into
the `tensor.unpack` operation by replacing the destination operand with
the destination for the source.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f719cfed6b6dd30..6fc45379111fc34 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3922,18 +3922,29 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
                             metadata.outerDimsPerm);
 }
 
-/// pack(unpack(x)) -> x
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
-  PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
-  if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
-    return failure();
-  if (packOp.getPaddingValue() ||
-      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-      !haveSameTiles(packOp, unPackOp))
-    return failure();
-  rewriter.replaceOp(unPackOp, packOp.getSource());
-  return success();
+  /// pack(unpack(x)) -> x
+  if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
+    if (packOp.getDestType() != unPackOp.getSourceType())
+      return failure();
+    if (packOp.getPaddingValue() ||
+        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !haveSameTiles(packOp, unPackOp))
+      return failure();
+    rewriter.replaceOp(unPackOp, packOp.getSource());
+    return success();
+  }
+  /// unpack(destinationStyleOp(x)) -> unpack(x)
+  if (auto dstStyleOp =
+          unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
+    auto destValue = unPackOp.getDest().cast<OpResult>();
+    Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
+    rewriter.updateRootInPlace(
+        unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
+    return success();
+  }
+  return failure();
 }
 
 bool UnPackOp::isLikeUnPad() {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index c40c9efeb152ac6..1078ee3b59a4306 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
   %1 = tensor.empty(%0) : tensor<4x5x?xf32>
   return %1 : tensor<4x5x?xf32>
 }
+
+// -----
+
+// Fold DstStyleOp -> tensor.unpack operations.
+func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
+  return %unpack : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x16x64xf32>
+//  CHECK-SAME:     %[[INIT:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+//  CHECK-SAME:       into %[[INIT]]
+//       CHECK:   return %[[UNPACK]]


        


More information about the Mlir-commits mailing list