[Mlir-commits] [mlir] 89675aa - [mlir][Linalg] NFC - Extract inferGemmDims

Nicolas Vasilache llvmlistbot at llvm.org
Tue Feb 7 10:15:46 PST 2023


Author: Nicolas Vasilache
Date: 2023-02-07T10:05:18-08:00
New Revision: 89675aaba7e3f5452988aca773a017959da97ee9

URL: https://github.com/llvm/llvm-project/commit/89675aaba7e3f5452988aca773a017959da97ee9
DIFF: https://github.com/llvm/llvm-project/commit/89675aaba7e3f5452988aca773a017959da97ee9.diff

LOG: [mlir][Linalg] NFC - Extract inferGemmDims

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 8ac2c2e0ad98c..f3e0f5618bf74 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/RegionKindInterface.h"
 
@@ -50,6 +51,33 @@ class DialectRegistry;
 
 namespace transform {
 
+/// Return the set of `linalgOp` iterator positions for which the indexing map
+/// for `opOperand` is a permutation (i.e. an AffineDimExpr).
+DenseSet<int64_t> findPermutationsIndexingOperand(linalg::LinalgOp linalgOp,
+                                                  OpOperand *opOperand,
+                                                  utils::IteratorType iter);
+
+/// Possible dimension candidates that define a gemm embedded in the indexing
+/// maps of a LinalgOp.
+struct GemmDimsForPacking {
+  DenseSet<int64_t> mPos, nPos, kPos;
+};
+
+/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
+/// a gemm subcomputation within `linalgOp`. These dimensions are such that:
+///   1. The m dimension is involved in an outer-product along LHS
+///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+///   2. The n dimension is involved in an outer-product along RHS
+///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+///   3. The k dimension appears as a permutation on LHS and RHS.
+///   4. m, n and k appear only once in any given indexing.
+/// This allows detecting that some gemm is embedded within `linalgOp` with some
+/// orthogonal heuristic.
+FailureOr<GemmDimsForPacking> inferGemmDims(linalg::LinalgOp linalgOp);
+
+/// Return true if `linalgOp` contains an embedded gemm subcomputation.
+bool containsMostMinorGemm(linalg::LinalgOp linalgOp);
+
 /// Implementation of tiling operations using `scf.foreach_thread`.
 DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
@@ -57,6 +85,7 @@ DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
     SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+
 } // namespace transform
 
 namespace linalg {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 94725fa043406..2ab94bce70d29 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1282,10 +1282,8 @@ auto par = utils::IteratorType::parallel;
 auto red = utils::IteratorType::reduction;
 } // namespace
 
-/// Return the set of AffineDimExpr
-static DenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
-                                utils::IteratorType iter) {
+DenseSet<int64_t> transform::findPermutationsIndexingOperand(
+    LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
   DenseSet<int64_t> res;
   assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
   AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
@@ -1301,24 +1299,7 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
   return res;
 }
 
-struct GemmDimsForPacking {
-  int64_t mPos, nPos, kPos;
-};
-/// Greedily look for 2 parallel (m and n) and 1 reduction (k) dimension that
-/// form a gemm. Such dimensions are such that:
-///   1. The m dimension is involved in an outer-product along LHS
-///      (i.e. it is a permutation on RES and LHS and does not appear in RHS).
-///   2. The n dimension is involved in an outer-product along RHS
-///      (i.e. it is a permutation on RES and RHS and does not appear in LHS).
-///   3. The k dimension appears as a permutation on LHS and RHS.
-///   4. m, n and k appear only once in any given indexing.
-///
-/// This allows detecting that some gemm is embedded within `linalgOp`.
-///
-/// When multiple possibilities for selecting m, n and k appear, we just pick
-/// an arbitrary one (i.e. the first in a DenseSet).
-// TODO: Better heuristic (e.g pick dims based on packing-based metric).
-static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
+FailureOr<GemmDimsForPacking> transform::inferGemmDims(LinalgOp linalgOp) {
   assert(linalgOp.getNumDpsInits() == 1 && "wrong number of dps inits");
   assert(linalgOp.getNumDpsInputs() == 2 && "wrong number of dps inputs");
 
@@ -1352,18 +1333,31 @@ static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
 
   // Pick the first one in each set.
   // TODO: Better heuristic (e.g pick dims based on packing-based metric).
-  return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()};
-}
-
-/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m and
-/// n are proper parallel dimensions and k is a proper reduction dimension.
-/// Packing occurs by rewriting the op as a linalg.generic and calling
-/// linalg::pack by `mnkPackedSizes`.
-/// The order of the packed dimensions is customizable: the `mnkOrder` is a
-/// permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible
-/// forms.
-/// The outer dimensions of the operands are not permuted at this time, this is
-/// left for future work.
+  return GemmDimsForPacking{ac, bc, ra};
+}
+
+bool transform::containsMostMinorGemm(LinalgOp linalgOp) {
+  FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
+  if (failed(res))
+    return false;
+  int64_t numLoops = linalgOp.getNumLoops();
+  for (DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
+    if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
+        s.contains(numLoops - 1))
+      continue;
+    return false;
+  }
+  return true;
+}
+
+/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m
+/// and n are proper parallel dimensions and k is a proper reduction
+/// dimension. Packing occurs by rewriting the op as a linalg.generic and
+/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
+/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
+/// to reorder {m, n, k} into one of the 8 possible forms. The outer
+/// dimensions of the operands are not permuted at this time, this is left for
+/// future work.
 static FailureOr<PackResult>
 packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                  ArrayRef<OpFoldResult> mnkPackedSizes,
@@ -1388,7 +1382,7 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
     packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
 
   // 1. Infer dims that are important for gemm.
-  FailureOr<GemmDimsForPacking> res = getGemmDims(linalgOp);
+  FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
   if (failed(res)) {
     return rewriter.notifyMatchFailure(linalgOp,
                                        "couldn't infer gemm iterators");
@@ -1396,8 +1390,9 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
 
   // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
   // minor iterators. If we wanted a 
diff erent normalization order, this is
-  // where it would have to start.
-  int64_t mPos = res->mPos, nPos = res->nPos, kPos = res->kPos;
+  // where it would have to plug a heuristic.
+  int64_t mPos = *(res->mPos.begin()), nPos = *(res->nPos.begin()),
+          kPos = *(res->kPos.begin());
   LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
              DBGS() << "Start packing generic op greedily with (m@" << mPos
                     << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
@@ -1412,9 +1407,9 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
     genericOp = *generalizeResult;
   }
 
-  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor iterators.
-  // Note that this only normalized the iteration order and does not change the
-  // indexings of any operand.
+  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
+  // iterators. Note that this only normalized the iteration order and does
+  // not change the indexings of any operand.
   SmallVector<int64_t> permutation =
       computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
   LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
@@ -1446,7 +1441,10 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
 
   // TODO: If we wanted to give the genericOp a name after packing, after
   // calling `pack` would be a good time.
-  return linalg::pack(rewriter, genericOp, adjustedPackedSizes);
+  auto res = linalg::pack(rewriter, genericOp, adjustedPackedSizes);
+  assert(containsMostMinorGemm(res->packedLinalgOp) &&
+         "failed to pack to a most minor gemm");
+  return res;
 }
 
 DiagnosedSilenceableFailure


        


More information about the Mlir-commits mailing list