[Mlir-commits] [mlir] [mlir][linalg] unfold projected permutation. (PR #114704)

Javed Absar llvmlistbot at llvm.org
Sun Nov 3 04:18:59 PST 2024


https://github.com/javedabsar1 created https://github.com/llvm/llvm-project/pull/114704

Identify folded 'projected permutations' (i.e. mixture of transpose and/or broadcast) in linalg generic operands. 
The 'projected permutations' are unfolded as separate linalg.transpose and linalg.broadcast so that the generic
operates on simple identity map which is necessary to replace generic with named op.



>From cdf865ca4c8b4e67b03de744b3c2540c34cd8082 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Wed, 16 Oct 2024 13:49:59 -0400
Subject: [PATCH] [mlir][linalg] unfold projected permutation.

Identify folded 'projected permutations' (i.e. mixture of
transpose and/or broadcast) in linalg generic operands.
The 'projected permutations' are unfolded as separate
linalg.transpose and linalg.broadcast so that the generic
operates on simple identity map which is necessary to
replace generic with named op.
---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |   2 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Dialect/Linalg/Transforms/Specialize.cpp  |   1 +
 .../Transforms/UnfoldProjectedPermutation.cpp | 270 ++++++++++++++++++
 .../Linalg/unfold_projected_permutation.mlir  |  71 +++++
 6 files changed, 349 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp
 create mode 100644 mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 0a404194569c22..b81a4c9c8760cf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -252,7 +252,7 @@ def LinalgStructuredInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumParallelLoops() ==  getNumParallelLoops();
+        return getNumParallelLoops() ==  getNumLoops();
       }]
     >,
     InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0693e31b4f70af..a110eb88e9f699 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1786,6 +1786,11 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 /// linalg.fill(%cst, tensor.extract_slice(%init)).
 void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
 
+
+/// Add patterns to make explicit broadcasts and transforms in the
+/// input operands of a genericOp.
+void populateUnfoldProjectedPermutationPatterns(RewritePatternSet &patterns);
+
 /// Patterns to apply `splitReduction` below.
 void populateSplitReductionPattern(
     RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d7c63cdd8198d7..dfe6d7a54c8f14 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TilingInterfaceImpl.cpp
   Transforms.cpp
   TransposeConv2D.cpp
+  UnfoldProjectedPermutation.cpp
   Vectorization.cpp
   WinogradConv2D.cpp
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index dfafffce9d9b60..a911286d5d44b2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass
 void LinalgSpecializeGenericOpsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   populateLinalgGenericOpsSpecializationPatterns(patterns);
+  populateUnfoldProjectedPermutationPatterns(patterns);
 
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     signalPassFailure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp
new file mode 100644
index 00000000000000..56d6bd23b2343a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp
@@ -0,0 +1,270 @@
+//===- UnfoldProjectedPermutation.cpp - extract projected projections   ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern to decompose the operand of a GenericOp that
+// has `transpose+broadcast` juxtaposed via its affine map into separate
+// transpose and broadcast ops.
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include <utility>
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include <map>
+#include <optional>
+#include <vector>
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+
+/// Projected permutation are effectively folding in of a mixture of
+/// transpose and broadcast into the affine map of the operand.
+/// While folding of transpose and broadcast into the affine map of the
+/// linalg.generic operand is a very effective optimization, sometimes
+/// we may want to unfold that, for instance when recognizing named ops.
+///
+///  Example
+///
+/// ```mlir
+///
+/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
+/// #identity   = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+/// ...
+///    %res = linalg.generic
+///       { indexing_maps = [#projection, #identity, #identity],
+///       iterator_types = ["parallel", "parallel", "parallel",
+///                         "parallel", "parallel"]}
+///       ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
+///       outs(%z : tensor<5x9x7x8x10xf32>) {
+///         ^bb0(%in: f32, %in_1: f32, %out: f32):
+///              %div = arith.divf %in, %in_1 : f32
+///              linalg.yield %div : f32
+///    } -> tensor<5x9x7x8x10xf32>
+/// ```
+///
+/// In the above IR operand `%x` map is a projected-permutation. This can be
+/// unfolded as:
+///
+/// ```mlir
+///   ...
+///   %transposed = linalg.transpose ins(%x : tensor<7x8x9xf32>)
+///                    outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+///   ...
+///   %broadcasted = linalg.broadcast ins(%transposed : tensor<9x7x8xf32>)
+///                    outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
+///   %2 = linalg.div
+///           ins(%broadcasted, %y :
+///                  tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
+///           outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
+///
+/// Note that linalg.generic has been 'specialized' to linalg.div.
+/// To unfold it is more effective to transpose first and then do the broadcast.
+/// However, if transpose is done first, the permutation map needs to be
+/// expressed in terms of reduced dimension (as broadcast hasn't happened yet).
+/// Also, the broadcast dimensions in a linalg.generic come from other operands
+/// (those not broadcasted along that particular dimension). We work this out
+/// by computing the polytope shape of the linalg.gneric from shapes of all the
+/// operands (inputs and outputs).
+
+struct UnfoldProjectedPermutation : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Calculate shape (dimensions) of the iteration space polytope.
+/// This is calculated by concatenating the indexing maps of all operands
+/// of the generic; inverting the concatenation; concatenating all the
+/// shapes of the operands; and then doing `apply map` to those two.
+SmallVector<int64_t> getPolytopeDims(GenericOp op) {
+  assert(op.hasPureTensorSemantics() && "works only on tensors");
+
+  /// Concat indexing maps of all operands and invert the mapping.
+  auto maps = op.getIndexingMapsArray();
+  auto concat = concatAffineMaps(maps);
+  auto inverse = inversePermutation(concat);
+
+  /// Concat the size of each dims of all operands.
+  SmallVector<int64_t> dims;
+  for (auto &operand : op->getOpOperands()) {
+    auto rankedType = cast<RankedTensorType>(operand.get().getType());
+    for (auto size : rankedType.getShape())
+      dims.push_back(size);
+  }
+
+  /// Match the inverse map with dims to get polytope dimensions.
+  /// Note that some maybe 'kDynamic'.
+  return applyPermutationMap<int64_t>(inverse, dims);
+}
+
+/// For the given `map` determine what dimensions are transposed
+/// and what dimensions are broadcasted.
+/// Returns :
+///  `isTransposed, isBroadcast,
+///   transpose-permutation, broadcast-dimensions`
+///
+std::tuple<bool, bool, SmallVector<int64_t>, SmallVector<int64_t>>
+computeTransposeBroadcast(AffineMap &map) {
+  assert(map.isProjectedPermutation(false) && "not a projection");
+
+  // Dimensions that don't appear on result are broadcast.
+  int64_t minorSize = map.getNumResults();
+
+  // Convert affine expr to int64_t.
+  SmallVector<int64_t> minorResult;
+  for (int64_t i = 0; i < minorSize; ++i) {
+    auto expr = cast<AffineDimExpr>(map.getResults()[i]);
+    minorResult.push_back(expr.getPosition());
+  }
+
+  // If dims are not monotonically increasing then transpose is present.
+  SmallVector<int64_t> sorted(minorResult);
+  std::sort(sorted.begin(), sorted.end());
+  bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
+                                  sorted.begin(), sorted.end());
+
+  // Walk the sorted map result to determine which dimensions are broadcasted.
+  SmallVector<int64_t> broadcast;
+  for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
+    if (j < minorSize && sorted[j] == i) {
+      j++;
+      continue;
+    }
+    broadcast.push_back(i);
+  }
+  bool hasBroadcast = broadcast.size();
+
+  /// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
+  /// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
+  /// `x`s access is both transposed and brodcast. But when specifying
+  /// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
+  /// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
+  /// refering to d3, d4. Therefore, re-base the transpose dimensions so
+  /// that they start from d0.
+  std::map<int64_t, int64_t> minorMap;
+  for (int64_t i = 0; i < minorSize; ++i)
+    minorMap.insert({sorted[i], i});
+
+  // Re-map the dimensions.
+  SmallVector<int64_t> remappedResult(minorSize);
+  for (int64_t i = 0; i < minorSize; ++i)
+    remappedResult[i] = minorMap[minorResult[i]];
+
+  /// Calculate the permutation for the transpose.
+  SmallVector<int64_t> permutation(minorSize);
+  for (unsigned i = 0; i < minorSize; ++i) {
+    permutation[remappedResult[i]] = i;
+  }
+
+  return {hasTranspose, hasBroadcast, permutation, broadcast};
+}
+
+LogicalResult
+UnfoldProjectedPermutation::matchAndRewrite(GenericOp op,
+                                            PatternRewriter &rewriter) const {
+  if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
+      op.isSingleYieldOp() || !op.isAllParallelLoops())
+    return failure();
+
+  // All maps need to be projected permutations.
+  for (auto &opOperand : op->getOpOperands()) {
+    auto map = op.getMatchingIndexingMap(&opOperand);
+    if (!map.isProjectedPermutation(false))
+      return failure();
+  }
+
+  // Currently we handle only static shapes.
+  for (auto &operand : op->getOpOperands()) {
+    auto rankedType = cast<RankedTensorType>(operand.get().getType());
+    for (auto size : rankedType.getShape())
+      if (size == ShapedType::kDynamic)
+        return failure();
+  }
+
+  // Calculate polytope bounds from affine maps and operand(s) shapes.
+  auto polytope = getPolytopeDims(op);
+
+  auto loc = op.getLoc();
+  bool isChanged = false;
+  SmallVector<Value> newInitValues = op.getDpsInputs();
+  SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
+
+  // Walk over each input operand and unfold if it is transposed, broadcast
+  // or mix of two via operand's affine-map.
+  for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
+    auto &map = newMap[i];
+    auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
+    auto elType = inputRTType.getElementType();
+
+    /// Nothing to do if map is already an identity.
+    if (map.isIdentity())
+      continue;
+
+    auto [hasTranspose, hasBroadcast, permutation, broadcastedDims] =
+        computeTransposeBroadcast(map);
+
+    if (hasTranspose) {
+      /// linalg.transpose permutes the dimensions of input using
+      /// rule: dim(result, i) = dim(input, permutation[i])
+      SmallVector<int64_t> transposedShape(map.getNumResults());
+      for (int64_t i = 0; i < map.getNumResults(); ++i)
+        transposedShape[i] = inputRTType.getShape()[permutation[i]];
+
+      Value emptyTensor =
+          rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
+
+      auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
+                                                      emptyTensor, permutation);
+      newInitValues[i] = transposeOp->getResult(0);
+      isChanged = true;
+    }
+
+    // Does it require broadcast
+    if (hasBroadcast) {
+      assert(broadcastedDims.size() && "should have non size broadcast");
+      Value emptyTensor = rewriter.create<tensor::EmptyOp>(
+          loc, polytope, inputRTType.getElementType());
+
+      auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
+          loc, newInitValues[i], emptyTensor, broadcastedDims);
+
+      newInitValues[i] = broadcastOp->getResult(0);
+      isChanged = true;
+    }
+    newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
+  }
+
+  if (isChanged) {
+    SmallVector<Value> operands = op->getOperands();
+    ValueRange operandsRef(operands);
+
+    auto newOp = rewriter.create<linalg::GenericOp>(
+        /*location=*/op.getLoc(),
+        /*resultTensorTypes=*/op->getResultTypes(),
+        /*inputs=*/newInitValues,
+        /*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
+        /*indexingMaps=*/newMap,
+        /*iteratorTypes=*/op.getIteratorTypesArray());
+
+    newOp.getRegion().takeBody(op->getRegion(0));
+    rewriter.replaceOp(op, newOp->getResults());
+  }
+  return success();
+}
+
+} // namespace
+
+void mlir::linalg::populateUnfoldProjectedPermutationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<UnfoldProjectedPermutation>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir b/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir
new file mode 100644
index 00000000000000..4efa07b2de12e3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/unfold_projected_permutation.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
+#identity   = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
+func.func @test_mixed(%x : tensor<7x8x9xf32>, %y:  tensor<5x9x7x8x10xf32>, %z :  tensor<5x9x7x8x10xf32>) ->  tensor<5x9x7x8x10xf32> {
+  %res = linalg.generic
+     { indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} 
+     ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
+     ^bb0(%in: f32, %in_1: f32, %out: f32):
+       %div = arith.divf %in, %in_1 : f32
+       linalg.yield %div : f32
+  } -> tensor<5x9x7x8x10xf32>
+  return %res : tensor<5x9x7x8x10xf32>
+}
+
+// CHECK-LABEL: test_mixed
+// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
+// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32>
+// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32>
+// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Transposed]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
+// CHECK: {{.*}} = linalg.div ins(%[[Broadcasted]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+
+func.func @test_transposed(%x : tensor<32x2x16xf32>, %y:  tensor<2x16x32xf32>, %z :  tensor<2x16x32xf32>) ->  tensor<2x16x32xf32> {
+  %res = linalg.generic
+     { indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
+     ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>)
+     outs(%z : tensor<2x16x32xf32>) {
+     ^bb0(%in: f32, %in_1: f32, %out: f32):
+       %div = arith.divf %in, %in_1 : f32
+       linalg.yield %div : f32
+  } -> tensor<2x16x32xf32>
+  return %res : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: test_transposed
+// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[Transposed:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0]
+// CHECK: {{.*}} = linalg.div ins(%[[Transposed]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @test_broadcast(%x : tensor<2x16x32xf32>, %y:  tensor<2x32xf32>, %z :  tensor<2x16x32xf32>) ->  tensor<2x16x32xf32> {
+  %res = linalg.generic
+     { indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
+     ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>)
+     outs(%z : tensor<2x16x32xf32>) {
+     ^bb0(%in: f32, %in_1: f32, %out: f32):
+       %div = arith.divf %in, %in_1 : f32
+       linalg.yield %div : f32
+  } -> tensor<2x16x32xf32>
+  return %res : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: test_broadcast
+// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[Broadcasted:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
+// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[Broadcasted]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK-NOT: linalg.generic



More information about the Mlir-commits mailing list