[Mlir-commits] [mlir] 89675aa - [mlir][Linalg] NFC - Extract inferGemmDims
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Feb 7 10:15:46 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-07T10:05:18-08:00
New Revision: 89675aaba7e3f5452988aca773a017959da97ee9
URL: https://github.com/llvm/llvm-project/commit/89675aaba7e3f5452988aca773a017959da97ee9
DIFF: https://github.com/llvm/llvm-project/commit/89675aaba7e3f5452988aca773a017959da97ee9.diff
LOG: [mlir][Linalg] NFC - Extract inferGemmDims
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 8ac2c2e0ad98c..f3e0f5618bf74 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"
@@ -50,6 +51,33 @@ class DialectRegistry;
namespace transform {
+/// Return the set of `linalgOp` iterator positions for which the indexing map
+/// for `opOperand` is a permutation (i.e. an AffineDimExpr).
+DenseSet<int64_t> findPermutationsIndexingOperand(linalg::LinalgOp linalgOp,
+ OpOperand *opOperand,
+ utils::IteratorType iter);
+
+/// Possible dimension candidates that define a gemm embedded in the indexing
+/// maps of a LinalgOp.
+struct GemmDimsForPacking {
+ DenseSet<int64_t> mPos, nPos, kPos;
+};
+
+/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
+/// a gemm subcomputation within `linalgOp`. These dimensions are such that:
+/// 1. The m dimension is involved in an outer-product along LHS
+/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+/// 2. The n dimension is involved in an outer-product along RHS
+/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+/// 3. The k dimension appears as a permutation on LHS and RHS.
+/// 4. m, n and k appear only once in any given indexing.
+/// This allows detecting that some gemm is embedded within `linalgOp` with some
+/// orthogonal heuristic.
+FailureOr<GemmDimsForPacking> inferGemmDims(linalg::LinalgOp linalgOp);
+
+/// Return true if `linalgOp` contains an embedded gemm subcomputation.
+bool containsMostMinorGemm(linalg::LinalgOp linalgOp);
+
/// Implementation of tiling operations using `scf.foreach_thread`.
DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
@@ -57,6 +85,7 @@ DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+
} // namespace transform
namespace linalg {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 94725fa043406..2ab94bce70d29 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1282,10 +1282,8 @@ auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
-/// Return the set of AffineDimExpr
-static DenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
- utils::IteratorType iter) {
+DenseSet<int64_t> transform::findPermutationsIndexingOperand(
+ LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
DenseSet<int64_t> res;
assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
@@ -1301,24 +1299,7 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
return res;
}
-struct GemmDimsForPacking {
- int64_t mPos, nPos, kPos;
-};
-/// Greedily look for 2 parallel (m and n) and 1 reduction (k) dimension that
-/// form a gemm. Such dimensions are such that:
-/// 1. The m dimension is involved in an outer-product along LHS
-/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
-/// 2. The n dimension is involved in an outer-product along RHS
-/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
-/// 3. The k dimension appears as a permutation on LHS and RHS.
-/// 4. m, n and k appear only once in any given indexing.
-///
-/// This allows detecting that some gemm is embedded within `linalgOp`.
-///
-/// When multiple possibilities for selecting m, n and k appear, we just pick
-/// an arbitrary one (i.e. the first in a DenseSet).
-// TODO: Better heuristic (e.g pick dims based on packing-based metric).
-static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
+FailureOr<GemmDimsForPacking> transform::inferGemmDims(LinalgOp linalgOp) {
assert(linalgOp.getNumDpsInits() == 1 && "wrong number of dps inits");
assert(linalgOp.getNumDpsInputs() == 2 && "wrong number of dps inputs");
@@ -1352,18 +1333,31 @@ static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
// Pick the first one in each set.
// TODO: Better heuristic (e.g pick dims based on packing-based metric).
- return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()};
-}
-
-/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m and
-/// n are proper parallel dimensions and k is a proper reduction dimension.
-/// Packing occurs by rewriting the op as a linalg.generic and calling
-/// linalg::pack by `mnkPackedSizes`.
-/// The order of the packed dimensions is customizable: the `mnkOrder` is a
-/// permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible
-/// forms.
-/// The outer dimensions of the operands are not permuted at this time, this is
-/// left for future work.
+ return GemmDimsForPacking{ac, bc, ra};
+}
+
+bool transform::containsMostMinorGemm(LinalgOp linalgOp) {
+ FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
+ if (failed(res))
+ return false;
+ int64_t numLoops = linalgOp.getNumLoops();
+ for (DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
+ if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
+ s.contains(numLoops - 1))
+ continue;
+ return false;
+ }
+ return true;
+}
+
+/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m
+/// and n are proper parallel dimensions and k is a proper reduction
+/// dimension. Packing occurs by rewriting the op as a linalg.generic and
+/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
+/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
+/// to reorder {m, n, k} into one of the 8 possible forms. The outer
+/// dimensions of the operands are not permuted at this time, this is left for
+/// future work.
static FailureOr<PackResult>
packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<OpFoldResult> mnkPackedSizes,
@@ -1388,7 +1382,7 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
// 1. Infer dims that are important for gemm.
- FailureOr<GemmDimsForPacking> res = getGemmDims(linalgOp);
+ FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
if (failed(res)) {
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer gemm iterators");
@@ -1396,8 +1390,9 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
// minor iterators. If we wanted a
diff erent normalization order, this is
- // where it would have to start.
- int64_t mPos = res->mPos, nPos = res->nPos, kPos = res->kPos;
+ // where it would have to plug a heuristic.
+ int64_t mPos = *(res->mPos.begin()), nPos = *(res->nPos.begin()),
+ kPos = *(res->kPos.begin());
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "Start packing generic op greedily with (m@" << mPos
<< ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
@@ -1412,9 +1407,9 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
genericOp = *generalizeResult;
}
- // 2.b. Interchange to move the dimensions (k, m, n) as most-minor iterators.
- // Note that this only normalized the iteration order and does not change the
- // indexings of any operand.
+ // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
+ // iterators. Note that this only normalized the iteration order and does
+ // not change the indexings of any operand.
SmallVector<int64_t> permutation =
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
@@ -1446,7 +1441,10 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time.
- return linalg::pack(rewriter, genericOp, adjustedPackedSizes);
+ auto res = linalg::pack(rewriter, genericOp, adjustedPackedSizes);
+ assert(containsMostMinorGemm(res->packedLinalgOp) &&
+ "failed to pack to a most minor gemm");
+ return res;
}
DiagnosedSilenceableFailure
More information about the Mlir-commits
mailing list