[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse ops with C… (PR #70653)

Amir Bishara llvmlistbot at llvm.org
Mon Oct 30 04:49:00 PDT 2023


https://github.com/amirBish created https://github.com/llvm/llvm-project/pull/70653

…anonicalized Identity maps

Supporting collapsion of linalg ops which have
canonicalized identity maps matched for their
OpOperands.

Canonnicalized Identity is an identity affine map
which include zero constants corresponded to the
values of `1` of the Operand's shape.

a common use case for this support would be the
usage of CollapseLinalgDimensions after Tosa-To-Linalg , since the later generates linalg.generic ops with canonicalized identity maps (and the rewrite pattern would fail matching, since it supports only projected permutes indexing maps).

>From 816eacade09b34ed261a0a075328cc681420c294 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Mon, 30 Oct 2023 13:35:36 +0200
Subject: [PATCH] [mlir][linalg] Enable CollapseLinalgDimensions to collapse
 ops with Canonicalized Identity maps

Supporting collapsion of linalg ops which have
canonicalized identity maps matched for their
OpOperands.

Canonnicalized Identity is an identity affine map
which include zero constants corresponded to the
values of `1` of the Operand's shape.

a common use case for this support would be the
usage of CollapseLinalgDimensions after Tosa-To-Linalg
, since the later generates linalg.generic ops with
canonicalized identity maps (and the rewrite pattern
would fail matching, since it supports only projected
permutes indexing maps).
---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 29 ++++++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  5 +-
 mlir/include/mlir/IR/AffineMap.h              | 11 ++++
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 58 +++++++++++++------
 mlir/lib/IR/AffineMap.cpp                     | 17 ++++++
 mlir/test/Dialect/Linalg/collapse-dim.mlir    | 32 ++++++++++
 6 files changed, 133 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 69ca888a8acdbe0..31efa35540b25e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -668,6 +668,35 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if the indexing map which matches the OpOperand
+        is considered as a canonicalized identity.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isCanonicalizedIdentityMap",
+      /*args=*/(ins "OpOperand*": $opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+          auto indexingMap = $_op.getMatchingIndexingMap(opOperand);
+          return indexingMap.isCanonicalizedIdentity(getShape(opOperand));
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if all of the indexing maps of the specefic linalg operation
+        are considered as canonicalized identity.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasOnlyCanonicalizedIdentityMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+          return llvm::all_of(this->getOperation()->getOpOperands(),[&](OpOperand &opOperand){
+            return $_op.isCanonicalizedIdentityMap(&opOperand);
+          });
+      }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..b7c769ed3560ee8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1043,8 +1043,9 @@ splitReductionByScaling(RewriterBase &b, LinalgOp op,
 /// range of the specified indexing map.
 bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 /// Return `true` if all sequences of dimensions specified in `dimSequences` are
-/// contiguous in all the ranges of the `maps`.
-bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
+/// contiguous in all the ranges of the indexing maps of the `op`.
+template <typename LinalgType>
+bool areDimSequencesPreserved(LinalgType op,
                               ArrayRef<ReassociationIndices> dimSequences);
 
 /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..d446e1500845406 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -131,6 +131,17 @@ class AffineMap {
   /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
   bool isMinorIdentity() const;
 
+  /// Returns true if this affine map is a canonicalized identity.
+  /// Otherwise return false.
+  /// A canonicalized identity affine map corresponds to an identity
+  /// affine function on the dimensional identifiers. which may
+  /// include zero constant expressions in the affine map results.
+  /// These zero constants should be corresponded to dimesnions with
+  /// value 1.
+  /// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d1, d2, d3, d4)>
+  /// is considered a canonicalized identity if `shape[0] == 1`.
+  bool isCanonicalizedIdentity(ArrayRef<int64_t> shape) const;
+
   /// Returns true if this affine map is a minor identity up to broadcasted
   /// dimensions which are indicated by value 0 in the result. If
   /// `broadcastedDims` is not null, it will be populated with the indices of
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 32d38a21e4e00f4..e2bdbebb831e5c4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1054,12 +1054,14 @@ bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
   // 3. No element of sequence found. Return true.
   return true;
 }
-
+template <typename LinalgType>
 bool mlir::linalg::areDimSequencesPreserved(
-    ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
-  return llvm::all_of(maps, [&](AffineMap map) {
+    LinalgType op, ArrayRef<ReassociationIndices> dimSequences) {
+  return llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) {
     return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
-      return isDimSequencePreserved(map, dimSequence);
+      return op.isCanonicalizedIdentityMap(&opOperand) ||
+             isDimSequencePreserved(op.getMatchingIndexingMap(&opOperand),
+                                    dimSequence);
     });
   });
 }
@@ -1320,17 +1322,31 @@ getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
 
 /// Compute the indexing map in the collapsed op that corresponds to the given
 /// `indexingMap` of the original operation.
+template <typename LinalgType>
 static AffineMap
-getCollapsedOpIndexingMap(AffineMap indexingMap,
+getCollapsedOpIndexingMap(LinalgType op, OpOperand &opOperand,
                           const CollapsingInfo &collapsingInfo) {
+  auto indexingMap = op.getMatchingIndexingMap(&opOperand);
   MLIRContext *context = indexingMap.getContext();
-  assert(indexingMap.isProjectedPermutation() &&
-         "expected indexing map to be projected permutation");
+  assert((op.isCanonicalizedIdentityMap(&opOperand) ||
+          indexingMap.isProjectedPermutation()) &&
+         "expected indexing map to be projected permutation or canonicalized "
+         "identity");
   SmallVector<AffineExpr> resultExprs;
   auto origOpToCollapsedOpMapping =
       collapsingInfo.getOrigOpToCollapsedOpMapping();
-  for (auto expr : indexingMap.getResults()) {
-    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+  unsigned dim;
+  for (auto pair : llvm::enumerate(indexingMap.getResults())) {
+    AffineExpr expr = pair.value();
+    auto constExprt = expr.dyn_cast<AffineConstantExpr>();
+    if (constExprt) {
+      assert(!constExprt.getValue() &&
+             "expected zero constants in canonicalized identity");
+      dim = pair.index();
+    } else {
+      dim = expr.cast<AffineDimExpr>().getPosition();
+    }
+
     // If the dim is not the first of the collapsed dim, do nothing.
     if (origOpToCollapsedOpMapping[dim].second != 0)
       continue;
@@ -1354,9 +1370,17 @@ getOperandReassociation(AffineMap indexingMap,
       collapsingInfo.getOrigOpToCollapsedOpMapping();
   auto collapsedOpToOrigOpMapping =
       collapsingInfo.getCollapsedOpToOrigOpMapping();
+  unsigned dim;
   while (counter < indexingMap.getNumResults()) {
-    unsigned dim =
-        indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
+    AffineExpr expr = indexingMap.getResult(counter);
+    auto constExprt = expr.dyn_cast<AffineConstantExpr>();
+    if (constExprt) {
+      assert(!constExprt.getValue() &&
+             "expected zero constants in canonicalized identity");
+      dim = counter;
+    } else {
+      dim = expr.cast<AffineDimExpr>().getPosition();
+    }
     // This is the start of a collapsed dimensions of the iteration that
     // is gauranteed to be preserved in the indexing map. The number of folded
     // dims is obtained from the collapsed op to original op mapping.
@@ -1480,10 +1504,11 @@ Operation *createCollapsedOp(LinalgType op,
       getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
-  auto indexingMaps =
-      llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
-        return getCollapsedOpIndexingMap(map, collapsingInfo);
-      });
+  auto indexingMaps = llvm::to_vector(
+      llvm::map_range(op->getOpOperands(), [&](OpOperand &opOperand) {
+        return getCollapsedOpIndexingMap<LinalgType>(op, opOperand,
+                                                     collapsingInfo);
+      }));
 
   Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
@@ -1659,8 +1684,7 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
       return failure();
 
     // Check if the specified list of dimensions to collapse is a valid list.
-    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
-                                  collapsableIterationDims)) {
+    if (!areDimSequencesPreserved<LinalgType>(op, collapsableIterationDims)) {
       return rewriter.notifyMatchFailure(
           op, "specified dimensions cannot be collapsed");
     }
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..a10ffb7bdd2b3b0 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -133,6 +133,23 @@ bool AffineMap::isMinorIdentity() const {
              getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
+bool AffineMap::isCanonicalizedIdentity(ArrayRef<int64_t> shape) const {
+  if (getNumDims() != getNumResults())
+    return false;
+  if (getNumDims() != shape.size())
+    return false;
+  for (auto [index, result] : llvm::enumerate(getResults())) {
+    auto constExpr = result.dyn_cast<AffineConstantExpr>();
+    if (constExpr && !constExpr.getValue() && shape[index] == 1)
+      continue;
+
+    auto expr = result.dyn_cast<AffineDimExpr>();
+    if (!expr || expr.getPosition() != index)
+      return false;
+  }
+  return true;
+}
+
 /// Returns true if this affine map is a minor identity up to broadcasted
 /// dimensions which are indicated by value 0 in the result.
 bool AffineMap::isMinorIdentityWithBroadcasting(
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f53387477..ed375ce703b41ff 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -153,3 +153,35 @@ func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: me
   linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
   return
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @collapse_canonicalized_identity(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<2x2x1x4096xf32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<2x2x1x4096xf32>) -> tensor<2x2x1x4096xf32> {
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x1x4096xf32> into tensor<2x2x4096xf32>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x1x4096xf32> into tensor<2x2x4096xf32>
+// CHECK:           %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_2]] : tensor<2x2x4096xf32>) outs(%[[VAL_3]] : tensor<2x2x4096xf32>) {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           } -> tensor<2x2x4096xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_9:.*]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x4096xf32> into tensor<2x2x1x4096xf32>
+// CHECK:           return %[[VAL_8]] : tensor<2x2x1x4096xf32>
+// CHECK:         }
+
+
+func.func @collapse_canonicalized_identity(
+    %arg0: tensor<2x2x1x4096xf32>, %arg1: tensor<2x2x1x4096xf32>) -> tensor<2x2x1x4096xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, d3)>,
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  ins(%arg0 : tensor<2x2x1x4096xf32>) outs(%arg1 : tensor<2x2x1x4096xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %1 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %1 : f32
+  } -> tensor<2x2x1x4096xf32>
+  return %0 : tensor<2x2x1x4096xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list