[Mlir-commits] [mlir] e444dce - [mlir][Transform] Add classify_convolution_dims op
Quinn Dawkins
llvmlistbot at llvm.org
Mon Jul 31 13:18:06 PDT 2023
Author: Quinn Dawkins
Date: 2023-07-31T16:15:25-04:00
New Revision: e444dcef7cadb29490deca726cb743e38d0412b3
URL: https://github.com/llvm/llvm-project/commit/e444dcef7cadb29490deca726cb743e38d0412b3
DIFF: https://github.com/llvm/llvm-project/commit/e444dcef7cadb29490deca726cb743e38d0412b3.diff
LOG: [mlir][Transform] Add classify_convolution_dims op
Includes `inferConvolutionDims` based on the existing helper for
contractions, `inferContractionDims`. This allows matching and
identifying the relevant dims for a convolution sub-computation of
a linalg operation.
Additionally adds stride/dilations inference to the captures and
convolution interface matcher.
Differential Revision: https://reviews.llvm.org/D156080
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index b321a376309976..e735fdc59ec1be 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -67,6 +67,45 @@ FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
bool isaContractionOpInterface(LinalgOp linalgOp);
+/// Positions of a Linalg op loops that correspond to
diff erent kinds of a
+/// convolution dimension.
+struct ConvolutionDimensions {
+ SmallVector<unsigned, 2> batch;
+ SmallVector<unsigned, 2> outputImage;
+ SmallVector<unsigned, 2> outputChannel;
+ SmallVector<unsigned, 2> filterLoop;
+ SmallVector<unsigned, 2> inputChannel;
+ SmallVector<unsigned, 2> depth;
+ SmallVector<int64_t, 2> strides;
+ SmallVector<int64_t, 2> dilations;
+};
+
+/// Find at least 1 parallel (output_image) and reduction (filter_loop)
+/// dimension candidates that form a convolution subcomputation within
+/// `linalgOp`. The LHS is assumed to be the convolution input while the
+/// RHS is assumed as the filter.
+/// These dimensions are such that:
+/// 1. Optional batch dimensions that appear in the input and filter.
+/// 2. The output_image dimension is involved in a cross-correlation along LHS
+/// (i.e. it is a permutation on RES and LHS and has an associated
+/// filter_loop in RHS).
+/// 3. Optional output_channel dimension is involved in an outer-product along
+/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
+/// LHS).
+/// 4. Optional input_channel dimension appears as a permutation on LHS and
+/// RHS.
+/// 5. The filter_loop dimension appears as a permutation on the RHS and
+/// represents the shape of the kernel cross-correlated along a
+/// corresponding output_image dim.
+/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
+/// 7. All dimensions appear only once in any given indexing map.
+/// This allows e.g. detecting that some convolution is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+/// When multiple dimension occurrences exist that match any classification
+/// indices are returned in sorted order.
+/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
+FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
+
/// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
bool isaConvolutionOpInterface(LinalgOp linalgOp);
@@ -100,9 +139,6 @@ enum class MatchContractionResult;
/// Checks whether `op` conforms to ContractionOpInterface and populates
/// `dimensions` with indexes of the
diff erent kinds of dimensions when
/// present.
-// TODO: Extract a standalone `inferConvolutionDims` that can also detect
-// whether a conv pattern exists within a bigger linalg op (see
-// inferContractionDims).
MatchContractionResult
isContractionInterfaceImpl(Operation *op,
ContractionDimensions *dimensions = nullptr);
@@ -115,17 +151,6 @@ StringRef getMatchContractionMessage(MatchContractionResult res);
/// convolution.
enum class MatchConvolutionResult;
-/// Positions of a Linalg op loops that correspond to
diff erent kinds of a
-/// convolution dimension.
-struct ConvolutionDimensions {
- SmallVector<unsigned, 2> batch;
- SmallVector<unsigned, 2> outputImage;
- SmallVector<unsigned, 2> outputChannel;
- SmallVector<unsigned, 2> filterLoop;
- SmallVector<unsigned, 2> inputChannel;
- SmallVector<unsigned, 2> depth;
-};
-
/// Checks whether `op` conforms to ConvolutionOpInterface and populates
/// `dimensions` with indexes of the
diff erent kinds of dimensions when
/// present.
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index ad348d0ce89a64..9e108529ec129b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -170,6 +170,57 @@ def MatchStructuredClassifyContractionDimsOp
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
}
+def MatchStructuredClassifyConvolutionDimsOp
+ : Op<Transform_Dialect, "match.structured.classify_convolution_dims", [
+ SingleOpMatcher,
+ StructuredPredicate,
+ MatchOpInterface,
+ MemoryEffectsOpInterface]> {
+ let summary =
+ "Checks if an operation has convolution-like dimensions and returns them";
+ let description = !strconcat([{
+ Checks if the structured payload op has convolution-like dimensions as
+ follows:
+
+ C(batch, depth, oi, oc) += A(batch, depth, oi, ic) * B(fl, depth, ic, oc)
+
+ That is:
+
+ - 'batch' are parallel dimensions used in the input and result;
+ - 'output_image' ('oi') are parallel dimensions used in the input and result;
+ - 'output_channel' ('oc') are parallel dimensions used in the filter and result;
+ - 'filter_loop' ('fl') are reduction dimensions representing the dimensions of the sliding window;
+ - 'input_channel' ('ic') are reduction dimensions present only in the input and filter.
+ - 'depth' ('ic') are parallel dimensions present in the input, filter, and output.
+
+ Additionally this will match stride and dilation information for the convolution:
+ - 'strides' are the static strides per convolution window dimension;
+ - 'dilations' are the static dilations per convolution window dimension.
+
+ Note that this doesn't check the operation in the body.
+
+ }], StructuredPredicate.extraDescription, [{
+
+ #### Return modes
+
+ Succeeds if the operation has the convolution-like dimensions, produces a
+ silenceable failure otherwise.
+ }]);
+
+ let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+ let results = (outs TransformParamTypeInterface:$batch,
+ TransformParamTypeInterface:$output_image,
+ TransformParamTypeInterface:$output_channel,
+ TransformParamTypeInterface:$filter_loop,
+ TransformParamTypeInterface:$input_channel,
+ TransformParamTypeInterface:$depth,
+ TransformParamTypeInterface:$strides,
+ TransformParamTypeInterface:$dilations);
+ let assemblyFormat =
+ "$operand_handle attr-dict `:` functional-type(operands, results)";
+ let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
class StructuredDimDescription<string kind> {
string description = !strconcat([{
The following }], kind ,[{ specifications are supported:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 08de32818f0047..95a766b3375713 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -161,10 +161,10 @@ static bool isContractionBody(Block &block) {
/// determining whether:
/// - It is a single AffineDimExpr.
/// - It is the only result involving this AffineDimExpr.
-static DenseSet<int64_t>
+static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
utils::IteratorType iter) {
- DenseSet<int64_t> res;
+ llvm::SmallDenseSet<int64_t> res;
assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
@@ -200,30 +200,30 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
return failure();
- DenseSet<int64_t> a = findPermutationsIndexingOperand(
+ llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(0), par);
- DenseSet<int64_t> b = findPermutationsIndexingOperand(
+ llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), par);
- DenseSet<int64_t> c = findPermutationsIndexingOperand(
+ llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInitOperand(0), par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
- DenseSet<int64_t> ac = a;
+ llvm::SmallDenseSet<int64_t> ac = a;
llvm::set_intersect(ac, c);
llvm::set_subtract(ac, b);
// B & C - A are the iterators involved in an outer-product along B (the RHS).
- DenseSet<int64_t> bc = b;
+ llvm::SmallDenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
// A & B & C are the "batch" dimensions.
- DenseSet<int64_t> batches = a;
+ llvm::SmallDenseSet<int64_t> batches = a;
llvm::set_intersect(batches, b);
llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
- DenseSet<int64_t> ra = findPermutationsIndexingOperand(
+ llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(0), red);
- DenseSet<int64_t> rb = findPermutationsIndexingOperand(
+ llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), red);
llvm::set_intersect(ra, rb);
@@ -236,10 +236,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
- std::sort(dimensions.batch.begin(), dimensions.batch.end());
- std::sort(dimensions.m.begin(), dimensions.m.end());
- std::sort(dimensions.n.begin(), dimensions.n.end());
- std::sort(dimensions.k.begin(), dimensions.k.end());
+ llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
+ llvm::sort(dimensions.m.begin(), dimensions.m.end());
+ llvm::sort(dimensions.n.begin(), dimensions.n.end());
+ llvm::sort(dimensions.k.begin(), dimensions.k.end());
return dimensions;
}
@@ -359,8 +359,37 @@ namespace {
/// dimensions and verifies each dimension occurs only once.
struct ConvAccessExprWalker
: public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
- llvm::SmallDenseSet<unsigned> convolvedDims;
- llvm::SmallDenseSet<unsigned> unConvolvedDims;
+ // Stores dimensions used in expressions of the above form.
+ llvm::SmallDenseSet<int64_t> convolvedDims;
+ // Stores the dual mapping between LHS and RHS of convolution exprs.
+ llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
+ // Stores single use dimensions used by an AffineDimExpr.
+ llvm::SmallDenseSet<int64_t> unConvolvedDims;
+ // Stores a mapping from convolved dims to their coefficient.
+ llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
+
+ // Removes dims with multiple uses in the source input map from dimension
+ // sets tracked by this walker.
+ void clearMultiUseDims(AffineMap map) {
+ for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
+ if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
+ return e.isFunctionOfDim(dimPos);
+ }) > 1) {
+ convolvedDims.erase(dimPos);
+ unConvolvedDims.erase(dimPos);
+ // If a duplicate dim is marked as convolved, the pair of the duplicate
+ // dim must be removed from the map as well.
+ if (convolvedDimMapping.contains(dimPos)) {
+ int64_t pairedDim = convolvedDimMapping[dimPos];
+ convolvedDims.erase(pairedDim);
+ unConvolvedDims.erase(pairedDim);
+ strideAndDilationMapping.erase(pairedDim);
+ convolvedDimMapping.erase(dimPos);
+ convolvedDimMapping.erase(pairedDim);
+ }
+ }
+ }
+ }
LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
unsigned position = dimExpr.getPosition();
@@ -379,17 +408,25 @@ struct ConvAccessExprWalker
// In pre-order visit, top level op has to be an add op.
if (binaryExpr.getKind() != AffineExprKind::Add)
return failure();
- return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
- succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
+ auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
+ auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
+ if (failed(lhsDimPos) || failed(rhsDimPos))
+ return failure();
+ convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
+ convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
+ return success();
}
- LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
+ FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
- unsigned dim = dimExpr.getPosition();
+ int64_t dim = dimExpr.getPosition();
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
return failure();
+ // Stride/dilation for this dim is implicitly 1.
+ strideAndDilationMapping[dim] =
+ getAffineConstantExpr(1, expr.getContext());
convolvedDims.insert(dim);
- return success();
+ return dim;
}
if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
if (symbolMulExpr.getKind() != AffineExprKind::Mul)
@@ -406,26 +443,170 @@ struct ConvAccessExprWalker
auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
if (!mulExpr || !dimExpr)
return failure();
- unsigned dim = dimExpr.getPosition();
+ int64_t dim = dimExpr.getPosition();
if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
return failure();
+ strideAndDilationMapping[dim] = mulExpr;
convolvedDims.insert(dim);
- return success();
+ return dim;
}
return failure();
}
};
} // namespace
-static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
+static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
assert(map.isProjectedPermutation() &&
"expected map to have projected permutations");
- llvm::SmallDenseSet<unsigned> preservedDims;
+ llvm::SmallDenseSet<int64_t> preservedDims;
for (auto expr : map.getResults())
preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
return preservedDims;
}
+static SmallVector<int64_t, 2>
+getConstantsFromExprList(SmallVector<AffineExpr, 2> exprs) {
+ SmallVector<int64_t, 2> vals;
+ for (auto e : exprs) {
+ auto constantExpr = e.dyn_cast<AffineConstantExpr>();
+ assert(constantExpr && "Found non-constant stride/dilation");
+ vals.push_back(constantExpr.getValue());
+ }
+ return vals;
+}
+
+/// Classifies dimensions in the `linalgOp` used by a convolution
+/// subcomputation, as captured by `inputExprWalker`. If
+/// `allowEmptyConvolvedDims` is not set this this will fail if there is not
+/// at least convolved dimension pair (output image + filter loop). Convolution
+/// dimensions are specified in sorted order, and strides match the order of
+/// the filter loop dimensions, while the dilations match the order of the
+/// output image dimensions.
+static FailureOr<ConvolutionDimensions>
+inferConvolutionDimsImpl(LinalgOp linalgOp,
+ ConvAccessExprWalker &inputExprWalker,
+ bool allowEmptyConvolvedDims) {
+ llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInputOperand(1), par);
+ llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
+ linalgOp, linalgOp.getDpsInitOperand(0), par);
+
+ // unConvolvedDims & outputDims - filterDims are the batch iterators.
+ llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
+ llvm::set_intersect(batch, outputDims);
+ llvm::set_subtract(batch, filterDims);
+
+ // convolvedDims & outputDims are the output image iterators.
+ llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
+ llvm::set_intersect(oi, outputDims);
+
+ // filterDims & outputDims - unConvolvedDims are the output channel iterators.
+ llvm::SmallDenseSet<int64_t> oc = filterDims;
+ llvm::set_intersect(oc, outputDims);
+ llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
+
+ // filterDims & outputDims & unConvolvedDims are the depth iterators.
+ llvm::SmallDenseSet<int64_t> depth = filterDims;
+ llvm::set_intersect(depth, outputDims);
+ llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
+
+ llvm::SmallDenseSet<int64_t> filterReducedDims =
+ findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
+ red);
+
+ // convolvedDims & filterReducedDims are the filter loop iterators.
+ llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
+ llvm::set_intersect(fl, filterReducedDims);
+
+ // unConvolvedDims & filterReducedDims are the input channel iterators.
+ llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
+ llvm::set_intersect(ic, filterReducedDims);
+
+ if (oi.empty() && !allowEmptyConvolvedDims)
+ return failure();
+
+ // Return each set in sorted order.
+ ConvolutionDimensions dimensions{
+ SmallVector<unsigned, 2>(batch.begin(), batch.end()),
+ SmallVector<unsigned, 2>(oi.begin(), oi.end()),
+ SmallVector<unsigned, 2>(oc.begin(), oc.end()),
+ SmallVector<unsigned, 2>(fl.begin(), fl.end()),
+ SmallVector<unsigned, 2>(ic.begin(), ic.end()),
+ SmallVector<unsigned, 2>(depth.begin(), depth.end()),
+ /*strides=*/SmallVector<int64_t, 2>{},
+ /*dilations=*/SmallVector<int64_t, 2>{}};
+ llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
+ llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
+ llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
+ llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
+ llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
+ llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
+
+ // Use the op carried strides/dilations attribute if present.
+ auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+ if (!nativeStrides) {
+ SmallVector<AffineExpr, 2> strideExprs;
+ for (unsigned oiDim : dimensions.outputImage)
+ strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
+ dimensions.strides = getConstantsFromExprList(strideExprs);
+ } else {
+ dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
+ }
+ auto nativeDilations =
+ linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+ if (!nativeDilations) {
+ SmallVector<AffineExpr, 2> dilationExprs;
+ for (unsigned flDim : dimensions.filterLoop)
+ dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
+ dimensions.dilations = getConstantsFromExprList(dilationExprs);
+ } else {
+ dimensions.dilations =
+ llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
+ }
+ return dimensions;
+}
+
+/// Find at least 1 parallel (output_image) and reduction (filter_loop)
+/// dimension candidates that form a convolution subcomputation within
+/// `linalgOp`. The LHS is assumed to be the convolution input while the
+/// RHS is assumed as the filter.
+/// These dimensions are such that:
+/// 1. Optional batch dimensions that appear in the input and filter.
+/// 2. The output_image dimension is involved in a cross-correlation along LHS
+/// (i.e. it is a permutation on RES and LHS and has an associated
+/// filter_loop in RHS).
+/// 3. Optional output_channel dimension is involved in an outer-product along
+/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
+/// LHS).
+/// 4. Optional input_channel dimension appears as a permutation on LHS and
+/// RHS.
+/// 5. The filter_loop dimension appears as a permutation on the RHS and
+/// represents the shape of the kernel cross-correlated along a
+/// corresponding output_image dim.
+/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
+/// 7. All dimensions appear only once in any given indexing map.
+/// This allows e.g. detecting that some convolution is embedded within
+/// `linalgOp` with some orthogonal heuristic.
+/// When multiple dimension occurrences exist that match any classification
+/// indices are returned in sorted order.
+/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
+FailureOr<ConvolutionDimensions>
+mlir::linalg::inferConvolutionDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+
+ auto indexingMaps = linalgOp.getIndexingMapsArray();
+
+ // Check the input indexing map has the right form.
+ ConvAccessExprWalker inputExprWalker;
+ for (AffineExpr expr : indexingMaps[0].getResults())
+ (void)inputExprWalker.visit(expr);
+ inputExprWalker.clearMultiUseDims(indexingMaps[0]);
+
+ return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
+ /*allowEmptyConvolvedDims=*/false);
+}
+
namespace mlir::linalg::detail {
enum class MatchConvolutionResult {
Success = 0,
@@ -466,9 +647,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
auto iteratorTypes = linalgOp.getIteratorTypesArray();
- llvm::SmallDenseSet<unsigned> outputDims =
+ llvm::SmallDenseSet<int64_t> outputDims =
getPreservedDims(indexingMaps.back());
- llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
+ llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
// Make sure all loops are characterized as one of:
// - Batch loop : present in output, as non-convolved in input, not present in
// filter.
@@ -482,17 +663,15 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// present in filter.
// - Depth multiplier : unconvolved in input, present in output, present in
// filter.
- llvm::SmallDenseSet<unsigned> allLoopDims;
+ llvm::SmallDenseSet<int64_t> allLoopDims;
for (auto outputExpr : indexingMaps.back().getResults()) {
- unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
+ int64_t outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Batch dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
- if (dimensions)
- dimensions->batch.push_back(outputDim);
continue;
}
if (inputExprWalker.convolvedDims.count(outputDim) &&
@@ -501,8 +680,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
- if (dimensions)
- dimensions->outputImage.push_back(outputDim);
continue;
}
if (!inputExprWalker.convolvedDims.count(outputDim) &&
@@ -512,8 +689,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
- if (dimensions)
- dimensions->outputChannel.push_back(outputDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -522,21 +697,16 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
- if (dimensions)
- dimensions->depth.push_back(outputDim);
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
}
for (auto filterExpr : indexingMaps[1].getResults()) {
- unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
+ int64_t filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
if (outputDims.count(filterDim) &&
!inputExprWalker.unConvolvedDims.count(filterDim) &&
!inputExprWalker.convolvedDims.count(filterDim)) {
// Output channel dimension. This is already seen, continue;
- assert((!dimensions ||
- llvm::is_contained(dimensions->outputChannel, filterDim)) &&
- "expected output channel to have been found from output dims");
continue;
}
if (inputExprWalker.convolvedDims.count(filterDim) &&
@@ -547,8 +717,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
- if (dimensions)
- dimensions->filterLoop.push_back(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
@@ -559,16 +727,11 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
- if (dimensions)
- dimensions->inputChannel.push_back(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
outputDims.count(filterDim)) {
// Depthwise loop. Already seen.
- assert(
- (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) &&
- "expected depthwise dimension to have been found from output dims");
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
@@ -578,12 +741,11 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
return MatchConvolutionResult::NonConvolutionLoop;
if (dimensions) {
- assert(dimensions->batch.size() + dimensions->outputImage.size() +
- dimensions->outputChannel.size() +
- dimensions->filterLoop.size() +
- dimensions->inputChannel.size() + dimensions->depth.size() ==
- linalgOp.getNumLoops() &&
- "expected all loops to be classified");
+ FailureOr<ConvolutionDimensions> res =
+ inferConvolutionDimsImpl(linalgOp, inputExprWalker,
+ /*allowEmptyConvolvedDims=*/true);
+ assert(succeeded(res) && "unexpected failure to infer convolution dims");
+ *dimensions = *res;
}
return MatchConvolutionResult::Success;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 2220930e8f4dd0..7b5b4924fe6926 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -259,6 +259,53 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// MatchStructuredClassifyConvolutionDimsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
+ Operation *current, transform::TransformResults &results,
+ transform::TransformState &state) {
+ FailureOr<linalg::ConvolutionDimensions> convolutionDims =
+ linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
+ if (failed(convolutionDims))
+ return emitSilenceableError() << "could not infer convolution dimensions";
+
+ MLIRContext *context = current->getContext();
+ Builder builder(context);
+ auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
+ return llvm::to_vector(
+ llvm::map_range(values, [&](unsigned value) -> Attribute {
+ return builder.getI64IntegerAttr(value);
+ }));
+ };
+ results.setParams(getBatch().cast<OpResult>(),
+ makeI64Attrs(convolutionDims->batch));
+ results.setParams(getOutputImage().cast<OpResult>(),
+ makeI64Attrs(convolutionDims->outputImage));
+ results.setParams(getOutputChannel().cast<OpResult>(),
+ makeI64Attrs(convolutionDims->outputChannel));
+ results.setParams(getFilterLoop().cast<OpResult>(),
+ makeI64Attrs(convolutionDims->filterLoop));
+ results.setParams(getInputChannel().cast<OpResult>(),
+ makeI64Attrs(convolutionDims->inputChannel));
+ results.setParams(getDepth().cast<OpResult>(),
+ makeI64Attrs(convolutionDims->depth));
+
+ auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+ return llvm::to_vector(
+ llvm::map_range(values, [&](int64_t value) -> Attribute {
+ return builder.getI64IntegerAttr(value);
+ }));
+ };
+ results.setParams(getStrides().cast<OpResult>(),
+ makeI64AttrsFromI64(convolutionDims->strides));
+ results.setParams(getDilations().cast<OpResult>(),
+ makeI64AttrsFromI64(convolutionDims->dilations));
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Utilities for structured match predicates.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 063e1e44fd604c..bad6893eaa99e1 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -934,3 +934,97 @@ module attributes { transform.target_tag = "start_here" } {
return %result : tensor<40x10x50x15xf32>
}
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @match_convolution(%arg0: !transform.any_op {transform.readonly})
+ -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+ %1:8 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+ ^bb0(%struct: !transform.any_op):
+ transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
+ %0:8 = transform.match.structured.classify_convolution_dims %struct
+ : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+ transform.match.structured.yield %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7
+ : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+ }
+ transform.yield %arg0, %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+ }
+
+ transform.named_sequence @print_convolution(
+ %op: !transform.any_op {transform.readonly},
+ %batch: !transform.param<i64> {transform.readonly},
+ %oi: !transform.param<i64> {transform.readonly},
+ %oc: !transform.param<i64> {transform.readonly},
+ %fl: !transform.param<i64> {transform.readonly},
+ %ic: !transform.param<i64> {transform.readonly},
+ %depth: !transform.param<i64> {transform.readonly},
+ %strides: !transform.param<i64> {transform.readonly},
+ %dilations: !transform.param<i64> {transform.readonly}) {
+ transform.test_print_remark_at_operand %op, "convolution" : !transform.any_op
+ transform.test_print_param %batch, "batch dims" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %oi, "output image dims" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %oc, "output channel dims" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %fl, "filter loop dims" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %ic, "input channel dims" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %depth, "depth dims" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %strides, "strides" at %op : !transform.param<i64>, !transform.any_op
+ transform.test_print_param %dilations, "dilations" at %op : !transform.param<i64>, !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+ ^bb0(%arg0: !transform.any_op):
+ %3 = transform.foreach_match in %arg0 @match_convolution -> @print_convolution : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+module attributes { transform.target_tag = "start_here" } {
+ func.func @convolution_simple(%input: tensor<10x20x30xf32>, %filter: tensor<3x30x15xf32>) -> tensor<10x18x15xf64> {
+ %cst = arith.constant 0.0 : f64
+ %empty = tensor.empty() : tensor<10x18x15xf64>
+ %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x18x15xf64>) -> tensor<10x18x15xf64>
+ // expected-remark @below {{convolution}}
+ // expected-remark @below {{batch dims 0}}
+ // expected-remark @below {{output image dims 1}}
+ // expected-remark @below {{output channel dims 2}}
+ // expected-remark @below {{filter loop dims 3}}
+ // expected-remark @below {{input channel dims 4}}
+ // expected-remark @below {{depth dims}}
+ // expected-remark @below {{strides 1}}
+ // expected-remark @below {{dilations 1}}
+ %result = linalg.conv_1d_nwc_wcf {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter: tensor<10x20x30xf32>, tensor<3x30x15xf32>) outs(%fill: tensor<10x18x15xf64>) -> tensor<10x18x15xf64>
+ return %result : tensor<10x18x15xf64>
+ }
+
+ func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
+ %cst = arith.constant 0.0 : f32
+ %empty = tensor.empty() : tensor<8x32x32x16xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8x32x32x16xf32>) -> tensor<8x32x32x16xf32>
+ // expected-remark @below {{convolution}}
+ // expected-remark @below {{batch dims}}
+ // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
+ // expected-remark @below {{output channel dims 0 : i64, 3 : i64}}
+ // expected-remark @below {{filter loop dims 5 : i64, 6 : i64}}
+ // expected-remark @below {{input channel dims 4 : i64, 7 : i64}}
+ // expected-remark @below {{depth dims}}
+ // expected-remark @below {{strides 1 : i64, 2 : i64}}
+ // expected-remark @below {{dilations 1 : i64, 1 : i64}}
+ %result = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1 + d5, 2 * d2 + d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d4, d5, d6, d7, d3)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]}
+ ins(%input, %filter : tensor<2x34x68x16xf32>, tensor<8x2x3x5x16x16xf32>) outs(%fill : tensor<8x32x32x16xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %mul = arith.mulf %in, %in_0 : f32
+ %add = arith.addf %mul, %out : f32
+ linalg.yield %add : f32
+ } -> tensor<8x32x32x16xf32>
+ return %result : tensor<8x32x32x16xf32>
+ }
+}
More information about the Mlir-commits
mailing list