[Mlir-commits] [mlir] [MLIR][Linalg] Fix DataLayoutPropagation for tensor.unpack + linalg.generic (PR #101755)

Abhishek Varma llvmlistbot at llvm.org
Fri Aug 2 14:35:14 PDT 2024


https://github.com/Abhishek-Varma created https://github.com/llvm/llvm-project/pull/101755

-- While pushing down tensor.unpack through linalg.generic we should take into account DPS. The current implementation was enforcing creating a tensor.empty() for the final output value. This should've just been the outs operand of the original linalg.generic.
-- This commit thus adds a fix for the same.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>

>From 4c9b9112c35aa50a1fc3754c2366ca01c2078a96 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 2 Aug 2024 21:21:44 +0000
Subject: [PATCH] [MLIR][Linalg] Fix DataLayoutPropagation for tensor.unpack +
 linalg.generic

-- While pushing down tensor.unpack through linalg.generic we
   should take into account DPS. The current implementation was
   enforcing creating a tensor.empty() for the final output value.
   This should've just been the outs operand of the original linalg.generic.
-- This commit thus adds a fix for the same.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Linalg/Transforms/DataLayoutPropagation.cpp  | 16 ++--------------
 .../Dialect/Linalg/data-layout-propagation.mlir  | 14 +++++++-------
 2 files changed, 9 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6ea6cda74c446..0741e147cdd69 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -1106,23 +1106,11 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
   auto innerDimsPos = destPack.getInnerDimsPos();
   auto outerDimsPerm = destPack.getOuterDimsPerm();
 
-  // If the output type for the generic differs from the source
-  // unpack op, we need to create a new destination tensor. In the
-  // dynamic case we always need a new destination.
-  auto loc = genericOp.getLoc();
-  Value unPackDest = producerUnPackOp.getDest();
-  auto genericOutType =
-      cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
-  if (producerUnPackOp.getDestType() != genericOutType ||
-      !genericOutType.hasStaticShape()) {
-    unPackDest = tensor::UnPackOp::createDestinationTensor(
-        rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
-  }
-
   // Insert an unPackOp right after the packed generic.
   Value unPackOpRes =
       rewriter
-          .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
+          .create<tensor::UnPackOp>(genericOp.getLoc(), newResult,
+                                    destPack.getSource(), innerDimsPos,
                                     mixedTiles, outerDimsPerm)
           .getResult();
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index d9206432379fb..07708231a6e2f 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -436,7 +436,7 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
 // CHECK-SAME:      outs(%[[PACKED_ARG0]]
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_EMPTY_UNPACK]]
+// CHECK-SAME:      into %[[UNPACKED_ARG0]]
 
 // -----
 
@@ -475,7 +475,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME:      into %[[ARG1]]
 
 // -----
 
@@ -512,10 +512,9 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0_PACK]]
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
-// CHECK:         %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_NEW_EMPTY_UNPACK]]
+// CHECK-SAME:      into %[[ARG1]]
 
 // -----
 
@@ -536,6 +535,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
 // CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
@@ -551,7 +551,7 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-SAME:      outs(%[[DEST]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME:      into %[[FINAL_RES]]
 
 // -----
 
@@ -913,6 +913,7 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-LABEL: func.func @unpack_different_destination_shape
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[FINAL_RES:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
 // CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
 // CHECK:         %[[PACK_ARG0:.+]] = tensor.pack
@@ -923,10 +924,9 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
 // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
 // CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
 // CHECK-SAME:      outs(%[[INIT]]
-// CHECK:         %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
 // CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
-// CHECK-SAME:      into %[[UNPACK_NEW_DEST]]
+// CHECK-SAME:      into %[[FINAL_RES]]
 // CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>
 
 // -----



More information about the Mlir-commits mailing list