[Mlir-commits] [mlir] 0b33f0d - [mlir][linalg] expose convolution dimension classifier
Alex Zinenko
llvmlistbot at llvm.org
Tue Feb 14 02:12:08 PST 2023
Author: Alex Zinenko
Date: 2023-02-14T10:12:01Z
New Revision: 0b33f0d80b832614af68557011786a83a0eef18a
URL: https://github.com/llvm/llvm-project/commit/0b33f0d80b832614af68557011786a83a0eef18a
DIFF: https://github.com/llvm/llvm-project/commit/0b33f0d80b832614af68557011786a83a0eef18a.diff
LOG: [mlir][linalg] expose convolution dimension classifier
Make available through functions in the `linalg::detail` namespace the
classification of Linalg op dimensions as different kinds (batch, image,
channel, etc) of convolution dimensions. This is useful for identifying
which dimensions to target with transformations.
Reviewed By: chelini
Differential Revision: https://reviews.llvm.org/D143584
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index d8d59b5ce3a81..cb93e8a7bc104 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -42,6 +42,31 @@ bool isaContractionOpInterface(LinalgOp linalgOp);
namespace detail {
+/// Result of matching a Linalg generic against the predicates of it being a
+/// convolution.
+enum class MatchConvolutionResult;
+
+/// Positions of a Linalg op loops that correspond to
diff erent kinds of a
+/// convolution dimension.
+struct ConvolutionDimensions {
+ SmallVector<unsigned, 2> batch;
+ SmallVector<unsigned, 2> outputImage;
+ SmallVector<unsigned, 2> outputChannel;
+ SmallVector<unsigned, 2> filterLoop;
+ SmallVector<unsigned, 2> inputChannel;
+ SmallVector<unsigned, 2> depth;
+};
+
+/// Checks whether `op` conforms to ConvolutionOpInterface and populates
+/// `dimensions` with indexes of the
diff erent kinds of dimensions when present.
+MatchConvolutionResult
+isConvolutionInterfaceImpl(Operation *op,
+ ConvolutionDimensions *dimensions = nullptr);
+
+/// Returns the error message corresponding to the convolution checking return
+/// code.
+StringRef getMatchConvolutionMessage(MatchConvolutionResult res);
+
/// Verify that `op` conforms to ContractionOpInterface.
LogicalResult verifyContractionInterface(Operation *op);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index e5e0bdd255dc5..a5c6dc627299e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -268,6 +268,7 @@ static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
return preservedDims;
}
+namespace mlir::linalg::detail {
enum class MatchConvolutionResult {
Success = 0,
NotLinalgOp,
@@ -278,8 +279,11 @@ enum class MatchConvolutionResult {
OutputDimsNotParallel,
NonOutputDimNotReduction
};
+} // namespace mlir::linalg::detail
-static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
+mlir::linalg::detail::MatchConvolutionResult
+mlir::linalg::detail::isConvolutionInterfaceImpl(
+ Operation *op, ConvolutionDimensions *dimensions) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return MatchConvolutionResult::NotLinalgOp;
@@ -307,7 +311,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
llvm::SmallDenseSet<unsigned> outputDims =
getPreservedDims(indexingMaps.back());
llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
- // Make sure all loops are charecterized as one of:
+ // Make sure all loops are characterized as one of:
// - Batch loop : present in output, as non-convolved in input, not present in
// filter.
// - Output image dimension : present in output, convolved dims in input, not
@@ -329,6 +333,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->batch.push_back(outputDim);
continue;
}
if (inputExprWalker.convolvedDims.count(outputDim) &&
@@ -337,6 +343,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->outputImage.push_back(outputDim);
continue;
}
if (!inputExprWalker.convolvedDims.count(outputDim) &&
@@ -346,6 +354,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->outputChannel.push_back(outputDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -354,6 +364,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
+ if (dimensions)
+ dimensions->depth.push_back(outputDim);
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
@@ -363,7 +375,10 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (outputDims.count(filterDim) &&
!inputExprWalker.unConvolvedDims.count(filterDim) &&
!inputExprWalker.convolvedDims.count(filterDim)) {
- // Output channel dimension. THis is already seen, continue;
+ // Output channel dimension. This is already seen, continue;
+ assert((!dimensions ||
+ llvm::is_contained(dimensions->outputChannel, filterDim)) &&
+ "expected output channel to have been found from output dims");
continue;
}
if (inputExprWalker.convolvedDims.count(filterDim) &&
@@ -374,6 +389,8 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
+ if (dimensions)
+ dimensions->filterLoop.push_back(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
@@ -384,11 +401,16 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
allLoopDims.insert(filterDim);
+ if (dimensions)
+ dimensions->inputChannel.push_back(filterDim);
continue;
}
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
outputDims.count(filterDim)) {
// Depthwise loop. Already seen.
+ assert(
+ (!dimensions || llvm::is_contained(dimensions->depth, filterDim)) &&
+ "expected depthwise dimension to have been found from output dims");
continue;
}
return MatchConvolutionResult::NonConvolutionLoop;
@@ -397,32 +419,45 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
+ if (dimensions) {
+ assert(dimensions->batch.size() + dimensions->outputImage.size() +
+ dimensions->outputChannel.size() +
+ dimensions->filterLoop.size() +
+ dimensions->inputChannel.size() + dimensions->depth.size() ==
+ linalgOp.getNumLoops() &&
+ "expected all loops to be classified");
+ }
+
return MatchConvolutionResult::Success;
}
-LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
- auto res = isConvolutionInterfaceImpl(op);
- if (res == MatchConvolutionResult::NotLinalgOp)
- return op->emitError("expected a LinalgOp");
- if (res == MatchConvolutionResult::WrongNumOperands)
- return op->emitError("expected op with 2 inputs and 1 output");
- if (res == MatchConvolutionResult::WrongInputIndexingMap)
- return op->emitError("unexpected input index map for convolutions");
- if (res == MatchConvolutionResult::NotProjectedPermutations) {
- return op->emitError(
- "expected output/filter indexing maps to be projected permutations");
- }
- if (res == MatchConvolutionResult::NonConvolutionLoop) {
- return op->emitError("unexpected loop dimension for convolution op");
- }
- if (res == MatchConvolutionResult::OutputDimsNotParallel) {
- return op->emitError(
- "expected all iterators used to access outputs to be parallel");
- }
- if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
- return op->emitError(
- "expected all iterators not used to access outputs to be reduction");
+StringRef
+mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
+ switch (res) {
+ case MatchConvolutionResult::NotLinalgOp:
+ return "expected a LinalgOp";
+ case MatchConvolutionResult::WrongNumOperands:
+ return "expected op with 2 inputs and 1 output";
+ case MatchConvolutionResult::WrongInputIndexingMap:
+ return "unexpected input index map for convolutions";
+ case MatchConvolutionResult::NotProjectedPermutations:
+ return "expected output/filter indexing maps to be projected permutations";
+ case MatchConvolutionResult::NonConvolutionLoop:
+ return "unexpected loop dimension for convolution op";
+ case MatchConvolutionResult::OutputDimsNotParallel:
+ return "expected all iterators used to access outputs to be parallel";
+ case MatchConvolutionResult::NonOutputDimNotReduction:
+ return "expected all iterators not used to access outputs to be reduction";
+ case MatchConvolutionResult::Success:
+ return "";
}
+ llvm_unreachable("unhandled MatchConvolutionResult case");
+}
+
+LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
+ MatchConvolutionResult res = isConvolutionInterfaceImpl(op);
+ if (res != MatchConvolutionResult::Success)
+ return op->emitError(getMatchConvolutionMessage(res));
return success();
}
More information about the Mlir-commits
mailing list