[Mlir-commits] [mlir] [mlir][linalg] unfold projected permutation. (PR #114704)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 5 08:49:01 PST 2024

@@ -0,0 +1,270 @@
+//===- UnfoldProjectedPermutation.cpp - extract projected projections   ---===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This file implements pattern to decompose the operand of a GenericOp that
+// has `transpose+broadcast` juxtaposed via its affine map into separate
+// transpose and broadcast ops.
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include <utility>
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include <map>
+#include <optional>
+#include <vector>
+using namespace mlir;
+using namespace mlir::linalg;
+namespace {
+/// Projected permutation are effectively folding in of a mixture of
+/// transpose and broadcast into the affine map of the operand.
+/// While folding of transpose and broadcast into the affine map of the
+/// linalg.generic operand is a very effective optimization, sometimes
+/// we may want to unfold that, for instance when recognizing named ops.
+///  Example
+/// ```mlir
+/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
+/// #identity   = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+/// ...
+///    %res = linalg.generic
+///       { indexing_maps = [#projection, #identity, #identity],
+///       iterator_types = ["parallel", "parallel", "parallel",
+///                         "parallel", "parallel"]}
+///       ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
+///       outs(%z : tensor<5x9x7x8x10xf32>) {
+///         ^bb0(%in: f32, %in_1: f32, %out: f32):
+///              %div = arith.divf %in, %in_1 : f32
+///              linalg.yield %div : f32
+///    } -> tensor<5x9x7x8x10xf32>
+/// ```
+/// In the above IR operand `%x` map is a projected-permutation. This can be
+/// unfolded as:
+/// ```mlir
+///   ...
+///   %transposed = linalg.transpose ins(%x : tensor<7x8x9xf32>)
+///                    outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+///   ...
+///   %broadcasted = linalg.broadcast ins(%transposed : tensor<9x7x8xf32>)
+///                    outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
+///   %2 = linalg.div
+///           ins(%broadcasted, %y :
+///                  tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
+///           outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
+/// Note that linalg.generic has been 'specialized' to linalg.div.
+/// To unfold it is more effective to transpose first and then do the broadcast.
+/// However, if transpose is done first, the permutation map needs to be
+/// expressed in terms of reduced dimension (as broadcast hasn't happened yet).
+/// Also, the broadcast dimensions in a linalg.generic come from other operands
+/// (those not broadcasted along that particular dimension). We work this out
+/// by computing the polytope shape of the linalg.gneric from shapes of all the
+/// operands (inputs and outputs).
+struct UnfoldProjectedPermutation : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override;
+/// Calculate shape (dimensions) of the iteration space polytope.
+/// This is calculated by concatenating the indexing maps of all operands
+/// of the generic; inverting the concatenation; concatenating all the
+/// shapes of the operands; and then doing `apply map` to those two.
+SmallVector<int64_t> getPolytopeDims(GenericOp op) {
+  assert(op.hasPureTensorSemantics() && "works only on tensors");
+  /// Concat indexing maps of all operands and invert the mapping.
+  auto maps = op.getIndexingMapsArray();
+  auto concat = concatAffineMaps(maps);
+  auto inverse = inversePermutation(concat);
banach-space wrote:

Spell out `auto`.


More information about the Mlir-commits mailing list