[Mlir-commits] [mlir] [mlir] Add inferContractionDims util for indexing map inputs (PR #76081)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 2 13:46:20 PST 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/76081

>From 8dd51b80995b79270b89d9d6b05382138b63e2bf Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 20 Dec 2023 12:09:23 -0500
Subject: [PATCH] [mlir] Add inferContractionDims util for indexing map inputs

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  2 +
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 91 +++++++++++++------
 2 files changed, 65 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..f92843a1dcb987 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -62,6 +62,8 @@ struct ContractionDimensions {
 /// `k`, indices are returned in sorted order.
 /// Returns a failure if any of `m`, `n` or `k` is empty.
 FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+FailureOr<ContractionDimensions>
+inferContractionDims(ArrayRef<AffineMap> indexingMaps);
 
 /// Checks whether `linalgOp` conforms to ContractionOpInterface.
 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..291784072896c3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) {
   return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
 }
 
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
+/// Given an `indexingMap` and its corresponding `iterators`, returns
+/// the positions of the iterators of type `iter` that are indexed by
+/// the `indexingMap` as a permutation. This is useful to infer various
+/// subcomputations on a `LinalgOp`. This is performed by looking up
+/// each result in the `indexingMap` and determining whether:
 ///   - It is a single AffineDimExpr.
 ///   - It is the only result involving this AffineDimExpr.
 static llvm::SmallDenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+findPermutationsIndexingOperand(AffineMap indexingMap,
+                                ArrayRef<utils::IteratorType> iterators,
                                 utils::IteratorType iter) {
+  assert(iterators.size() == indexingMap.getNumDims());
   llvm::SmallDenseSet<int64_t> res;
-  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
-  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
   for (AffineExpr e : indexingMap.getResults()) {
     if (auto d = dyn_cast<AffineDimExpr>(e)) {
-      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+      if (iterators[d.getPosition()] == iter &&
           llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
             return e.isFunctionOfDim(d.getPosition());
           }) == 1)
@@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel;
 auto red = utils::IteratorType::reduction;
 } // namespace
 
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static FailureOr<SmallVector<utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+  if (!map.isProjectedPermutation())
+    return failure();
+  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+  for (auto expr : map.getResults())
+    if (auto dim = dyn_cast<AffineDimExpr>(expr))
+      iterators[dim.getPosition()] = par;
+  return iterators;
+}
+
 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
 ///   1. The m dimension is involved in an outer-product along LHS
@@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction;
 ///   5. Optional batch dimensions that appear in all operands are captured.
 /// This allows e.g. detecting that some contraction is embedded within
 /// `linalgOp` with some orthogonal heuristic.
-FailureOr<ContractionDimensions>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
-  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
-    return failure();
-
-  llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), par);
-  llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), par);
-  llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInitOperand(0), par);
+static FailureOr<ContractionDimensions>
+inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
+                         ArrayRef<utils::IteratorType> iterators) {
+  llvm::SmallDenseSet<int64_t> a =
+      findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
+  llvm::SmallDenseSet<int64_t> b =
+      findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
+  llvm::SmallDenseSet<int64_t> c =
+      findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
 
   // A & C - B are the iterators involved in an outer-product along A (the LHS).
   llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
   llvm::set_intersect(batches, c);
 
   // A & B red are the reduction dimensions.
-  llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), red);
-  llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), red);
+  llvm::SmallDenseSet<int64_t> ra =
+      findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
+  llvm::SmallDenseSet<int64_t> rb =
+      findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
   llvm::set_intersect(ra, rb);
 
   // Return each set in sorted order.
@@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
   return dimensions;
 }
 
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+    return failure();
+  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
+                                  linalgOp.getIteratorTypesArray());
+}
+
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return failure();
+  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
+  if (failed(iterators))
+    return failure();
+  return inferContractionDimsImpl(indexingMaps, iterators.value());
+}
+
 namespace mlir::linalg::detail {
 enum class MatchContractionResult {
   Success = 0,
@@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions>
 inferConvolutionDimsImpl(LinalgOp linalgOp,
                          ConvAccessExprWalker &inputExprWalker,
                          bool allowEmptyConvolvedDims) {
+  auto filterMap =
+      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
+  auto outputMap =
+      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
   llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), par);
+      filterMap, linalgOp.getIteratorTypesArray(), par);
   llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInitOperand(0), par);
+      outputMap, linalgOp.getIteratorTypesArray(), par);
 
   // unConvolvedDims & outputDims - filterDims are the batch iterators.
   llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
@@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
   llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
 
   llvm::SmallDenseSet<int64_t> filterReducedDims =
-      findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
-                                      red);
+      findPermutationsIndexingOperand(filterMap,
+                                      linalgOp.getIteratorTypesArray(), red);
 
   // convolvedDims & filterReducedDims are the filter loop iterators.
   llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;



More information about the Mlir-commits mailing list