[Mlir-commits] [mlir] [mlir] Add inferContractionDims util for indexing map inputs (PR #76081)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 21 07:18:59 PST 2023
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/76081
>From 139ade6f5d17f50907229bbc0f08b9d03ab47bba Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 20 Dec 2023 12:09:23 -0500
Subject: [PATCH 1/2] [mlir] Add inferContractionDims util for indexing map
inputs
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 2 +
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 95 ++++++++++++++-----
2 files changed, 71 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..f92843a1dcb987 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -62,6 +62,8 @@ struct ContractionDimensions {
/// `k`, indices are returned in sorted order.
/// Returns a failure if any of `m`, `n` or `k` is empty.
FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+FailureOr<ContractionDimensions>
+inferContractionDims(ArrayRef<AffineMap> indexingMaps);
/// Checks whether `linalgOp` conforms to ContractionOpInterface.
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..78a13017ae5c3e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -176,22 +176,14 @@ static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
-/// - It is a single AffineDimExpr.
-/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
- utils::IteratorType iter) {
+findPermutationsIndexingOperandImpl(AffineMap indexingMap,
+ ArrayRef<utils::IteratorType> iterators,
+ utils::IteratorType iter) {
llvm::SmallDenseSet<int64_t> res;
- assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
- if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+ if (iterators[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
@@ -201,11 +193,48 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
return res;
}
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+/// - It is a single AffineDimExpr.
+/// - It is the only result involving this AffineDimExpr.
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+ utils::IteratorType iter) {
+ assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+ return findPermutationsIndexingOperandImpl(
+ linalgOp.getMatchingIndexingMap(opOperand),
+ linalgOp.getIteratorTypesArray(), iter);
+}
+
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(AffineMap indexingMap,
+ ArrayRef<utils::IteratorType> iterators,
+ utils::IteratorType iter) {
+ return findPermutationsIndexingOperandImpl(indexingMap, iterators, iter);
+}
+
namespace {
auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) {
+ SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+ for (auto expr : map.getResults()) {
+ if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
+ iterators[dim.getPosition()] = par;
+ }
+ }
+ return iterators;
+}
+
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
@@ -218,16 +247,14 @@ auto red = utils::IteratorType::reduction;
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
FailureOr<ContractionDimensions>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
- return failure();
-
- llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), par);
- llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
- llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
+inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<utils::IteratorType> iterators) {
+ llvm::SmallDenseSet<int64_t> a =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
+ llvm::SmallDenseSet<int64_t> b =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
+ llvm::SmallDenseSet<int64_t> c =
+ findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +270,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
- llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), red);
- llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), red);
+ llvm::SmallDenseSet<int64_t> ra =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
+ llvm::SmallDenseSet<int64_t> rb =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);
// Return each set in sorted order.
@@ -262,6 +289,22 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
return dimensions;
}
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+ return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
+ linalgOp.getIteratorTypesArray());
+}
+
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return failure();
+ return inferContractionDimsImpl(indexingMaps,
+ inferIteratorsFromOutMap(indexingMaps[2]));
+}
+
namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
>From 1a6b0b51ef804d271180769f119ed552a08c599a Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 21 Dec 2023 10:17:39 -0500
Subject: [PATCH 2/2] address comments
---
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 34 ++++++++++++-------
1 file changed, 22 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 78a13017ae5c3e..b8d9d5fe567a73 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,9 +15,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
@@ -176,6 +178,13 @@ static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+/// - It is a single AffineDimExpr.
+/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperandImpl(AffineMap indexingMap,
ArrayRef<utils::IteratorType> iterators,
@@ -193,13 +202,6 @@ findPermutationsIndexingOperandImpl(AffineMap indexingMap,
return res;
}
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
-/// - It is a single AffineDimExpr.
-/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
utils::IteratorType iter) {
@@ -213,6 +215,7 @@ static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(AffineMap indexingMap,
ArrayRef<utils::IteratorType> iterators,
utils::IteratorType iter) {
+ assert(iterators.size() == indexingMap.getNumDims());
return findPermutationsIndexingOperandImpl(indexingMap, iterators, iter);
}
@@ -225,14 +228,19 @@ auto red = utils::IteratorType::reduction;
/// are present in the map results, and returns an iterator types array with
/// parallel types for dims that are present, and reduction types for dims that
/// are not present.
-static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) {
+static FailureOr<ArrayRef<utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
for (auto expr : map.getResults()) {
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
iterators[dim.getPosition()] = par;
}
}
- return iterators;
+ if (iterators.size() != map.getNumDims())
+ return failure();
+ return ArrayRef<utils::IteratorType>(iterators);
}
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
@@ -246,7 +254,7 @@ static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) {
/// 5. Optional batch dimensions that appear in all operands are captured.
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
-FailureOr<ContractionDimensions>
+static FailureOr<ContractionDimensions>
inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iterators) {
llvm::SmallDenseSet<int64_t> a =
@@ -301,8 +309,10 @@ FailureOr<ContractionDimensions>
mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
if (indexingMaps.size() != 3)
return failure();
- return inferContractionDimsImpl(indexingMaps,
- inferIteratorsFromOutMap(indexingMaps[2]));
+ auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(iterators))
+ return failure();
+ return inferContractionDimsImpl(indexingMaps, iterators.value());
}
namespace mlir::linalg::detail {
More information about the Mlir-commits
mailing list