[Mlir-commits] [mlir] f8b8aff - [mlir][Linalg]Add support for inferring batch dimensions

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 27 05:30:16 PDT 2023


Author: Nicolas Vasilache
Date: 2023-06-27T12:29:51Z
New Revision: f8b8affc4a6a5df346b1c29b4b7bf4c46e70d3a1

URL: https://github.com/llvm/llvm-project/commit/f8b8affc4a6a5df346b1c29b4b7bf4c46e70d3a1
DIFF: https://github.com/llvm/llvm-project/commit/f8b8affc4a6a5df346b1c29b4b7bf4c46e70d3a1.diff

LOG: [mlir][Linalg]Add support for inferring batch dimensions

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ed8c08d1c640f..ccd650c7d7263 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -35,10 +35,10 @@ namespace linalg {
 // Utilities for inferring various semantics properties of Linalg ops.
 //===----------------------------------------------------------------------===//
 
-/// Possible dimension candidates that define a matmul embedded in the indexing
-/// maps of a LinalgOp.
-struct EmbeddedMatmulDimsCandidates {
-  DenseSet<int64_t> mPos, nPos, kPos;
+/// Possible dimension candidates that define a contraction embedded in the
+/// indexing maps of a LinalgOp.
+struct EmbeddedContractionDimsCandidates {
+  DenseSet<int64_t> batchPos, mPos, nPos, kPos;
 };
 
 /// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
@@ -64,10 +64,11 @@ bool containsMostMinorMatmul(linalg::LinalgOp linalgOp);
 ///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
 ///   3. The k dimension appears as a permutation on LHS and RHS.
 ///   4. m, n and k appear only once in any given indexing.
-/// This allows detecting that some matmul is embedded within `linalgOp` with
-/// some orthogonal heuristic.
-FailureOr<EmbeddedMatmulDimsCandidates>
-inferMatmulDims(linalg::LinalgOp linalgOp);
+///   5. Optional batch dimensions that appear in all operands are captured.
+/// This allows e.g. detecting that some contraction is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+FailureOr<EmbeddedContractionDimsCandidates>
+inferContractionDims(linalg::LinalgOp linalgOp);
 
 //===----------------------------------------------------------------------===//
 // General utilities

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2148624c2963d..ae0c4d9f9688b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1240,7 +1240,7 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
   }
 
   // 1. Infer dims that are important for matmul.
-  FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
+  FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
   if (failed(res)) {
     return rewriter.notifyMatchFailure(linalgOp,
                                        "couldn't infer matmul iterators");

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d5eea24534982..e3a9569f623ff 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -167,7 +167,7 @@ auto red = utils::IteratorType::reduction;
 } // namespace
 
 bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
-  FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
+  FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
   if (failed(res))
     return false;
   int64_t numLoops = linalgOp.getNumLoops();
@@ -180,8 +180,8 @@ bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
   return true;
 }
 
-FailureOr<EmbeddedMatmulDimsCandidates>
-mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
+FailureOr<EmbeddedContractionDimsCandidates>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
   if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
     return failure();
 
@@ -200,8 +200,10 @@ mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
   DenseSet<int64_t> bc = b;
   llvm::set_intersect(bc, c);
   llvm::set_subtract(bc, a);
-
-  // Note: if we ever need them, A & B & C would be "batch" dimensions.
+  // A & B & C are the "batch" dimensions.
+  DenseSet<int64_t> batches = a;
+  llvm::set_intersect(batches, b);
+  llvm::set_intersect(batches, c);
 
   // A & B red are the reduction dimensions.
   DenseSet<int64_t> ra = findPermutationsIndexingOperand(
@@ -215,7 +217,7 @@ mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
 
   // Pick the first one in each set.
   // TODO: Better heuristic (e.g pick dims based on packing-based metric).
-  return EmbeddedMatmulDimsCandidates{ac, bc, ra};
+  return EmbeddedContractionDimsCandidates{batches, ac, bc, ra};
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list