[Mlir-commits] [mlir] [mlir][linalg] Handle unset outer_dims_perm in data layout propagation (PR #98040)

James Newling llvmlistbot at llvm.org
Mon Jul 8 09:12:26 PDT 2024


https://github.com/newling created https://github.com/llvm/llvm-project/pull/98040

Adds check that the optional `outer_dims_perm` is set, and adds logic to handle the case where it is not:

When bubbling a pack or pushing down an unpack, an identity permutation on the outer dimensions remains an identity permutation (with more or fewer dimensions). 

Fixes assertion failure observed when running iree's `-iree-preprocessing-convert-conv-to-channels-last` which calls into MLIR's data-layout-propagation. 

Included in PR: some whitespace fixes, unused header removal, clang-tidy suggestion changes. 

>From e131982a67d61e8ae1cb5cfd3f92ee749cc0fe25 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 5 Jul 2024 09:26:24 -0700
Subject: [PATCH] improve code, format

---
 .../Transforms/DataLayoutPropagation.cpp      | 25 +++---
 .../Linalg/data-layout-propagation.mlir       | 83 +++++++++++++------
 2 files changed, 71 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..a097b87380763 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,17 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Linalg/Passes.h"
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -36,10 +32,9 @@ using namespace mlir::linalg;
 namespace {
 
 static bool hasGatherSemantics(linalg::GenericOp genericOp) {
-  for (Operation &op : genericOp.getBody()->getOperations())
-    if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
-      return true;
-  return false;
+  return llvm::any_of(genericOp.getBody()->getOperations(), [](Operation &op) {
+    return isa<tensor::ExtractOp, linalg::IndexOp>(op);
+  });
 }
 
 // The struct contains the infomation about mapping packing information to
@@ -912,10 +907,16 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   // new permutation after pushing. This is because moving a source dim is
   // equivalent to moving the associated expanded dims together.
   SmallVector<int64_t> newOuterDimsPerm;
-  for (auto outerPos : outerDimsPerm) {
-    newOuterDimsPerm.insert(newOuterDimsPerm.end(),
-                            reassocIndices[outerPos].begin(),
-                            reassocIndices[outerPos].end());
+  if (!outerDimsPerm.empty()) {
+    for (auto outerPos : outerDimsPerm) {
+      newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+                              reassocIndices[outerPos].begin(),
+                              reassocIndices[outerPos].end());
+    }
+  } else {
+    // If 'outerDimsPerm' is empty, it denotes the identity permutation. If
+    // 'outerDimsPerm' is the identity permutation, then so is the replacement
+    // permutation: leaving 'newOuterDimsPerm' empty denotes this.
   }
 
   SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 626dd8b697e59..71f87f6a39117 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 // CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
 // CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[ARG0_PACK]]
 // CHECK-SAME:      outs(%[[ARG1_PACK]]
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-LABEL: func.func @forward_tensor_empty
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 // CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
 // CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:      ins(%[[PACKED_ARG0]]
 // CHECK-SAME:      outs(%[[DEST]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
 
 // -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
 // CHECK:         %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
 // CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
 // CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
 // CHECK:         %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
 
@@ -713,7 +713,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
 // CHECK-SAME:      outs(%[[EMPTY]]
 // CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
 // CHECK-NEXT:    %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32] 
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
 // CHECK-SAME:      into %[[ALLOC]]
 
 // -----
@@ -760,19 +760,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
 
 // CHECK-LABEL: func.func @unpack_empty_inner_dims
 // CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
-// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 // CHECK:         %[[RES:.+]] = linalg.generic
 // CHECK-SAME:      ins(%[[PACKED_ARG0]]
 // CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
 
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, 
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
       %arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
   %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
       ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +810,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
 
 // -----
 
-func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, 
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
   %arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
 {
   %reduction = linalg.generic {
@@ -867,7 +867,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
 #map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
 #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
-func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>, 
+func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
     %filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
   %init = tensor.empty() : tensor<16x540x960xi32>
   %empty = tensor.empty() : tensor<1x16x1080x1920xi32>
@@ -1370,3 +1370,36 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
 // CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+func.func @push_down_unpack_with_no_outer_dims(%1: tensor<1x512x3xf32>) -> tensor<1x3x512xf32> {
+  %0 = tensor.empty() : tensor<3x512xf32>
+  %unpack = tensor.unpack %1 inner_dims_pos = [0] inner_tiles = [3] into %0 : tensor<1x512x3xf32> -> tensor<3x512xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 3, 512] : tensor<3x512xf32> into tensor<1x3x512xf32>
+  func.return %expanded : tensor<1x3x512xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_with_no_outer_dims
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [1, 1, 512, 3] : tensor<1x512x3xf32> into tensor<1x1x512x3xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x3x512xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] inner_dims_pos = [1] inner_tiles = [3] into %[[EMPTY]] : tensor<1x1x512x3xf32> -> tensor<1x3x512xf32>
+// CHECK:         return %[[UNPACK]] : tensor<1x3x512xf32>
+
+// -----
+
+func.func @bubble_up_pack_with_no_outer_dims(%arg0: tensor<1x2x3xf32>) -> tensor<1x3x2xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x2x3xf32> into tensor<2x3xf32>
+  %1 = tensor.empty() : tensor<1x3x2xf32>
+  %2 = tensor.pack %0 inner_dims_pos = [0] inner_tiles = [2] into %1 : tensor<2x3xf32> -> tensor<1x3x2xf32>
+  return %2 : tensor<1x3x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_with_no_outer_dims
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x3x2xf32>
+// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:     inner_dims_pos = [1] inner_tiles = [2] into %[[EMPTY]]
+// CHECK-SAME:     : tensor<1x2x3xf32> -> tensor<1x1x3x2xf32>
+// CHECK:         %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACKED]] {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME:     : tensor<1x1x3x2xf32> into tensor<1x3x2xf32>
+// CHECK:         return %[[COLLAPSED]] : tensor<1x3x2xf32>



More information about the Mlir-commits mailing list