[Mlir-commits] [mlir] f8b8aff - [mlir][Linalg]Add support for inferring batch dimensions
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jun 27 05:30:16 PDT 2023
Author: Nicolas Vasilache
Date: 2023-06-27T12:29:51Z
New Revision: f8b8affc4a6a5df346b1c29b4b7bf4c46e70d3a1
URL: https://github.com/llvm/llvm-project/commit/f8b8affc4a6a5df346b1c29b4b7bf4c46e70d3a1
DIFF: https://github.com/llvm/llvm-project/commit/f8b8affc4a6a5df346b1c29b4b7bf4c46e70d3a1.diff
LOG: [mlir][Linalg]Add support for inferring batch dimensions
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ed8c08d1c640f..ccd650c7d7263 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -35,10 +35,10 @@ namespace linalg {
// Utilities for inferring various semantics properties of Linalg ops.
//===----------------------------------------------------------------------===//
-/// Possible dimension candidates that define a matmul embedded in the indexing
-/// maps of a LinalgOp.
-struct EmbeddedMatmulDimsCandidates {
- DenseSet<int64_t> mPos, nPos, kPos;
+/// Possible dimension candidates that define a contraction embedded in the
+/// indexing maps of a LinalgOp.
+struct EmbeddedContractionDimsCandidates {
+ DenseSet<int64_t> batchPos, mPos, nPos, kPos;
};
/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
@@ -64,10 +64,11 @@ bool containsMostMinorMatmul(linalg::LinalgOp linalgOp);
/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
/// 3. The k dimension appears as a permutation on LHS and RHS.
/// 4. m, n and k appear only once in any given indexing.
-/// This allows detecting that some matmul is embedded within `linalgOp` with
-/// some orthogonal heuristic.
-FailureOr<EmbeddedMatmulDimsCandidates>
-inferMatmulDims(linalg::LinalgOp linalgOp);
+/// 5. Optional batch dimensions that appear in all operands are captured.
+/// This allows e.g. detecting that some contraction is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+FailureOr<EmbeddedContractionDimsCandidates>
+inferContractionDims(linalg::LinalgOp linalgOp);
//===----------------------------------------------------------------------===//
// General utilities
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2148624c2963d..ae0c4d9f9688b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1240,7 +1240,7 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
}
// 1. Infer dims that are important for matmul.
- FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
+ FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
if (failed(res)) {
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer matmul iterators");
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d5eea24534982..e3a9569f623ff 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -167,7 +167,7 @@ auto red = utils::IteratorType::reduction;
} // namespace
bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
- FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
+ FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
if (failed(res))
return false;
int64_t numLoops = linalgOp.getNumLoops();
@@ -180,8 +180,8 @@ bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
return true;
}
-FailureOr<EmbeddedMatmulDimsCandidates>
-mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
+FailureOr<EmbeddedContractionDimsCandidates>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
return failure();
@@ -200,8 +200,10 @@ mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
DenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
-
- // Note: if we ever need them, A & B & C would be "batch" dimensions.
+ // A & B & C are the "batch" dimensions.
+ DenseSet<int64_t> batches = a;
+ llvm::set_intersect(batches, b);
+ llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
DenseSet<int64_t> ra = findPermutationsIndexingOperand(
@@ -215,7 +217,7 @@ mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
// Pick the first one in each set.
// TODO: Better heuristic (e.g pick dims based on packing-based metric).
- return EmbeddedMatmulDimsCandidates{ac, bc, ra};
+ return EmbeddedContractionDimsCandidates{batches, ac, bc, ra};
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list