[Mlir-commits] [mlir] 98d6ab9 - [mlir][Linalg] Refactor isaContractionOpInterface and surrounding utils
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jun 28 03:19:06 PDT 2023
Author: Nicolas Vasilache
Date: 2023-06-28T10:19:00Z
New Revision: 98d6ab9d6a4918e3887334672716278a1b632c12
URL: https://github.com/llvm/llvm-project/commit/98d6ab9d6a4918e3887334672716278a1b632c12
DIFF: https://github.com/llvm/llvm-project/commit/98d6ab9d6a4918e3887334672716278a1b632c12.diff
LOG: [mlir][Linalg] Refactor isaContractionOpInterface and surrounding utils
This is almost NFC except for the fact that:
- when multiple candidates are available we now return them in sorted order vs undetermined order previously
- the type of the transform return is relaxed an a test is added for the case where the transform does not apply
Differential Revision: https://reviews.llvm.org/D153941
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index cb93e8a7bc104..0562f3779e08b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -36,12 +36,60 @@ bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp,
ArrayRef<OpOperand *> droppedOperands);
} // namespace detail
+/// Positions of a Linalg op loops that correspond to
diff erent kinds of a
+/// contraction dimension.
+struct ContractionDimensions {
+ SmallVector<unsigned, 2> batch;
+ SmallVector<unsigned, 2> m;
+ SmallVector<unsigned, 2> n;
+ SmallVector<unsigned, 2> k;
+};
+
+/// Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates
+/// that form a matmul subcomputation within `linalgOp`.
+/// These dimensions are such that:
+/// 1. The m dimension is involved in an outer-product along LHS
+/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+/// 2. The n dimension is involved in an outer-product along RHS
+/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+/// 3. The k dimension appears as a permutation on LHS and RHS.
+/// 4. m, n and k appear only once in any given indexing.
+/// 5. Optional batch dimensions that appear in all operands are captured.
+/// This allows e.g. detecting that some contraction is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+/// When multiple dimension occurrences exist that match `batch`, `m`, `n`, or
+/// `k`, indices are returned in sorted order.
+/// Returns a failure if any of `m`, `n` or `k` is empty.
+FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+
/// Checks whether `linalgOp` conforms to ContractionOpInterface.
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
bool isaContractionOpInterface(LinalgOp linalgOp);
+/// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
+// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
+bool isaConvolutionOpInterface(LinalgOp linalgOp);
+
namespace detail {
+/// Result of matching a Linalg generic against the predicates of it being a
+/// contractiom.
+enum class MatchContractionResult;
+
+/// Checks whether `op` conforms to ContractionOpInterface and populates
+/// `dimensions` with indexes of the
diff erent kinds of dimensions when
+/// present.
+// TODO: Extract a standalone `inferConvolutionDims` that can also detect
+// whether a conv pattern exists within a bigger linalg op (see
+// inferContractionDims).
+MatchContractionResult
+isContractionInterfaceImpl(Operation *op,
+ ContractionDimensions *dimensions = nullptr);
+
+/// Returns the error message corresponding to the contraction checking return
+/// code.
+StringRef getMatchContractionMessage(MatchContractionResult res);
+
/// Result of matching a Linalg generic against the predicates of it being a
/// convolution.
enum class MatchConvolutionResult;
@@ -58,7 +106,8 @@ struct ConvolutionDimensions {
};
/// Checks whether `op` conforms to ConvolutionOpInterface and populates
-/// `dimensions` with indexes of the
diff erent kinds of dimensions when present.
+/// `dimensions` with indexes of the
diff erent kinds of dimensions when
+/// present.
MatchConvolutionResult
isConvolutionInterfaceImpl(Operation *op,
ConvolutionDimensions *dimensions = nullptr);
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 3943e627d5737..bccdeaa9d5ef7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -722,7 +722,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
:$matmul_padded_sizes_next_multiple_of,
ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
[DenseArrayCount<3>]>:$matmul_inner_dims_order);
- let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
+ let results = (outs TransformHandleTypeInterface:$packed_op);
let builders = [
OpBuilder<(ins "Value":$target,
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ccd650c7d7263..a4c8baeb1d451 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -35,41 +35,6 @@ namespace linalg {
// Utilities for inferring various semantics properties of Linalg ops.
//===----------------------------------------------------------------------===//
-/// Possible dimension candidates that define a contraction embedded in the
-/// indexing maps of a LinalgOp.
-struct EmbeddedContractionDimsCandidates {
- DenseSet<int64_t> batchPos, mPos, nPos, kPos;
-};
-
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
-/// - It is a single AffineDimExpr.
-/// - It is the only result involving this AffineDimExpr.
-DenseSet<int64_t> findPermutationsIndexingOperand(LinalgOp linalgOp,
- OpOperand *opOperand,
- utils::IteratorType iter);
-
-/// Return true if `linalgOp` contains an embedded matmul subcomputation in its
-/// most minor dimensions.
-bool containsMostMinorMatmul(linalg::LinalgOp linalgOp);
-
-/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
-/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
-/// 1. The m dimension is involved in an outer-product along LHS
-/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
-/// 2. The n dimension is involved in an outer-product along RHS
-/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
-/// 3. The k dimension appears as a permutation on LHS and RHS.
-/// 4. m, n and k appear only once in any given indexing.
-/// 5. Optional batch dimensions that appear in all operands are captured.
-/// This allows e.g. detecting that some contraction is embedded within
-/// `linalgOp` with some orthogonal heuristic.
-FailureOr<EmbeddedContractionDimsCandidates>
-inferContractionDims(linalg::LinalgOp linalgOp);
-
//===----------------------------------------------------------------------===//
// General utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 137204343ef06..e928cc72a26a8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -17,7 +17,10 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::linalg;
@@ -112,6 +115,96 @@ static bool isAddMul(Block &block) {
return success;
}
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+/// - It is a single AffineDimExpr.
+/// - It is the only result involving this AffineDimExpr.
+static DenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+ utils::IteratorType iter) {
+ DenseSet<int64_t> res;
+ assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+ for (AffineExpr e : indexingMap.getResults()) {
+ if (auto d = e.dyn_cast<AffineDimExpr>()) {
+ if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+ llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
+ return e.isFunctionOfDim(d.getPosition());
+ }) == 1)
+ res.insert(d.getPosition());
+ }
+ }
+ return res;
+}
+
+namespace {
+auto par = utils::IteratorType::parallel;
+auto red = utils::IteratorType::reduction;
+} // namespace
+
+/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
+/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
+/// 1. The m dimension is involved in an outer-product along LHS
+/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
+/// 2. The n dimension is involved in an outer-product along RHS
+/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
+/// 3. The k dimension appears as a permutation on LHS and RHS.
+/// 4. m, n and k appear only once in any given indexing.
+/// 5. Optional batch dimensions that appear in all operands are captured.
+/// This allows e.g. detecting that some contraction is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+
+ DenseSet<int64_t> a = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(0), par);
+ DenseSet<int64_t> b = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(1), par);
+ DenseSet<int64_t> c = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+ // A & C - B are the iterators involved in an outer-product along A (the LHS).
+ DenseSet<int64_t> ac = a;
+ llvm::set_intersect(ac, c);
+ llvm::set_subtract(ac, b);
+ // B & C - A are the iterators involved in an outer-product along B (the RHS).
+ DenseSet<int64_t> bc = b;
+ llvm::set_intersect(bc, c);
+ llvm::set_subtract(bc, a);
+ // A & B & C are the "batch" dimensions.
+ DenseSet<int64_t> batches = a;
+ llvm::set_intersect(batches, b);
+ llvm::set_intersect(batches, c);
+
+ // A & B red are the reduction dimensions.
+ DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(0), red);
+ DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(1), red);
+ llvm::set_intersect(ra, rb);
+
+ if (ac.empty() || bc.empty() || ra.empty())
+ return failure();
+
+ // Return each set in sorted order.
+ ContractionDimensions dimensions{
+ SmallVector<unsigned, 2>(batches.begin(), batches.end()),
+ SmallVector<unsigned, 2>(ac.begin(), ac.end()),
+ SmallVector<unsigned, 2>(bc.begin(), bc.end()),
+ SmallVector<unsigned, 2>(ra.begin(), ra.end())};
+ std::sort(dimensions.batch.begin(), dimensions.batch.end());
+ std::sort(dimensions.m.begin(), dimensions.m.end());
+ std::sort(dimensions.n.begin(), dimensions.n.end());
+ std::sort(dimensions.k.begin(), dimensions.k.end());
+ return dimensions;
+}
+
+namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
NotLinalgOp,
@@ -120,7 +213,11 @@ enum class MatchContractionResult {
NotProjectedPermutations,
NotAddMul
};
-static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
+} // namespace mlir::linalg::detail
+
+mlir::linalg::detail::MatchContractionResult
+mlir::linalg::detail::isContractionInterfaceImpl(
+ Operation *op, mlir::linalg::ContractionDimensions *dimensions) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchContractionResult::NotLinalgOp;
@@ -139,15 +236,41 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
linalgOp->getRegion(0).front()) &&
!isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
return MatchContractionResult::NotAddMul;
+
+ if (dimensions) {
+ FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
+ assert(succeeded(res) && "unexpected failure to infer contraction dims");
+ *dimensions = *res;
+ }
return MatchContractionResult::Success;
}
+StringRef
+mlir::linalg::detail::getMatchContractionMessage(MatchContractionResult res) {
+ switch (res) {
+ case MatchContractionResult::NotLinalgOp:
+ return "expected a LinalgOp";
+ case MatchContractionResult::WrongNumOperands:
+ return "expected op with 2 inputs and 1 output";
+ case MatchContractionResult::NoReduction:
+ return "expected at least 1 reduction";
+ case MatchContractionResult::NotProjectedPermutations:
+ return "expected indexing maps to be projected permutations";
+ case MatchContractionResult::NotAddMul:
+ return "expected add/mul op in the body";
+ case MatchContractionResult::Success:
+ return "";
+ }
+ llvm_unreachable("unhandled MatchContractionResult case");
+}
+
bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
if (!linalgOp)
return false;
Operation *op = linalgOp.getOperation();
return isa<ContractionOpInterface>(op) ||
- (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
+ (mlir::linalg::detail::isContractionInterfaceImpl(op) ==
+ mlir::linalg::detail::MatchContractionResult::Success);
}
/// Verify that a LinalgOp `op` is a contraction.
@@ -165,16 +288,8 @@ bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
/// constant operations that do not involve the reduction dimension(s).
LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
auto res = isContractionInterfaceImpl(op);
- if (res == MatchContractionResult::NotLinalgOp)
- return op->emitError("expected a LinalgOp");
- if (res == MatchContractionResult::WrongNumOperands)
- return op->emitError("expected op with 2 inputs and 1 outputs");
- if (res == MatchContractionResult::NoReduction)
- return op->emitError("expected at least a reduction loop");
- if (res == MatchContractionResult::NotProjectedPermutations)
- return op->emitError("expected all indexings to be projected permutations");
- if (res == MatchContractionResult::NotAddMul)
- return op->emitError("(add, mul) operations not found");
+ if (res != MatchContractionResult::Success)
+ return op->emitError(getMatchContractionMessage(res));
return success();
}
@@ -454,6 +569,11 @@ mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
llvm_unreachable("unhandled MatchConvolutionResult case");
}
+bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) {
+ return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
+ linalg::detail::MatchConvolutionResult::Success;
+}
+
LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
if (res != MatchConvolutionResult::Success)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 49ed54673ba37..d702e6d7e7c6c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -49,6 +49,7 @@ using namespace mlir::transform;
#define DEBUG_TYPE "linalg-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
@@ -1227,6 +1228,8 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
+ LDBG("need 3+ loops to find a matmul to pack, got "
+ << numLoops << "\nin: " << linalgOp << "\n");
return rewriter.notifyMatchFailure(
linalgOp, "need 3+ loops to find a matmul to pack");
}
@@ -1247,17 +1250,21 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
}
// 1. Infer dims that are important for matmul.
- FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
- if (failed(res)) {
+ FailureOr<ContractionDimensions> maybeDimensions =
+ inferContractionDims(linalgOp);
+ if (failed(maybeDimensions)) {
+ LDBG("couldn't infer matmul iterators in: " << linalgOp << "\n");
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer matmul iterators");
}
// 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
- // minor iterators. If we wanted a
diff erent normalization order, this is
- // where it would have to plug a heuristic.
- int64_t mPos = *(res->mPos.begin()), nPos = *(res->nPos.begin()),
- kPos = *(res->kPos.begin());
+ // minor iterators. In cases with multiple options for m, n, k bias towards
+ // the most minor embedding.
+ // If we wanted a
diff erent normalization order, this is where it would have
+ // to plug a heuristic.
+ int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
+ kPos = maybeDimensions->k.back();
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "Start packing generic op greedily with (m@" << mPos
<< ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
@@ -2655,71 +2662,71 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
linalg::ForallTilingResult &tilingResult) {
// Transform all targets one by one.
- auto tileableOp = dyn_cast<TilingInterface>(target);
- if (!tileableOp) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "only TilingInterface ops are supported";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
- rewriter.setInsertionPoint(tileableOp);
- FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
- if (!mixedNumThreads.empty()) {
- maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp,
- mixedNumThreads, mapping);
- } else {
- maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
- rewriter, tileableOp, mixedTileSizes, mapping);
- }
+ auto tileableOp = dyn_cast<TilingInterface>(target);
+ if (!tileableOp) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "only TilingInterface ops are supported";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ rewriter.setInsertionPoint(tileableOp);
+ FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+ if (!mixedNumThreads.empty()) {
+ maybeTilingResult =
+ linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+ } else {
+ maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
+ rewriter, tileableOp, mixedTileSizes, mapping);
+ }
- if (failed(maybeTilingResult))
- return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+ if (failed(maybeTilingResult))
+ return transformOp.emitDefaultSilenceableFailure(tileableOp);
+ rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
- tilingResult = *maybeTilingResult;
- return DiagnosedSilenceableFailure::success();
+ tilingResult = *maybeTilingResult;
+ return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
transform::TileToForallOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
- auto transformOp = cast<TransformOpInterface>(getOperation());
-
- // Result payload ops.
- SmallVector<Operation *> tileOps;
- SmallVector<Operation *> tiledOps;
-
- // Unpack handles.
- SmallVector<OpFoldResult> mixedNumThreads;
- DiagnosedSilenceableFailure status =
- getPackedNumThreads()
- ? unpackSingleIndexResultPayloadOperations(
- state, transformOp, mixedNumThreads, getPackedNumThreads())
- : unpackSingleIndexResultPayloadOperations(
- state, transformOp, mixedNumThreads, getMixedNumThreads());
- if (!status.succeeded())
- return status;
- SmallVector<OpFoldResult> mixedTileSizes;
- status = getPackedTileSizes()
- ? unpackSingleIndexResultPayloadOperations(
- state, transformOp, mixedTileSizes, getPackedTileSizes())
- : unpackSingleIndexResultPayloadOperations(
- state, transformOp, mixedTileSizes, getMixedTileSizes());
- if (!status.succeeded())
- return status;
-
- for (Operation *target : state.getPayloadOps(getTarget())) {
- linalg::ForallTilingResult tilingResult;
- DiagnosedSilenceableFailure diag = tileToForallOpImpl(
- rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
- getMapping(), tilingResult);
- if (!diag.succeeded())
+ auto transformOp = cast<TransformOpInterface>(getOperation());
+
+ // Result payload ops.
+ SmallVector<Operation *> tileOps;
+ SmallVector<Operation *> tiledOps;
+
+ // Unpack handles.
+ SmallVector<OpFoldResult> mixedNumThreads;
+ DiagnosedSilenceableFailure status =
+ getPackedNumThreads()
+ ? unpackSingleIndexResultPayloadOperations(
+ state, transformOp, mixedNumThreads, getPackedNumThreads())
+ : unpackSingleIndexResultPayloadOperations(
+ state, transformOp, mixedNumThreads, getMixedNumThreads());
+ if (!status.succeeded())
+ return status;
+ SmallVector<OpFoldResult> mixedTileSizes;
+ status = getPackedTileSizes()
+ ? unpackSingleIndexResultPayloadOperations(
+ state, transformOp, mixedTileSizes, getPackedTileSizes())
+ : unpackSingleIndexResultPayloadOperations(
+ state, transformOp, mixedTileSizes, getMixedTileSizes());
+ if (!status.succeeded())
+ return status;
+
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ linalg::ForallTilingResult tilingResult;
+ DiagnosedSilenceableFailure diag = tileToForallOpImpl(
+ rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
+ getMapping(), tilingResult);
+ if (!diag.succeeded())
return diag;
tileOps.push_back(tilingResult.tileOp);
tiledOps.push_back(tilingResult.tiledOp);
- }
+ }
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e3a9569f623ff..55da7096b25c2 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -33,7 +33,6 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Pass/Pass.h"
-#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -140,86 +139,6 @@ static void unpackRanges(OpBuilder &builder, Location loc,
}
}
-//===----------------------------------------------------------------------===//
-// Utilities for inferring various semantics properties of Linalg ops.
-//===----------------------------------------------------------------------===//
-
-DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand(
- LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
- DenseSet<int64_t> res;
- assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
- for (AffineExpr e : indexingMap.getResults()) {
- if (auto d = e.dyn_cast<AffineDimExpr>()) {
- if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
- llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
- return e.isFunctionOfDim(d.getPosition());
- }) == 1)
- res.insert(d.getPosition());
- }
- }
- return res;
-}
-
-namespace {
-auto par = utils::IteratorType::parallel;
-auto red = utils::IteratorType::reduction;
-} // namespace
-
-bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
- FailureOr<EmbeddedContractionDimsCandidates> res = inferContractionDims(linalgOp);
- if (failed(res))
- return false;
- int64_t numLoops = linalgOp.getNumLoops();
- for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
- if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
- s.contains(numLoops - 1))
- continue;
- return false;
- }
- return true;
-}
-
-FailureOr<EmbeddedContractionDimsCandidates>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
- return failure();
-
- DenseSet<int64_t> a = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), par);
- DenseSet<int64_t> b = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
- DenseSet<int64_t> c = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
-
- // A & C - B are the iterators involved in an outer-product along A (the LHS).
- DenseSet<int64_t> ac = a;
- llvm::set_intersect(ac, c);
- llvm::set_subtract(ac, b);
- // B & C - A are the iterators involved in an outer-product along B (the RHS).
- DenseSet<int64_t> bc = b;
- llvm::set_intersect(bc, c);
- llvm::set_subtract(bc, a);
- // A & B & C are the "batch" dimensions.
- DenseSet<int64_t> batches = a;
- llvm::set_intersect(batches, b);
- llvm::set_intersect(batches, c);
-
- // A & B red are the reduction dimensions.
- DenseSet<int64_t> ra = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), red);
- DenseSet<int64_t> rb = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), red);
- llvm::set_intersect(ra, rb);
-
- if (ac.empty() || bc.empty() || ra.empty())
- return failure();
-
- // Pick the first one in each set.
- // TODO: Better heuristic (e.g pick dims based on packing-based metric).
- return EmbeddedContractionDimsCandidates{batches, ac, bc, ra};
-}
-
//===----------------------------------------------------------------------===//
// General utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index d67016efd0971..8fb6ed1e90994 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -59,7 +59,9 @@ transform::OperationType::checkPayload(Location loc,
for (Operation *op : payload) {
if (opName != op->getName()) {
DiagnosedSilenceableFailure diag =
- emitSilenceableError(loc) << "incompatible payload operation name";
+ emitSilenceableError(loc)
+ << "incompatible payload operation name expected " << opName << " vs "
+ << op->getName() << " -> " << *op;
diag.attachNote(op->getLoc()) << "payload operation";
return diag;
}
diff --git a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
index 374c1d280297f..68f07067b5350 100644
--- a/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
+++ b/mlir/test/Dialect/Linalg/transform-pack-greedily.mlir
@@ -326,3 +326,25 @@ transform.sequence failures(propagate) {
matmul_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
+
+// -----
+
+!A = tensor<1023x255xf32>
+!X = tensor<255xf32>
+!Y = tensor<1023xf32>
+
+// CHECK-LABEL: @matvec_fail(
+func.func @matvec_fail(%A : !A, %x : !X, %y : !Y) -> !Y {
+ // CHECK: linalg.matvec
+ %0 = linalg.matvec ins(%A, %x : !A, !X) outs(%y : !Y) -> !Y
+ return %0 : !Y
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matvec"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"linalg.matvec">
+ transform.structured.pack_greedily %matmul
+ matmul_packed_sizes = [8, 16, 32] matmul_inner_dims_order = [1, 2, 0]
+ : (!transform.op<"linalg.matvec">) -> !transform.any_op
+}
More information about the Mlir-commits
mailing list