[Mlir-commits] [mlir] [mlir][linalg] Handle unset outer_dims_perm in data layout propagation (PR #98040)
James Newling
llvmlistbot at llvm.org
Mon Jul 8 09:12:26 PDT 2024
https://github.com/newling created https://github.com/llvm/llvm-project/pull/98040
Adds check that the optional `outer_dims_perm` is set, and adds logic to handle the case where it is not:
When bubbling a pack or pushing down an unpack, an identity permutation on the outer dimensions remains an identity permutation (with more or fewer dimensions).
Fixes assertion failure observed when running iree's `-iree-preprocessing-convert-conv-to-channels-last` which calls into MLIR's data-layout-propagation.
Included in PR: some whitespace fixes, unused header removal, clang-tidy suggestion changes.
>From e131982a67d61e8ae1cb5cfd3f92ee749cc0fe25 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 5 Jul 2024 09:26:24 -0700
Subject: [PATCH] improve code, format
---
.../Transforms/DataLayoutPropagation.cpp | 25 +++---
.../Linalg/data-layout-propagation.mlir | 83 +++++++++++++------
2 files changed, 71 insertions(+), 37 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6984bc2dff498..a097b87380763 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,17 +6,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -36,10 +32,9 @@ using namespace mlir::linalg;
namespace {
static bool hasGatherSemantics(linalg::GenericOp genericOp) {
- for (Operation &op : genericOp.getBody()->getOperations())
- if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
- return true;
- return false;
+ return llvm::any_of(genericOp.getBody()->getOperations(), [](Operation &op) {
+ return isa<tensor::ExtractOp, linalg::IndexOp>(op);
+ });
}
// The struct contains the infomation about mapping packing information to
@@ -912,10 +907,16 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
// new permutation after pushing. This is because moving a source dim is
// equivalent to moving the associated expanded dims together.
SmallVector<int64_t> newOuterDimsPerm;
- for (auto outerPos : outerDimsPerm) {
- newOuterDimsPerm.insert(newOuterDimsPerm.end(),
- reassocIndices[outerPos].begin(),
- reassocIndices[outerPos].end());
+ if (!outerDimsPerm.empty()) {
+ for (auto outerPos : outerDimsPerm) {
+ newOuterDimsPerm.insert(newOuterDimsPerm.end(),
+ reassocIndices[outerPos].begin(),
+ reassocIndices[outerPos].end());
+ }
+ } else {
+ // If 'outerDimsPerm' is empty, it denotes the identity permutation. If
+ // 'outerDimsPerm' is the identity permutation, then so is the replacement
+ // permutation: leaving 'newOuterDimsPerm' empty denotes this.
}
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 626dd8b697e59..71f87f6a39117 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -458,23 +458,23 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0_PACK]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// -----
@@ -537,20 +537,20 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-LABEL: func.func @forward_tensor_empty
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
// -----
@@ -571,8 +571,8 @@ func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tens
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
// -----
@@ -614,8 +614,8 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
@@ -713,7 +713,7 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
// CHECK-SAME: into %[[ALLOC]]
// -----
@@ -760,19 +760,19 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
// CHECK-LABEL: func.func @unpack_empty_inner_dims
// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
%arg1: tensor<128x256xi32>) -> tensor<4x16x16x32xi32>{
%elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0 : tensor<128x256x32xi32>)
@@ -810,7 +810,7 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
// -----
-func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>,
%arg2: tensor<128xi32>, %init_reduction: tensor<100x128x256xi32>) -> tensor<4x16x100x16x32xi32>
{
%reduction = linalg.generic {
@@ -867,7 +867,7 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
-func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
+func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>,
%filter: tensor<2x2xi32>) -> tensor<16x540x960xi32>{
%init = tensor.empty() : tensor<16x540x960xi32>
%empty = tensor.empty() : tensor<1x16x1080x1920xi32>
@@ -1370,3 +1370,36 @@ func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
+
+// -----
+
+func.func @push_down_unpack_with_no_outer_dims(%1: tensor<1x512x3xf32>) -> tensor<1x3x512xf32> {
+ %0 = tensor.empty() : tensor<3x512xf32>
+ %unpack = tensor.unpack %1 inner_dims_pos = [0] inner_tiles = [3] into %0 : tensor<1x512x3xf32> -> tensor<3x512xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 3, 512] : tensor<3x512xf32> into tensor<1x3x512xf32>
+ func.return %expanded : tensor<1x3x512xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_with_no_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [1, 1, 512, 3] : tensor<1x512x3xf32> into tensor<1x1x512x3xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x3x512xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] inner_dims_pos = [1] inner_tiles = [3] into %[[EMPTY]] : tensor<1x1x512x3xf32> -> tensor<1x3x512xf32>
+// CHECK: return %[[UNPACK]] : tensor<1x3x512xf32>
+
+// -----
+
+func.func @bubble_up_pack_with_no_outer_dims(%arg0: tensor<1x2x3xf32>) -> tensor<1x3x2xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x2x3xf32> into tensor<2x3xf32>
+ %1 = tensor.empty() : tensor<1x3x2xf32>
+ %2 = tensor.pack %0 inner_dims_pos = [0] inner_tiles = [2] into %1 : tensor<2x3xf32> -> tensor<1x3x2xf32>
+ return %2 : tensor<1x3x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_with_no_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x3x2xf32>
+// CHECK: %[[PACKED:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<1x2x3xf32> -> tensor<1x1x3x2xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACKED]] {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME: : tensor<1x1x3x2xf32> into tensor<1x3x2xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<1x3x2xf32>
More information about the Mlir-commits
mailing list