[Mlir-commits] [mlir] 352d6fe - [mlir][Linalg] NFC - Move transform utilities related to subcomputation inference to Linalg/Utils
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Mar 27 03:48:58 PDT 2023
Author: Nicolas Vasilache
Date: 2023-03-27T03:48:51-07:00
New Revision: 352d6fe1eb2214cae974c36ee0b1bbc2cc0f91e3
URL: https://github.com/llvm/llvm-project/commit/352d6fe1eb2214cae974c36ee0b1bbc2cc0f91e3
DIFF: https://github.com/llvm/llvm-project/commit/352d6fe1eb2214cae974c36ee0b1bbc2cc0f91e3.diff
LOG: [mlir][Linalg] NFC - Move transform utilities related to subcomputation inference to Linalg/Utils
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index a15c3a3b01a3c..6d7c802c3e533 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -53,33 +53,6 @@ class DialectRegistry;
namespace transform {
-/// Return the set of `linalgOp` iterator positions for which the indexing map
-/// for `opOperand` is a permutation (i.e. an AffineDimExpr).
-DenseSet<int64_t> findPermutationsIndexingOperand(linalg::LinalgOp linalgOp,
- OpOperand *opOperand,
- utils::IteratorType iter);
-
-/// Possible dimension candidates that define a gemm embedded in the indexing
-/// maps of a LinalgOp.
-struct GemmDimsForPacking {
- DenseSet<int64_t> mPos, nPos, kPos;
-};
-
-/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
-/// a gemm subcomputation within `linalgOp`. These dimensions are such that:
-/// 1. The m dimension is involved in an outer-product along LHS
-/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
-/// 2. The n dimension is involved in an outer-product along RHS
-/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
-/// 3. The k dimension appears as a permutation on LHS and RHS.
-/// 4. m, n and k appear only once in any given indexing.
-/// This allows detecting that some gemm is embedded within `linalgOp` with some
-/// orthogonal heuristic.
-FailureOr<GemmDimsForPacking> inferGemmDims(linalg::LinalgOp linalgOp);
-
-/// Return true if `linalgOp` contains an embedded gemm subcomputation.
-bool containsMostMinorGemm(linalg::LinalgOp linalgOp);
-
/// Implementation of tiling operations using `scf.forall`.
DiagnosedSilenceableFailure tileToForallOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c58e955cb7951..e107911af8b98 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -590,14 +590,14 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
Different packing strategies are applied in order, when one applies
successfully, the transform returns:
- 1. Gemm packing: Try to infer a gemm operation embedded in the target op.
+ 1. Matmul packing: Try to infer a matmul operation embedded in the target op.
Specifically, this looks for 2 parallel dimensions that participate in
an outer-product and 1 reduction dimension.
- These dimensions are referred as (m, n, k) to match canonical gemm
+ These dimensions are referred as (m, n, k) to match canonical matmul
terminology.
- The packed sizes for (m, n, k) are specified by `gemm_packed_sizes`.
+ The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`.
The ordering of the packed dimensions (mm, nn, kk) is specified by the
- `gemm_inner_dims_order` attribute.
+ `matmul_inner_dims_order` attribute.
Packing occurs as follows:
1. Find the dimensions to pack according to the strategy.
@@ -624,25 +624,25 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
// TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<PDL_Operation>:$gemm_packed_sizes,
+ Variadic<PDL_Operation>:$matmul_packed_sizes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">
- :$static_gemm_packed_sizes,
+ :$static_matmul_packed_sizes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">
- :$gemm_inner_dims_order);
+ :$matmul_inner_dims_order);
let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
let builders = [
OpBuilder<(ins "Value":$target,
- "ArrayRef<OpFoldResult>":$mixedGemmPackedSizes,
- CArg<"ArrayRef<int64_t>", "{}">:$gemmDimsInnerDimsOrder)>
+ "ArrayRef<OpFoldResult>":$mixedMatmulPackedSizes,
+ CArg<"ArrayRef<int64_t>", "{}">:$matmulDimsInnerDimsOrder)>
];
let assemblyFormat = [{
$target
oilist(
- `gemm_packed_sizes` `=` custom<DynamicIndexList>($gemm_packed_sizes,
- $static_gemm_packed_sizes)
- `gemm_inner_dims_order` `=` $gemm_inner_dims_order
+ `matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
+ $static_matmul_packed_sizes)
+ `matmul_inner_dims_order` `=` $matmul_inner_dims_order
)
attr-dict
`:` functional-type($target, results)
@@ -652,7 +652,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
let extraClassDeclaration = [{
/// Returns the list of tile sizes, which may be static (Attribute) or
/// dynamic (Value).
- SmallVector<OpFoldResult> getMixedGemmPackedSizes();
+ SmallVector<OpFoldResult> getMixedMatmulPackedSizes();
}];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 3c3fa70e161f7..4c23ceb82c4dd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -27,6 +27,44 @@ class ExtractSliceOp;
namespace linalg {
+//===----------------------------------------------------------------------===//
+// Utilities for inferring various semantics properties of Linalg ops.
+//===----------------------------------------------------------------------===//
+
+/// Possible dimension candidates that define a matmul embedded in the indexing
+/// maps of a LinalgOp.
+struct EmbeddedMatmulDimsCandidates {
+ DenseSet<int64_t> mPos, nPos, kPos;
+};
+
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+/// - It is a single AffineDimExpr.
+/// - It is the only result involving this AffineDimExpr.
+DenseSet<int64_t> findPermutationsIndexingOperand(LinalgOp linalgOp,
+ OpOperand *opOperand,
+ utils::IteratorType iter);
+
+/// Return true if `linalgOp` contains an embedded matmul subcomputation in its
+/// most minor dimensions.
+bool containsMostMinorMatmul(linalg::LinalgOp linalgOp);
+
+/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
+/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
+/// 1. The m dimension is involved in an outer-product along LHS
+/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+/// 2. The n dimension is involved in an outer-product along RHS
+/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+/// 3. The k dimension appears as a permutation on LHS and RHS.
+/// 4. m, n and k appear only once in any given indexing.
+/// This allows detecting that some matmul is embedded within `linalgOp` with
+/// some orthogonal heuristic.
+FailureOr<EmbeddedMatmulDimsCandidates>
+inferMatmulDims(linalg::LinalgOp linalgOp);
+
//===----------------------------------------------------------------------===//
// General utilities
//===----------------------------------------------------------------------===//
@@ -96,10 +134,10 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
/// Create a tensor::PadOp that pads `source` to the size of the statically
/// sized `type` whose static sizes are assumed to be greater than the dynamic
-/// `source` size. The padding introduces trailing `pad` values until the target
-/// size is met. If `source` is defined by one or more LinalgOps that have been
-/// padded with the same value and sizes, return their padded result instead of
-/// creating a tensor::PadOp.
+/// `source` size. The padding introduces trailing `pad` values until the
+/// target size is met. If `source` is defined by one or more LinalgOps that
+/// have been padded with the same value and sizes, return their padded result
+/// instead of creating a tensor::PadOp.
///
/// Example:
/// ```
@@ -116,8 +154,8 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold);
-/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using
-/// `transposeVector` to permute the `inputTensor` dimensions.
+/// Returns a GenericOp that transposes `inputTensor` into `outputTensor`
+/// using `transposeVector` to permute the `inputTensor` dimensions.
GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
Value outputTensor,
ArrayRef<int64_t> transposeVector);
@@ -127,12 +165,12 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
/// or vectorize.
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
-/// Get the reassociation maps to fold the result of a extract_slice (or source
-/// of a insert_slice) operation with given offsets, and sizes to its
+/// Get the reassociation maps to fold the result of a extract_slice (or
+/// source of a insert_slice) operation with given offsets, and sizes to its
/// rank-reduced version. This is only done for the cases where the size is 1
-/// and offset is 0. Strictly speaking the offset 0 is not required in general,
-/// but non-zero offsets are not handled by SPIR-V backend at this point (and
-/// potentially cannot be handled).
+/// and offset is 0. Strictly speaking the offset 0 is not required in
+/// general, but non-zero offsets are not handled by SPIR-V backend at this
+/// point (and potentially cannot be handled).
std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
@@ -151,8 +189,9 @@ enum class LinalgTilingLoopType {
ParallelLoops = 2
};
-/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
-/// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
+/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case
+/// a tile size is zero (i.e., no tiling), the corresponding offset is also
+/// zero.
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> ivs,
ArrayRef<OpFoldResult> tileSizes);
@@ -166,15 +205,16 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> sizeBounds);
/// Returns the list of tensor output types produced when the given structured
-/// operation `op` is applied to the given `operands`. Note that `operands` are
-/// not necessarily the actual operands of `op`.
+/// operation `op` is applied to the given `operands`. Note that `operands`
+/// are not necessarily the actual operands of `op`.
SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands);
/// Creates `insert_slice` ops that insert `results` back into larger tensors
-/// they were originally extracted from with `extract_slice` before being passed
-/// as `operands` to the given structured operation `op` or its clone. Note that
-/// `operands` are not necessarily the actual operands of `op`, the operation
-/// serves only as metadata container for operand types and positions.
+/// they were originally extracted from with `extract_slice` before being
+/// passed as `operands` to the given structured operation `op` or its clone.
+/// Note that `operands` are not necessarily the actual operands of `op`, the
+/// operation serves only as metadata container for operand types and
+/// positions.
SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
LinalgOp op, ValueRange operands,
ValueRange results);
@@ -187,8 +227,8 @@ struct SliceParameters {
};
/// Computes SliceParameters for a single `valueToTile` assuming that its user
-/// is being tiled with the given loop bounds `lbs` and `ubs` and the tile sizes
-/// `tileSizes`.
+/// is being tiled with the given loop bounds `lbs` and `ubs` and the tile
+/// sizes `tileSizes`.
///
/// `omitPartialTileCheck` controls whether to omit the partial/boundary tile
/// condition check in cases where we statically know that it is unnecessary.
@@ -219,8 +259,8 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
/// Creates an extract_slice/subview op for a single `valueToTile` with
/// `builder`. This new operation extracts a tile of `valueToTile`, starting
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
-/// controls whether to omit the partial/boundary tile condition check in cases
-/// where we statically know that it is unnecessary.
+/// controls whether to omit the partial/boundary tile condition check in
+/// cases where we statically know that it is unnecessary.
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
@@ -232,8 +272,8 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
/// nest for tiling with the given induction variables `ivs` and tile sizes
/// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
/// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to
-/// omit the partial/boundary tile condition check in cases where we statically
-/// know that it is unnecessary.
+/// omit the partial/boundary tile condition check in cases where we
+/// statically know that it is unnecessary.
///
/// Note that a constant zero in `tileSizes` means no tiling at that implicit
/// loop. The number of non-zero values in `tileSizes` should be equal to the
@@ -254,8 +294,9 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
ArrayRef<OpFoldResult> offests);
/// A struct containing the Linalg producer before and after fusion.
-/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
-/// before the consumer Linalg op, until enough canonicalizations have applied.
+/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast`
+/// op before the consumer Linalg op, until enough canonicalizations have
+/// applied.
struct FusionInfo {
LinalgOp originalProducer;
LinalgOp fusedProducer;
@@ -285,19 +326,23 @@ FailureOr<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
/// Scheme used to distribute loops to processors.
enum class DistributionMethod {
/// Cyclic distribution where no assumption is made about the dynamic
- /// relationship between number of processors and number of iterations of the
+ /// relationship between number of processors and number of iterations of
+ /// the
/// distributed loop. Distributes the following loop
///
/// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
///
/// to
///
- /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * %nprocs)
+ /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step *
+ /// %nprocs)
Cyclic = 0,
/// Cyclic distribution where the number of processors can be assumed to be
- /// more than or equal to the number of iterations of the distributed loop. In
- /// such cases, a simple in-bounds check is enough (instead of materializing a
+ /// more than or equal to the number of iterations of the distributed loop.
+ /// In
+ /// such cases, a simple in-bounds check is enough (instead of materializing
+ /// a
/// loop). Distributes the following loop
///
/// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
@@ -312,7 +357,8 @@ enum class DistributionMethod {
CyclicNumProcsGeNumIters = 1,
/// Cyclic distribution where the number of processors can be assumed to be
- /// equal to the number of iterations of the distributed loop. In such cases,
+ /// equal to the number of iterations of the distributed loop. In such
+ /// cases,
/// no bounds check is needed. Distributes the following loop
///
/// scf.parallel (%iv) = (%lb) to (%ub) step (%step)
@@ -339,16 +385,17 @@ using ProcInfoCallBackFn = std::function<SmallVector<ProcInfo>(
/// Options that allow distribution of loops generated in Linalg transforms to
/// processors while generating the loops.
struct LinalgLoopDistributionOptions {
- /// Callback function that returns the Values for processor ID (`procId`), and
- /// number of processors (`nprocs`) used to execute the parallel loops. The
- /// number of `{procId, nprocs}` pairs returned must be equal to the number of
- /// `parallelLoopRanges` passed into the callback. The `parallelLoopRanges`
- /// are ranges of the outer parallel loops of the operation that
- /// do have non-zero tile sizes specified.
+ /// Callback function that returns the Values for processor ID (`procId`),
+ /// and number of processors (`nprocs`) used to execute the parallel loops.
+ /// The number of `{procId, nprocs}` pairs returned must be equal to the
+ /// number of `parallelLoopRanges` passed into the callback. The
+ /// `parallelLoopRanges` are ranges of the outer parallel loops of the
+ /// operation that do have non-zero tile sizes specified.
ProcInfoCallBackFn procInfo;
};
-/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
+/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and
+/// `step`.
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
Value procId, Value nprocs, Value &lb,
Value &ub, Value &step);
@@ -362,15 +409,15 @@ class TileLoopNest {
public:
TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {}
- /// Tile the root operation using the given `tileSizes` and `tileInterchange`,
- /// and `tileDistribution`.
+ /// Tile the root operation using the given `tileSizes` and
+ /// `tileInterchange`, and `tileDistribution`.
LogicalResult
tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
std::optional<LinalgLoopDistributionOptions> tileDistribution);
- /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
- /// the fused producer or fails if fusion is not possible.
+ /// Fuse the producer of `consumerOpOperand` into the tile loop nest.
+ /// Returns the fused producer or fails if fusion is not possible.
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
/// Returns the replacement results for the original untiled root operation.
@@ -426,8 +473,8 @@ struct RegionMatcher {
IAdd,
};
- /// Matches the given linalg op if its body is performing binary operation on
- /// int or float scalar values and returns the binary op kind.
+ /// Matches the given linalg op if its body is performing binary operation
+ /// on int or float scalar values and returns the binary op kind.
///
/// The linalg op's region is expected to be
/// ```
@@ -445,9 +492,10 @@ struct RegionMatcher {
//===----------------------------------------------------------------------===//
/// Utility class used to generate nested loops with ranges described by
-/// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
-/// is used to generate the body of the innermost loop. It is passed a range
-/// of loop induction variables and a range of operand values to use.
+/// `loopRanges` and loop type described by the `iteratorTypes`.
+/// `bodyBuilderFn` is used to generate the body of the innermost loop. It is
+/// passed a range of loop induction variables and a range of operand values
+/// to use.
template <typename LoopTy>
struct GenerateLoopNest {
static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e3c1429ade54a..6ee0f13049977 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -27,10 +27,6 @@
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
@@ -38,9 +34,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
-#include "llvm/ADT/SetOperations.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
@@ -1300,94 +1293,21 @@ void transform::PackOp::getEffects(
//===---------------------------------------------------------------------===//
LogicalResult transform::PackGreedilyOp::verify() {
- if (!isPermutationVector(getGemmInnerDimsOrder())) {
- return emitOpError() << getGemmInnerDimsOrderAttrName()
+ if (!isPermutationVector(getMatmulInnerDimsOrder())) {
+ return emitOpError() << getMatmulInnerDimsOrderAttrName()
<< " is not a valid permutation";
}
- // TODO: relax to allow empty once we have another strategy than just gemm.
- if (getGemmInnerDimsOrder().size() != 3 ||
- getMixedGemmPackedSizes().size() != 3) {
- return emitOpError() << " needs 3 entries for gemm_packed_sizes and "
- << getGemmInnerDimsOrderAttrName()
- << " order for the gemm strategy";
+ // TODO: relax to allow empty once we have another strategy than just matmul.
+ if (getMatmulInnerDimsOrder().size() != 3 ||
+ getMixedMatmulPackedSizes().size() != 3) {
+ return emitOpError() << " needs 3 entries for matmul_packed_sizes and "
+ << getMatmulInnerDimsOrderAttrName()
+ << " order for the matmul strategy";
}
return success();
}
-namespace {
-auto par = utils::IteratorType::parallel;
-auto red = utils::IteratorType::reduction;
-} // namespace
-
-DenseSet<int64_t> transform::findPermutationsIndexingOperand(
- LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
- DenseSet<int64_t> res;
- assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
- for (AffineExpr e : indexingMap.getResults()) {
- if (auto d = e.dyn_cast<AffineDimExpr>()) {
- if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
- llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
- return e.isFunctionOfDim(d.getPosition());
- }) == 1)
- res.insert(d.getPosition());
- }
- }
- return res;
-}
-
-FailureOr<GemmDimsForPacking> transform::inferGemmDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
- return failure();
-
- DenseSet<int64_t> a = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), par);
- DenseSet<int64_t> b = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
- DenseSet<int64_t> c = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
-
- // A & C - B are the iterators involved in an outer-product along A (the LHS).
- DenseSet<int64_t> ac = a;
- llvm::set_intersect(ac, c);
- llvm::set_subtract(ac, b);
- // B & C - A are the iterators involved in an outer-product along B (the RHS).
- DenseSet<int64_t> bc = b;
- llvm::set_intersect(bc, c);
- llvm::set_subtract(bc, a);
-
- // Note: if we ever need them, A & B & C would be "batch" dimensions.
-
- // A & B red are the reduction dimensions.
- DenseSet<int64_t> ra = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), red);
- DenseSet<int64_t> rb = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), red);
- llvm::set_intersect(ra, rb);
-
- if (ac.empty() || bc.empty() || ra.empty())
- return failure();
-
- // Pick the first one in each set.
- // TODO: Better heuristic (e.g pick dims based on packing-based metric).
- return GemmDimsForPacking{ac, bc, ra};
-}
-
-bool transform::containsMostMinorGemm(LinalgOp linalgOp) {
- FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
- if (failed(res))
- return false;
- int64_t numLoops = linalgOp.getNumLoops();
- for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
- if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
- s.contains(numLoops - 1))
- continue;
- return false;
- }
- return true;
-}
-
-/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m
+/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
/// and n are proper parallel dimensions and k is a proper reduction
/// dimension. Packing occurs by rewriting the op as a linalg.generic and
/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
@@ -1396,17 +1316,17 @@ bool transform::containsMostMinorGemm(LinalgOp linalgOp) {
/// dimensions of the operands are not permuted at this time, this is left for
/// future work.
static FailureOr<PackResult>
-packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
- ArrayRef<OpFoldResult> mnkPackedSizes,
- ArrayRef<int64_t> mnkOrder) {
+packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
+ ArrayRef<OpFoldResult> mnkPackedSizes,
+ ArrayRef<int64_t> mnkOrder) {
assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
assert(isPermutationVector(mnkOrder) && "expected a permutation");
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
- return rewriter.notifyMatchFailure(linalgOp,
- "need 3+ loops to find a gemm to pack");
+ return rewriter.notifyMatchFailure(
+ linalgOp, "need 3+ loops to find a matmul to pack");
}
// Locally adjust the desired iterator position of mnk and packing sizes.
@@ -1418,11 +1338,11 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
- // 1. Infer dims that are important for gemm.
- FailureOr<GemmDimsForPacking> res = inferGemmDims(linalgOp);
+ // 1. Infer dims that are important for matmul.
+ FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
if (failed(res)) {
return rewriter.notifyMatchFailure(linalgOp,
- "couldn't infer gemm iterators");
+ "couldn't infer matmul iterators");
}
// 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
@@ -1479,8 +1399,8 @@ packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time.
auto packingRes = linalg::pack(rewriter, genericOp, adjustedPackedSizes);
- assert(containsMostMinorGemm(packingRes->packedLinalgOp) &&
- "failed to pack to a most minor gemm");
+ assert(containsMostMinorMatmul(packingRes->packedLinalgOp) &&
+ "failed to pack to a most minor matmul");
return packingRes;
}
@@ -1500,11 +1420,11 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
rewriter.setInsertionPointAfter(linalgOp);
// Failing to pack greedily is perfectly fine.
// In the future we will want to order packings according to some metric.
- FailureOr<PackResult> packResult = packGemmGreedily(
+ FailureOr<PackResult> packResult = packMatmulGreedily(
/*rewriter=*/rewriter,
/*linalgOp=*/linalgOp,
- /*mnkPackedSizes=*/getMixedGemmPackedSizes(),
- /*mnkOrder=*/getGemmInnerDimsOrder());
+ /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
+ /*mnkOrder=*/getMatmulInnerDimsOrder());
if (succeeded(packResult)) {
results.push_back(packResult->packedLinalgOp);
continue;
@@ -1515,15 +1435,16 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
return DiagnosedSilenceableFailure::success();
}
-SmallVector<OpFoldResult> PackGreedilyOp::getMixedGemmPackedSizes() {
+SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
Builder b(getContext());
- return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b);
+ return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
+ b);
}
void transform::PackGreedilyOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
- transform::onlyReadsHandle(getGemmPackedSizes(), effects);
+ transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
transform::producesHandle(getPackedOp(), effects);
transform::modifiesPayload(effects);
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 75f818b1b275d..572b7e40d4c95 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -33,6 +33,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -138,6 +139,88 @@ static void unpackRanges(OpBuilder &builder, Location loc,
}
}
+//===----------------------------------------------------------------------===//
+// Utilities for inferring various semantics properties of Linalg ops.
+//===----------------------------------------------------------------------===//
+
+DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand(
+ LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
+ DenseSet<int64_t> res;
+ assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+ for (AffineExpr e : indexingMap.getResults()) {
+ if (auto d = e.dyn_cast<AffineDimExpr>()) {
+ if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+ llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
+ return e.isFunctionOfDim(d.getPosition());
+ }) == 1)
+ res.insert(d.getPosition());
+ }
+ }
+ return res;
+}
+
+namespace {
+auto par = utils::IteratorType::parallel;
+auto red = utils::IteratorType::reduction;
+} // namespace
+
+bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
+ FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
+ if (failed(res))
+ return false;
+ int64_t numLoops = linalgOp.getNumLoops();
+ for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
+ if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
+ s.contains(numLoops - 1))
+ continue;
+ return false;
+ }
+ return true;
+}
+
+FailureOr<EmbeddedMatmulDimsCandidates>
+mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+
+ DenseSet<int64_t> a = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(0), par);
+ DenseSet<int64_t> b = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(1), par);
+ DenseSet<int64_t> c = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+ // A & C - B are the iterators involved in an outer-product along A (the LHS).
+ DenseSet<int64_t> ac = a;
+ llvm::set_intersect(ac, c);
+ llvm::set_subtract(ac, b);
+ // B & C - A are the iterators involved in an outer-product along B (the RHS).
+ DenseSet<int64_t> bc = b;
+ llvm::set_intersect(bc, c);
+ llvm::set_subtract(bc, a);
+
+ // Note: if we ever need them, A & B & C would be "batch" dimensions.
+
+ // A & B red are the reduction dimensions.
+ DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(0), red);
+ DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(1), red);
+ llvm::set_intersect(ra, rb);
+
+ if (ac.empty() || bc.empty() || ra.empty())
+ return failure();
+
+ // Pick the first one in each set.
+ // TODO: Better heuristic (e.g pick dims based on packing-based metric).
+ return EmbeddedMatmulDimsCandidates{ac, bc, ra};
+}
+
+//===----------------------------------------------------------------------===//
+// General utilities
+//===----------------------------------------------------------------------===//
+
namespace mlir {
namespace linalg {
diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
index 42f3a6cd0fba0..544f4391eb39a 100644
--- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
+++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
@@ -25,7 +25,7 @@ transform.sequence failures(propagate) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op
: (!pdl.operation) -> !transform.op<"linalg.matmul">
transform.structured.pack_greedily %matmul
- gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic">
}
@@ -70,7 +70,7 @@ transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
- gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
@@ -115,7 +115,7 @@ transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
- gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
@@ -160,7 +160,7 @@ transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
- gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
@@ -195,7 +195,7 @@ transform.sequence failures(propagate) {
%conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op
: (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw">
transform.structured.pack_greedily %conv
- gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic">
}
@@ -223,6 +223,6 @@ transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
- gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
More information about the Mlir-commits
mailing list