[Mlir-commits] [mlir] 0ac4821 - [mlir][linalg] unfold projected permutation. (#114704)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 9 15:35:01 PST 2024


Author: Javed Absar
Date: 2024-11-09T23:34:57Z
New Revision: 0ac4821b718dd14e80d3856efa532d52df6878bb

URL: https://github.com/llvm/llvm-project/commit/0ac4821b718dd14e80d3856efa532d52df6878bb
DIFF: https://github.com/llvm/llvm-project/commit/0ac4821b718dd14e80d3856efa532d52df6878bb.diff

LOG: [mlir][linalg] unfold projected permutation. (#114704)

Patterns to decompose the input operand(s) of a linalg.generic that has
a projected permutation`  affine-map -- i.e. effectively a folded `transpose`, 
`broadcast`, or a mixture of two --  into explicit transpose and broadcast.
This is useful for instance when trying to recognize named ops.
email: quic_mabsar at quicinc.com

Added: 
    mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
    mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a8662a3d6f63be..adfaea5a9d074d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1820,6 +1820,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 /// linalg.fill(%cst, tensor.extract_slice(%init)).
 void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to make explicit broadcasts and transforms in the
+/// input operands of a genericOp.
+void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns);
+
 /// Patterns to apply `splitReduction` below.
 void populateSplitReductionPattern(
     RewritePatternSet &patterns,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d7c63cdd8198d7..3594b084138124 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TilingInterfaceImpl.cpp
   Transforms.cpp
   TransposeConv2D.cpp
+  DecomposeGenericByUnfoldingPermutation.cpp
   Vectorization.cpp
   WinogradConv2D.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
new file mode 100644
index 00000000000000..83c4b5bdf10976
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -0,0 +1,249 @@
+//===- DecomposeGenericByUnfoldingPermutation.cpp                   -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include <map>
+#include <optional>
+#include <utility>
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// This pattern decomposes the input operand(s) of a linalg.generic that has
+/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
+/// and broadcast. Having them folded into the linalg.generic is a good
+/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
+/// explicit transpose and broadcast. This rewrite pattern helps do it for
+/// each input operand. This is useful for instance when trying to recognize
+/// named ops.
+///
+/// The transpose, broadcast, or mixture of both, are expressed in the affine
+/// map of the operand. Technically it is essentially `projected permutation`.
+///
+///  Example
+///
+/// ```mlir
+///
+/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
+/// #identity   = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+/// ...
+///    %res = linalg.generic
+///       { indexing_maps = [#projection, #identity, #identity],
+///       iterator_types = ["parallel", "parallel", "parallel",
+///                         "parallel", "parallel"]}
+///       ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
+///       outs(%z : tensor<5x9x7x8x10xf32>) {
+///         ^bb0(%in: f32, %in_1: f32, %out: f32):
+///              %div = arith.divf %in, %in_1 : f32
+///              linalg.yield %div : f32
+///    } -> tensor<5x9x7x8x10xf32>
+/// ```
+///
+/// In the above IR operand `%x` map is a projected-permutation. This can be
+/// unfolded as:
+///
+/// ```mlir
+///   ...
+///   %x_trans = linalg.transpose
+///                   ins(%x : tensor<7x8x9xf32>)
+///                   outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+///   ...
+///   %x_trans_bc = linalg.broadcast
+///                   ins(%x_trans : tensor<9x7x8xf32>)
+///                   outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
+///   %2 = linalg.div
+///           ins(%x_trans_bc, %y :
+///                  tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
+///           outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
+///
+/// Note that linalg.generic has been 'specialized' to linalg.div.
+///
+/// To unfold it, it is more optimal to transpose first and then do the
+/// broadcast. However, if transpose is done first, the permutation map needs
+/// to be expressed in terms of reduced dimension as broadcast hasn't happened
+/// yet. Also, the broadcast dimensions in a linalg.generic come from other
+/// operands (those not broadcasted along that particular dimension). We work
+/// this out by computing the convex-polyhedron shape of the linalg.generic
+/// iteration space from shapes of all the operands, both inputs and outputs.
+///
+struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// For the given `map`, determine what dimensions are transposed and what
+/// dimensions are broadcasted.
+/// Returns :
+///   transpose-permutation, broadcast-dimensions` (empty if not needed)
+///
+std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
+computeTransposeBroadcast(AffineMap &map) {
+  assert(map.isProjectedPermutation(false) && "not a projection");
+
+  // As the map is a projection it likely operates on a smaller set of
+  // dimensions as far as the transpose is concerned (rest are broadcast).
+  int64_t minorSize = map.getNumResults();
+
+  SmallVector<int64_t> minorResult;
+  for (int64_t i = 0; i < minorSize; ++i) {
+    auto expr = cast<AffineDimExpr>(map.getResults()[i]);
+    minorResult.push_back(expr.getPosition());
+  }
+
+  // If dims are not monotonically increasing then transpose is present.
+  SmallVector<int64_t> sortedResMap(minorResult);
+  std::sort(sortedResMap.begin(), sortedResMap.end());
+  bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
+                                  sortedResMap.begin(), sortedResMap.end());
+
+  // Walk the sorted map result to determine which dimensions are broadcasted.
+  SmallVector<int64_t> broadcast;
+  for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
+    if (j < minorSize && sortedResMap[j] == i) {
+      j++;
+      continue;
+    }
+    broadcast.push_back(i);
+  }
+
+  SmallVector<int64_t> permutation;
+  if (hasTranspose) {
+    // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
+    // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
+    // `x`s access is both transposed and broadcast. But when specifying
+    // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
+    // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
+    // refering to d3, d4. Therefore, re-base the transpose dimensions so
+    // that they start from d0.
+    permutation.resize(minorSize);
+    std::map<int64_t, int64_t> minorMap;
+    for (int64_t i = 0; i < minorSize; ++i)
+      minorMap.insert({sortedResMap[i], i});
+
+    // Re-map the dimensions.
+    SmallVector<int64_t> remappedResult(minorSize);
+    for (int64_t i = 0; i < minorSize; ++i)
+      remappedResult[i] = minorMap[minorResult[i]];
+
+    /// Calculate the permutation for the transpose.
+    for (unsigned i = 0; i < minorSize; ++i) {
+      permutation[remappedResult[i]] = i;
+    }
+  }
+  return {permutation, broadcast};
+}
+
+LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
+    GenericOp op, PatternRewriter &rewriter) const {
+  if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
+      op.isSingleYieldOp() || !op.isAllParallelLoops())
+    return failure();
+
+  // If the map of an operand is not a `projected permutation` then
+  // it cannot be decomposed to mere transpose and broadcast.
+  // The requirement that all maps be `projected permutation` may be
+  // over-restrictive but since we need to determine shape of the
+  // iteration space as well, reject if any map violates assumption.
+  for (auto &opOperand : op->getOpOperands()) {
+    auto map = op.getMatchingIndexingMap(&opOperand);
+    if (!map.isProjectedPermutation(false))
+      return failure();
+  }
+
+  // Decomposing linalg.generic involves creating `tensor.empty`
+  // which can have dynamic shapes but then we would have to work
+  // out which operand can supply that runtime-value (tensor.dim).
+  // Leaving it as a future TODO.
+  if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
+        auto opType = cast<RankedTensorType>(oper.get().getType());
+        return ShapedType::isDynamicShape(opType.getShape());
+      }))
+    return failure();
+
+  auto outputShape = op.getStaticLoopRanges();
+
+  auto loc = op.getLoc();
+  bool isChanged = false;
+  SmallVector<Value> newInitValues = op.getDpsInputs();
+  SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
+
+  // Walk over each input operand and unfold if it is transposed, broadcast
+  // or mix of two via operand's affine-map.
+  for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
+    auto &map = newMap[i];
+    auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
+    auto elType = inputRTType.getElementType();
+
+    /// Nothing to do if map is already an identity.
+    if (map.isIdentity())
+      continue;
+
+    auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
+
+    // Does it need transpose?
+    if (!permutation.empty()) {
+      /// linalg.transpose permutes the dimensions of input using
+      /// rule: dim(result, i) = dim(input, permutation[i])
+      SmallVector<int64_t> transposedShape(map.getNumResults());
+      for (int64_t i = 0; i < map.getNumResults(); ++i)
+        transposedShape[i] = inputRTType.getShape()[permutation[i]];
+
+      Value emptyTensor =
+          rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
+
+      auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
+                                                      emptyTensor, permutation);
+      newInitValues[i] = transposeOp->getResult(0);
+      isChanged = true;
+    }
+
+    // Does it require broadcast?
+    if (!broadcastedDims.empty()) {
+      assert(broadcastedDims.size() && "should have non size broadcast");
+      Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+          loc, outputShape, inputRTType.getElementType());
+
+      auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
+          loc, newInitValues[i], emptyTensor, broadcastedDims);
+
+      newInitValues[i] = broadcastOp->getResult(0);
+      isChanged = true;
+    }
+    newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
+  }
+
+  if (isChanged) {
+    SmallVector<Value> operands = op->getOperands();
+    ValueRange operandsRef(operands);
+
+    auto newOp = rewriter.create<linalg::GenericOp>(
+        /*location=*/op.getLoc(),
+        /*resultTensorTypes=*/op->getResultTypes(),
+        /*inputs=*/newInitValues,
+        /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+        /*indexingMaps=*/newMap,
+        /*iteratorTypes=*/op.getIteratorTypesArray());
+
+    newOp.getRegion().takeBody(op->getRegion(0));
+    rewriter.replaceOp(op, newOp->getResults());
+  }
+  return success();
+}
+
+} // namespace
+
+void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index dfafffce9d9b60..748e2a1377930d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass
 void LinalgSpecializeGenericOpsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   populateLinalgGenericOpsSpecializationPatterns(patterns);
+  populateDecomposeProjectedPermutationPatterns(patterns);
 
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     signalPassFailure();

diff  --git a/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
new file mode 100644
index 00000000000000..38e406a13ec087
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
+#identity   = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
+func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y:  tensor<5x9x7x8x10xf32>, %z :  tensor<5x9x7x8x10xf32>) ->  tensor<5x9x7x8x10xf32> {
+  %res = linalg.generic
+     { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+     ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
+     ^bb0(%in: f32, %in_1: f32, %out: f32):
+       %div = arith.divf %in, %in_1 : f32
+       linalg.yield %div : f32
+  } -> tensor<5x9x7x8x10xf32>
+  return %res : tensor<5x9x7x8x10xf32>
+}
+
+// CHECK-LABEL: transpose_and_broadcast
+// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
+// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32>
+// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32>
+// CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
+// CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+
+func.func @transpose_only(%x : tensor<32x2x16xf32>, %y:  tensor<2x16x32xf32>, %z :  tensor<2x16x32xf32>) ->  tensor<2x16x32xf32> {
+  %res = linalg.generic
+     { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
+     ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>)
+     outs(%z : tensor<2x16x32xf32>) {
+     ^bb0(%in: f32, %in_1: f32, %out: f32):
+       %div = arith.divf %in, %in_1 : f32
+       linalg.yield %div : f32
+  } -> tensor<2x16x32xf32>
+  return %res : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: transpose_only
+// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0]
+// CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :  tensor<2x16x32xf32>) ->  tensor<2x16x32xf32> {
+  %res = linalg.generic
+     { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
+     ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>)
+     outs(%z : tensor<2x16x32xf32>) {
+     ^bb0(%in: f32, %in_1: f32, %out: f32):
+       %div = arith.divf %in, %in_1 : f32
+       linalg.yield %div : f32
+  } -> tensor<2x16x32xf32>
+  return %res : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: broadcast_only
+// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
+// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK-NOT: linalg.generic


        


More information about the Mlir-commits mailing list