[Mlir-commits] [mlir] 0b33890 - [mlir][Linalg] Add ConvolutionOpInterface.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 20 10:41:22 PDT 2021
Author: MaheshRavishankar
Date: 2021-09-20T10:41:10-07:00
New Revision: 0b33890f4553c9255c0f44cee04a0d98843d6a5a
URL: https://github.com/llvm/llvm-project/commit/0b33890f4553c9255c0f44cee04a0d98843d6a5a
DIFF: https://github.com/llvm/llvm-project/commit/0b33890f4553c9255c0f44cee04a0d98843d6a5a.diff
LOG: [mlir][Linalg] Add ConvolutionOpInterface.
Add an interface that allows grouping together all covolution and
pooling ops within Linalg named ops. The interface currently
- the indexing map used for input/image access is valid
- the filter and output are accessed using projected permutations
- that all loops are charecterizable as one iterating over
- batch dimension,
- output image dimensions,
- filter convolved dimensions,
- output channel dimensions,
- input channel dimensions,
- depth multiplier (for depthwise convolutions)
Differential Revision: https://reviews.llvm.org/D109793
Added:
mlir/test/Dialect/Linalg/conv-interface-invalid.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestOps.td
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 4837fc8923472..8de117ca04253 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -44,6 +44,9 @@ namespace detail {
/// Verify that `op` conforms to ContractionOpInterface.
LogicalResult verifyContractionInterface(Operation *op);
+/// Verify that `op` conforms to the ConvolutionOpInterface.
+LogicalResult verifyConvolutionInterface(Operation *op);
+
/// Verify that `op` conforms to the invariants of StructuredOpInterface
LogicalResult verifyStructuredOpInterface(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fe8ddaffbc169..6aa2fb7596163 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -87,6 +87,51 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
];
}
+def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
+ let description = [{
+ A convolution is defined in general terms:
+ 1. Has an `image` and a `filter` operand.
+ 2. Has one `output` operand.
+ 3. The indexing maps of the input have expressions that satisfy
+ ```
+ AffineExpr ::== AffineDimExpr | ConvolvedExpr
+ ConvolvedExpr ::== MulExpr (`+` MulExpr)+
+ MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))?
+ ```
+ 4. The filter and the output have projected permutation maps.
+ 5. Each of the loops can be qualified as one of,
+ - Loop over batch dimension,
+ - Loop over output image dimensions,
+ - Loop over output channel dimensions,
+ - Loop over convolved filter dimensions,
+ - Loop over input channel dimension.
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let verify = [{ return detail::verifyConvolutionInterface($_op); }];
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/"Return the image operand.",
+ /*retTy=*/"Value",
+ /*methodName=*/"image",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getOperation()->getOperand(0);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/"Return the filter operand.",
+ /*retTy=*/"Value",
+ /*methodName=*/"filter",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getOperation()->getOperand(1);
+ }]
+ >,
+ ];
+}
+
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 11bd84b0fcf47..1a6c5cdee2c8f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -636,6 +636,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -695,6 +697,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -756,6 +760,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -820,6 +826,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -898,6 +906,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -985,6 +995,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1103,6 +1115,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1185,6 +1199,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1272,6 +1288,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Multiplier is set to 1
which is a special case for most dpethwise convolutions.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1350,6 +1368,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1460,6 +1480,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1542,6 +1564,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1656,6 +1680,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1724,6 +1750,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1792,6 +1820,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1860,6 +1890,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -1928,6 +1960,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2002,6 +2036,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -2076,6 +2112,8 @@ metadata: !LinalgOpMetadata
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1d4e6d546067e..4c17987c580f1 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -25,8 +25,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// first operands. These may be optionally followed by non-view operands
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
- : Op<Linalg_Dialect, mnemonic, !listconcat(props, [
- LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface])> {
+ : Op<Linalg_Dialect, mnemonic, !listconcat([
+ LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface], props)> {
code structuredOpsBaseDecls = [{
// Return whether the op accesses the iteration indices.
bool hasIndexSemantics() {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index fb1d1cca29d6b..f3570e93cdbbd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -89,7 +89,7 @@ static bool isAddMul(Block &block) {
return success;
}
-enum MatchContractionResult {
+enum class MatchContractionResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
@@ -152,6 +152,260 @@ LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// ConvolutionOpInterface implementation
+//===----------------------------------------------------------------------===//
+
+/// Of the given two expressions returns one that is of type T (`lhs` gets
+/// preference over `rhs`)
+template <typename T>
+static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs) {
+ return lhs.isa<T>() ? lhs.cast<T>()
+ : (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
+}
+
+namespace {
+/// Walk the indexing expressions for input of a convolution operation to verify
+/// its of the right form, either
+/// - AffineDimExpr
+/// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
+/// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
+///
+/// classifies the AffineDimExpr as convolved dimensions or unconvolved
+/// dimensions and verifies each dimension occurs only once.
+struct ConvAccessExprWalker
+ : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
+ llvm::SmallDenseSet<unsigned> convolvedDims;
+ llvm::SmallDenseSet<unsigned> unConvolvedDims;
+
+ LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
+ unsigned position = dimExpr.getPosition();
+ if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
+ return failure();
+ }
+ unConvolvedDims.insert(position);
+ return success();
+ }
+
+ LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
+
+ LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
+
+ LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
+ // In pre-order visit, top level op has to be an add op.
+ if (binaryExpr.getKind() != AffineExprKind::Add)
+ return failure();
+ return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
+ succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
+ }
+
+ LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
+ if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ unsigned dim = dimExpr.getPosition();
+ if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
+ return failure();
+ convolvedDims.insert(dim);
+ return success();
+ }
+ if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ if (symbolMulExpr.getKind() != AffineExprKind::Mul)
+ return failure();
+ auto lhsExpr = symbolMulExpr.getLHS();
+ auto rhsExpr = symbolMulExpr.getRHS();
+ // Check for symbol expression.
+ AffineExpr mulExpr =
+ getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
+ // If there was no symbol expr, check for constant expression.
+ if (!mulExpr) {
+ mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
+ }
+ auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
+ if (!mulExpr || !dimExpr)
+ return failure();
+ unsigned dim = dimExpr.getPosition();
+ if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
+ return failure();
+ convolvedDims.insert(dim);
+ return success();
+ }
+ return failure();
+ }
+};
+} // namespace
+
+static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
+ assert(map.isProjectedPermutation() &&
+ "expected map to have projected permutations");
+ llvm::SmallDenseSet<unsigned> preservedDims;
+ for (auto expr : map.getResults())
+ preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
+ return preservedDims;
+}
+
+enum class MatchConvolutionResult {
+ Success = 0,
+ NotLinalgOp,
+ WrongNumOperands,
+ WrongInputIndexingMap,
+ NotProjectedPermutations,
+ NonConvolutionLoop,
+ OutputDimsNotParallel,
+ NonOutputDimNotReduction
+};
+
+static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp)
+ return MatchConvolutionResult::NotLinalgOp;
+ if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1)
+ return MatchConvolutionResult::WrongNumOperands;
+
+ auto indexingMaps = linalgOp.getIndexingMaps();
+
+ // Check the input indexing map has the right form.
+ ConvAccessExprWalker inputExprWalker;
+ if (llvm::any_of(indexingMaps[0].getResults(),
+ [&inputExprWalker](AffineExpr expr) {
+ return failed(inputExprWalker.visit(expr));
+ })) {
+ return MatchConvolutionResult::WrongInputIndexingMap;
+ }
+
+ // Filter and output maps must be projected permutation.
+ if (!indexingMaps[1].isProjectedPermutation() ||
+ !indexingMaps.back().isProjectedPermutation())
+ return MatchConvolutionResult::NotProjectedPermutations;
+
+ auto iteratorTypesRange =
+ linalgOp.iterator_types().getAsValueRange<StringAttr>();
+
+ llvm::SmallDenseSet<unsigned> outputDims =
+ getPreservedDims(indexingMaps.back());
+ llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
+ // Make sure all loops are charecterized as one of:
+ // - Batch loop : present in output, as non-convolved in input, not present in
+ // filter.
+ // - Output image dimension : present in output, convolved dims in input, not
+ // present in filter.
+ // - Output channel dimension : present in output, not present in input,
+ // present in filter.
+ // - Filter loop dimension : present in filter, convolved in input, not
+ // present in output.
+ // - Input channel dimension : unconvolved in input, not present in output,
+ // present in filter.
+ // - Depth multiplier : unconvolved in input, present in output, present in
+ // filter.
+ llvm::SmallDenseSet<unsigned> allLoopDims;
+ for (auto outputExpr : indexingMaps.back().getResults()) {
+ unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
+ if (inputExprWalker.unConvolvedDims.count(outputDim) &&
+ !filterDims.count(outputDim)) {
+ // Batch dimension.
+ if (*std::next(iteratorTypesRange.begin(), outputDim) !=
+ getParallelIteratorTypeName())
+ return MatchConvolutionResult::OutputDimsNotParallel;
+ allLoopDims.insert(outputDim);
+ continue;
+ }
+ if (inputExprWalker.convolvedDims.count(outputDim) &&
+ !filterDims.count(outputDim)) {
+ // Output image Loop dimension.
+ if (*std::next(iteratorTypesRange.begin(), outputDim) !=
+ getParallelIteratorTypeName())
+ return MatchConvolutionResult::OutputDimsNotParallel;
+ allLoopDims.insert(outputDim);
+ continue;
+ }
+ if (!inputExprWalker.convolvedDims.count(outputDim) &&
+ !inputExprWalker.unConvolvedDims.count(outputDim) &&
+ filterDims.count(outputDim)) {
+ // Output channel dimension.
+ if (*std::next(iteratorTypesRange.begin(), outputDim) !=
+ getParallelIteratorTypeName())
+ return MatchConvolutionResult::OutputDimsNotParallel;
+ allLoopDims.insert(outputDim);
+ continue;
+ }
+ if (inputExprWalker.unConvolvedDims.count(outputDim) &&
+ filterDims.count(outputDim)) {
+ // Depth multiplier.
+ if (*std::next(iteratorTypesRange.begin(), outputDim) !=
+ getParallelIteratorTypeName())
+ return MatchConvolutionResult::OutputDimsNotParallel;
+ allLoopDims.insert(outputDim);
+ continue;
+ }
+ return MatchConvolutionResult::NonConvolutionLoop;
+ }
+ for (auto filterExpr : indexingMaps[1].getResults()) {
+ unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
+ if (outputDims.count(filterDim) &&
+ !inputExprWalker.unConvolvedDims.count(filterDim) &&
+ !inputExprWalker.convolvedDims.count(filterDim)) {
+ // Output channel dimension. THis is already seen, continue;
+ continue;
+ }
+ if (inputExprWalker.convolvedDims.count(filterDim) &&
+ !outputDims.count(filterDim)) {
+ // Filter loop dimension.
+ if (*std::next(iteratorTypesRange.begin(), filterDim) !=
+ getReductionIteratorTypeName())
+ return MatchConvolutionResult::NonOutputDimNotReduction;
+ if (allLoopDims.count(filterDim))
+ return MatchConvolutionResult::NonConvolutionLoop;
+ allLoopDims.insert(filterDim);
+ continue;
+ }
+ if (inputExprWalker.unConvolvedDims.count(filterDim) &&
+ !outputDims.count(filterDim)) {
+ // Input channel dimension.
+ if (*std::next(iteratorTypesRange.begin(), filterDim) !=
+ getReductionIteratorTypeName())
+ return MatchConvolutionResult::NonOutputDimNotReduction;
+ if (allLoopDims.count(filterDim))
+ return MatchConvolutionResult::NonConvolutionLoop;
+ allLoopDims.insert(filterDim);
+ continue;
+ }
+ if (inputExprWalker.unConvolvedDims.count(filterDim) &&
+ outputDims.count(filterDim)) {
+ // Depthwise loop. Already seen.
+ continue;
+ }
+ return MatchConvolutionResult::NonConvolutionLoop;
+ }
+ // All loops must be covered now.
+ if (allLoopDims.size() != linalgOp.getNumLoops())
+ return MatchConvolutionResult::NonConvolutionLoop;
+
+ return MatchConvolutionResult::Success;
+}
+
+LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
+ auto res = isConvolutionInterfaceImpl(op);
+ if (res == MatchConvolutionResult::NotLinalgOp)
+ return op->emitError("expected a LinalgOp");
+ if (res == MatchConvolutionResult::WrongNumOperands)
+ return op->emitError("expected op with 2 inputs and 1 output");
+ if (res == MatchConvolutionResult::WrongInputIndexingMap)
+ return op->emitError("unexpected input index map for convolutions");
+ if (res == MatchConvolutionResult::NotProjectedPermutations) {
+ return op->emitError(
+ "expected output/filter indexing maps to be projected permutations");
+ }
+ if (res == MatchConvolutionResult::NonConvolutionLoop) {
+ return op->emitError("unexpected loop dimension for convolution op");
+ }
+ if (res == MatchConvolutionResult::OutputDimsNotParallel) {
+ return op->emitError(
+ "expected all iterators used to access outputs to be parallel");
+ }
+ if (res == MatchConvolutionResult::NonOutputDimNotReduction) {
+ return op->emitError(
+ "expected all iterators not used to access outputs to be reduction");
+ }
+ return success();
+}
//===----------------------------------------------------------------------===//
// StructuredOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index f54d2a5855388..c3894002914fa 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -484,7 +484,7 @@ def __init__(self, cpp_name: str):
ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
-
+ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
class OpMetadataDef(YAMLObject):
"""Metadata about the op (generally not behavior impacting)."""
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 7a804b7a9fff5..b78a2179737f5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -154,6 +154,7 @@ def conv_1d(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.ow, D.kw)
O[D.ow] += cast(
U, I[D.ow + D.kw]) * cast(U, K[D.kw])
@@ -168,6 +169,7 @@ def conv_2d(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.oh, D.ow, D.kh, D.kw)
O[D.oh, D.ow] += cast(
U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw])
@@ -182,6 +184,7 @@ def conv_3d(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw)
O[D.od, D.oh, D.ow] += cast(
U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kd, D.kh, D.kw])
@@ -198,6 +201,7 @@ def conv_1d_nwc_wcf(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.ow, D.f, D.kw, D.c)
O[D.n, D.ow, D.f] += cast(
U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c
@@ -219,6 +223,7 @@ def conv_2d_nhwc_hwcf(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
@@ -243,6 +248,7 @@ def conv_2d_nhwc_hwcf_q(
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += (cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
@@ -264,6 +270,7 @@ def conv_2d_nchw_fchw(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW
@@ -282,6 +289,7 @@ def conv_3d_ndhwc_dhwcf(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
O[D.n, D.od, D.oh, D.ow, D.f] += cast(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c
@@ -300,6 +308,7 @@ def depthwise_conv2D_nhw(
them to the same data type as the accumulator/output. Multiplier is set to 1
which is a special case for most dpethwise convolutions.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -319,6 +328,7 @@ def depthwise_conv2D_nhw_q(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic] += (
(cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -337,6 +347,7 @@ def depthwise_conv2D_nhwc(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic, D.cm] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -356,6 +367,7 @@ def depthwise_conv2D_nhwc_q(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw)
O[D.n, D.oh, D.ow, D.ic, D.cm] += (
(cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -375,6 +387,7 @@ def pooling_nhwc_sum(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
@@ -392,6 +405,7 @@ def pooling_nhwc_max(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -409,6 +423,7 @@ def pooling_nchw_max(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -426,6 +441,7 @@ def pooling_nhwc_min(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -445,6 +461,7 @@ def pooling_ndhwc_sum(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
O[D.n, D.od, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
@@ -464,6 +481,7 @@ def pooling_ndhwc_max(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)(
cast(
@@ -484,6 +502,7 @@ def pooling_ndhwc_min(
Numeric casting is performed on the input operand, promoting it to the same
data type as the accumulator/output.
"""
+ implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)(
cast(
diff --git a/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir b/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir
new file mode 100644
index 0000000000000..b46845aae265c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+func @test_conv_op_not_linalg_op(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>,
+ %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{expected a LinalgOp}}
+ %0 = "test.conv_op_not_linalg_op"(%arg0, %arg1, %arg2)
+ : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Check for number of operands being >= 2.
+#map = affine_map<(d0) -> (d0)>
+func @test_conv_op_wrong_num_operands(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{expected op with 2 inputs and 1 output}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg3 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @test_conv_op_wrong_input_indexing_map1(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{unexpected input index map for convolution}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0 * 2)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @test_conv_op_wrong_input_indexing_map2(%arg0 : tensor<?x?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{unexpected input index map for convolution}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d0)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @test_conv_op_filter_index_map_not_projection(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d1 + d0)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @test_conv_op_output_index_map_not_projection(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0 + d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Convolution op illegal if a loop dimension is used to access
+// output, filter and is convolved.
+func @test_conv_op_output_filter_convolved(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{unexpected loop dimension for convolution op}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// Convolution op illegal if a loop dimension is used only in the output.
+func @test_conv_op_output_only_dim(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{unexpected loop dimension for convolution op}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
+ affine_map<(d0, d1, d2) -> (d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// Convolution op illegal if a loop dimension is used only in the filter.
+func @test_conv_op_filter_only_dim(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{unexpected loop dimension for convolution op}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0)>],
+ iterator_types = ["parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Convolution op illegal if a loop dimension is used only in the input.
+func @test_conv_op_input_only_dim(%arg0 : tensor<?x?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{unexpected loop dimension for convolution op}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d1)>,
+ affine_map<(d0, d1, d2) -> (d0)>],
+ iterator_types = ["parallel", "reduction", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Convolution op illegal if a loop dimension accessing output is not parallel.
+func @test_conv_op_non_output_access_loop_parallel(%arg0 : tensor<?xf32>,
+ %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error @+1 {{expected all iterators not used to access outputs to be reduction}}
+ %0 = test.linalg_conv_op {
+ indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ linalg.yield %arg5 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 3592d592acc3a..525e4e05e5369 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -160,9 +160,9 @@ func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>
func @generic_empty_region(%arg0: memref<f32>) {
%f0 = constant 0.0: f32
- // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
+ // expected-error @+1 {{op expected 1 region with 1 block}}
linalg.generic {
- indexing_maps = [ affine_map<() -> (0)> ],
+ indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
iterator_types = []}
ins(%arg0 : memref<f32>)
outs(%arg0 : memref<f32>) {
@@ -275,8 +275,8 @@ func @generic(%arg0: memref<?x?xi4>) {
// expected-error @+2 {{op expects regions to end with 'linalg.yield', found 'std.addf'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'linalg.yield'}}
linalg.generic {
- indexing_maps = [ affine_map<(i) -> (i)> ],
- iterator_types = ["parallel"]}
+ indexing_maps = [ affine_map<(i, j) -> (i, j)> ],
+ iterator_types = ["parallel", "parallel"]}
outs(%arg0 : memref<?x?xi4>) {
^bb(%0: i4) :
%1 = std.addf %0, %0: i4
@@ -675,18 +675,18 @@ func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
// -----
#attrs = {
- indexing_maps = [
- affine_map<(i) -> (3 - i)>,
- affine_map<(i) -> (i)>
- ],
- iterator_types = ["parallel"]
+ indexing_maps = [
+ affine_map<(i) -> (3 - i)>,
+ affine_map<(i) -> (i)>
+ ],
+ iterator_types = ["parallel"]
}
func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
// expected-error @+1 {{unexpected result less than 0 at expression #0 in}}
linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) {
- ^bb0(%a: f32, %b: f32):
- linalg.yield %a : f32
- }
- return
+ ^bb0(%a: f32, %b: f32):
+ linalg.yield %a : f32
+ }
+ return
}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 208e4f2baf5f5..c6e5dc0f25aa8 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -543,3 +543,37 @@ func @pooling_ndhwc_min(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>,
outs(%output: memref<1x2x2x2x1xf32>)
return
}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 2, d2 * 2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+func @conv_interface_wrong_input_indexing_map(
+ %arg0 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // expected-error @+1 {{unexpected input index map for convolutions}}
+ %0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ( {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32): // no predecessors
+ %1 = "std.mulf"(%arg3, %arg4) : (f32, f32) -> f32
+ %2 = "std.addf"(%arg5, %1) : (f32, f32) -> f32
+ "linalg.yield"(%2) : (f32) -> ()
+ }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>, strides = dense<2> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3, d5 + 1)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+func @conv_interface_wrong_num_operands(
+ %arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
+ %0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ( {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32): // no predecessors
+ %1 = "std.mulf"(%arg3, %arg4) : (f32, f32) -> f32
+ %2 = "std.addf"(%arg5, %1) : (f32, f32) -> f32
+ "linalg.yield"(%2) : (f32) -> ()
+ }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = dense<[2, 1]> : vector<2xi32>, strides = dense<1> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 91af79e9578ce..52d1cac3b310d 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -61,6 +61,7 @@ add_mlir_library(MLIRTestDialect
MLIRDLTI
MLIRIR
MLIRInferTypeOpInterface
+ MLIRLinalg
MLIRLinalgTransforms
MLIRLLVMIR
MLIRPass
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 5aca160c3f183..b887ff77c2dc4 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -17,6 +17,7 @@
#include "TestInterfaces.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a887adbc27055..858ce7df5de0a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -20,6 +20,7 @@ include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "TestInterfaces.td"
def Test_Dialect : Dialect {
@@ -2272,4 +2273,56 @@ def OpCrashShort : TEST_Op<"op_crash_short"> {
def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
+//===----------------------------------------------------------------------===//
+// Test LinalgConvolutionOpInterface.
+//===----------------------------------------------------------------------===//
+
+def TestLinalgConvOpNotLinalgOp : TEST_Op<"conv_op_not_linalg_op", [
+ LinalgConvolutionOpInterface]> {
+ let arguments = (ins
+ AnyType:$image, AnyType:$filter, AnyType:$output);
+ let results = (outs AnyRankedTensor:$result);
+}
+
+def TestLinalgConvOp :
+ TEST_Op<"linalg_conv_op", [AttrSizedOperandSegments,
+ LinalgStructuredInterface, LinalgConvolutionOpInterface]> {
+
+ let arguments = (ins Variadic<AnyType>:$inputs,
+ Variadic<AnyType>:$outputs);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region AnyRegion:$region);
+
+ let assemblyFormat = [{
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ `outs` `(` $outputs `:` type($outputs) `)`
+ $region (`->` type($results)^)?
+ }];
+
+ let extraClassDeclaration = [{
+ bool hasIndexSemantics() { return false; }
+
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block) {
+ b.create<mlir::linalg::YieldOp>(block.getArguments().back());
+ }
+
+ static std::function<void(mlir::ImplicitLocOpBuilder &b, mlir::Block &block)>
+ getRegionBuilder() {
+ return ®ionBuilder;
+ }
+
+ mlir::ArrayAttr iterator_types() {
+ return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
+ }
+
+ mlir::ArrayAttr indexing_maps() {
+ return getOperation()->getAttrOfType<mlir::ArrayAttr>("indexing_maps");
+ }
+
+ std::string getLibraryCallName() {
+ return "";
+ }
+ }];
+}
+
#endif // TEST_OPS
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 7a095f29194d0..72134154a686b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -41,6 +41,7 @@ td_library(
"//mlir:DLTIDialectTdFiles",
"//mlir:DataLayoutInterfacesTdFiles",
"//mlir:InferTypeOpInterfaceTdFiles",
+ "//mlir:LinalgStructuredOpsTdFiles",
"//mlir:OpBaseTdFiles",
"//mlir:SideEffectTdFiles",
],
@@ -210,6 +211,8 @@ cc_library(
"//mlir:IR",
"//mlir:InferTypeOpInterface",
"//mlir:LLVMDialect",
+ "//mlir:LinalgInterfaces",
+ "//mlir:LinalgOps",
"//mlir:Pass",
"//mlir:Reducer",
"//mlir:SideEffects",
More information about the Mlir-commits
mailing list