[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)
Abhishek Varma
llvmlistbot at llvm.org
Tue Nov 25 00:33:36 PST 2025
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/168362
>From 7b4d4b714b43684a87f5a35dea377402f8262801 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 17 Nov 2025 06:44:55 -0600
Subject: [PATCH 1/2] [Linalg] Add *Conv2D* matchers
-- This commit is the second in the series of adding matchers
for linalg.*conv*/*pool*. Refer: https://github.com/llvm/llvm-project/pull/163724
-- In this commit all variants of Conv2D convolution ops have been
added.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 15 +
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 973 ++++++++++++++++--
.../convolution/roundtrip-convolution.mlir | 155 ++-
3 files changed, 1046 insertions(+), 97 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index c2485a08932dd..b52b93f8cc9b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -279,6 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
CONV_OP_SPECIALIZER(linalg::Conv3DOp);
// -----------------------------
// Depthwise Convolution ops.
@@ -287,6 +298,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
// -----------------------------
// Pooling ops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..57593abac7ab0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
//===----------------------------------------------------------------------===//
/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
BlockArgument blockArg = dyn_cast<BlockArgument>(val);
if ((blockArg))
return blockArg;
@@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
Operation *defOp = val.getDefiningOp();
if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
!dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
- !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+ !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+ !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
return nullptr;
}
return dyn_cast<BlockArgument>(defOp->getOperand(0));
}
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+/// %a - %b
+/// where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+ Block *body) {
+ Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+ if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+ return false;
+ Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+ if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+ return false;
+ BlockArgument inputBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+ BlockArgument inputScalarBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+ BlockArgument filterBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+ BlockArgument filterScalarBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+ BlockArgument outBlockArg =
+ getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+ if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+ !filterScalarBlockArg || !outBlockArg ||
+ inputBlockArg.getOwner() != body ||
+ inputScalarBlockArg.getOwner() != body ||
+ filterBlockArg.getOwner() != body ||
+ filterScalarBlockArg.getOwner() != body ||
+ outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+ inputScalarBlockArg.getArgNumber() != 2 ||
+ filterBlockArg.getArgNumber() != 1 ||
+ filterScalarBlockArg.getArgNumber() != 3 ||
+ outBlockArg.getArgNumber() != 4)
+ return false;
+ return true;
+}
+
/// Utility to match block body for convolution ops.
/// The body is thus expected to yield :-
/// %out + (%lhs * %rhs)
/// where: %lhs, %rhs and %out are block arguments and
/// %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+/// %input - %input_scalar
+/// where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+ bool zeroPointOffset = false) {
Operation *addOp = yieldVal.getDefiningOp();
if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
return false;
@@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
return false;
+ if (zeroPointOffset) {
+ return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+ }
BlockArgument lhsBlockArg =
- getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
BlockArgument rhsBlockArg =
- getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
BlockArgument outBlockArg =
- getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
return false;
BlockArgument lhsArg =
- getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(defOp->getOperand(0));
BlockArgument rhsArg =
- getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(defOp->getOperand(1));
if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
rhsArg.getArgNumber() != 0)
@@ -599,49 +646,45 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
-// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
-// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv3DOp>(op))
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcFhwcOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(3, 1);
- *strides = SmallVector<int64_t>(3, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
- AffineExpr D = getAffineDimExpr(0, context);
+ AffineExpr N = getAffineDimExpr(0, context);
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr d = getAffineDimExpr(3, context);
+ AffineExpr F = getAffineDimExpr(3, context);
AffineExpr h = getAffineDimExpr(4, context);
AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
- // Match: D * stride + d * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
- /*oDim=*/0, (*dilations)[0], (*strides)[0]))
- return false;
// Match: H * stride + h * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
- /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
- /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
- H * (*strides)[1] + h * (*dilations)[1],
- W * (*strides)[2] + w * (*dilations)[2]},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{D, H, W}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{F, h, w, c},
+ /*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
// Match body
@@ -651,37 +694,45 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
-// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ if (isa<linalg::Conv2DNhwcHwcfOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr F = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
- /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
- /*filterMap=*/{C, w},
- /*outputMap=*/{N, C, W}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c, F},
+ /*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
// Match body
@@ -691,37 +742,196 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ if (isa<linalg::Conv2DNchwFchwOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
+ AffineExpr F = getAffineDimExpr(1, context);
+ AffineExpr H = getAffineDimExpr(2, context);
+ AffineExpr W = getAffineDimExpr(3, context);
+ AffineExpr C = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{F, C, h, w},
+ /*outputMap=*/{N, F, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr F = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
- /*filterMap=*/{w, C},
- /*outputMap=*/{N, W, C}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{F, h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, F}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNchwFchwQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr F = getAffineDimExpr(1, context);
+ AffineExpr H = getAffineDimExpr(2, context);
+ AffineExpr W = getAffineDimExpr(3, context);
+ AffineExpr C = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{F, C, h, w},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, F, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (FG, G, C, h, w)>
+// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr G = getAffineDimExpr(1, context);
+ AffineExpr FG = getAffineDimExpr(2, context);
+ AffineExpr H = getAffineDimExpr(3, context);
+ AffineExpr W = getAffineDimExpr(4, context);
+ AffineExpr C = getAffineDimExpr(5, context);
+ AffineExpr h = getAffineDimExpr(6, context);
+ AffineExpr w = getAffineDimExpr(7, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{FG, G, C, h, w},
+ /*outputMap=*/{N, G, FG, H, W}},
indexingMaps, context))
return false;
// Match body
@@ -731,38 +941,46 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
-// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+// #inputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
+// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+ if (isa<linalg::Conv2DNgchwGfchwOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr CM = getAffineDimExpr(3, context);
- AffineExpr w = getAffineDimExpr(4, context);
+ AffineExpr G = getAffineDimExpr(1, context);
+ AffineExpr FG = getAffineDimExpr(2, context);
+ AffineExpr H = getAffineDimExpr(3, context);
+ AffineExpr W = getAffineDimExpr(4, context);
+ AffineExpr C = getAffineDimExpr(5, context);
+ AffineExpr h = getAffineDimExpr(6, context);
+ AffineExpr w = getAffineDimExpr(7, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[0], (*strides)[0]))
+ return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
- /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
- /*filterMap=*/{w, C, CM},
- /*outputMap=*/{N, W, C, CM}},
+ {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{G, FG, C, h, w},
+ /*outputMap=*/{N, G, FG, H, W}},
indexingMaps, context))
return false;
// Match body
@@ -772,14 +990,15 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
-// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
-// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -791,32 +1010,604 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
AffineExpr N = getAffineDimExpr(0, context);
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr C = getAffineDimExpr(3, context);
+ AffineExpr F = getAffineDimExpr(3, context);
AffineExpr h = getAffineDimExpr(4, context);
AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
- /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
- /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
- W * (*strides)[1] + w * (*dilations)[1]},
- /*filterMap=*/{C, h, w},
- /*outputMap=*/{N, C, H, W}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c, F},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
+// #filterMap = affine_map<(N, H, W, G, FG, h, w, c) -> (G, FG, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, G, FG, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H, W, G, FG)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr G = getAffineDimExpr(3, context);
+ AffineExpr FG = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ AffineExpr c = getAffineDimExpr(7, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], G, c},
+ /*filterMap=*/{G, FG, h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, G, FG}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
+// #scalarMap = affine_map<(N, G, FG, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr G = getAffineDimExpr(1, context);
+ AffineExpr FG = getAffineDimExpr(2, context);
+ AffineExpr H = getAffineDimExpr(3, context);
+ AffineExpr W = getAffineDimExpr(4, context);
+ AffineExpr C = getAffineDimExpr(5, context);
+ AffineExpr h = getAffineDimExpr(6, context);
+ AffineExpr w = getAffineDimExpr(7, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{G, FG, C, h, w},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, G, FG, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H + h, W + w, G, C)>
+// #filterMap = affine_map<(N, H, W, G, FG, h, w, C) -> (G, FG, h, w, C)>
+// #outputMap = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H, W, G, FG)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr G = getAffineDimExpr(3, context);
+ AffineExpr FG = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ AffineExpr C = getAffineDimExpr(7, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], G, C},
+ /*filterMap=*/{G, FG, h, w, C},
+ /*outputMap=*/{N, H, W, G, FG}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
+// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
+// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv3DOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(3, 1);
+ *strides = SmallVector<int64_t>(3, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr D = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr d = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: D * stride + d * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+ /*oDim=*/0, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
+ H * (*strides)[1] + h * (*dilations)[1],
+ W * (*strides)[2] + w * (*dilations)[2]},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
+// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(1, 1);
+ *strides = SmallVector<int64_t>(1, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr W = getAffineDimExpr(1, context);
+ AffineExpr C = getAffineDimExpr(2, context);
+ AffineExpr w = getAffineDimExpr(3, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(1, 1);
+ *strides = SmallVector<int64_t>(1, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr W = getAffineDimExpr(1, context);
+ AffineExpr C = getAffineDimExpr(2, context);
+ AffineExpr w = getAffineDimExpr(3, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
+// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(1, 1);
+ *strides = SmallVector<int64_t>(1, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr W = getAffineDimExpr(1, context);
+ AffineExpr C = getAffineDimExpr(2, context);
+ AffineExpr CM = getAffineDimExpr(3, context);
+ AffineExpr w = getAffineDimExpr(4, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr C = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{C, h, w},
+ /*outputMap=*/{N, C, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
+// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr c = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c},
+ /*outputMap=*/{N, H, W, c}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
+// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr c = getAffineDimExpr(3, context);
+ AffineExpr cm = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c, cm},
+ /*outputMap=*/{N, H, W, c, cm}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr c = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, c}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
+// #scalarMap = affine_map<(N, H, W, c, cm, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr c = getAffineDimExpr(3, context);
+ AffineExpr cm = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c, cm},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, c, cm}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
}
// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 4b2d42a3ae4e0..11a2672d04632 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -4,9 +4,9 @@
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic
-// -----------------------------
+//====================
// Convolution ops.
-// -----------------------------
+//====================
func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> {
%0 = linalg.conv_1d
ins(%in, %filter : tensor<?xf32>, tensor<?xf32>)
@@ -55,6 +55,97 @@ func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tens
// -----
+func.func @conv_2d_nchw_fchw(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nchw_fchw
+ {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+ ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nchw_fchw
+// CHECK: linalg.conv_2d_nchw_fchw
+// CHECK-SAME: dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.conv_2d_nchw_fchw_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+ outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @conv_2d_nchw_fchw_q
+// CHECK: linalg.conv_2d_nchw_fchw_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_2d_ngchw_fgchw
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @conv_2d_ngchw_fgchw
+// CHECK: linalg.conv_2d_ngchw_fgchw
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+ %0 = linalg.conv_2d_ngchw_gfchw
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>)
+ outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?x?xi32>
+}
+// CHECK: @conv_2d_ngchw_gfchw
+// CHECK: linalg.conv_2d_ngchw_gfchw
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+ %0 = linalg.conv_2d_ngchw_gfchw_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+ outs (%output: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?x?xi32>
+}
+// CHECK: @conv_2d_ngchw_gfchw_q
+// CHECK: linalg.conv_2d_ngchw_gfchw_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+ outs(%output : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwc_hwcf_q
+// CHECK: linalg.conv_2d_nhwc_hwcf_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwgc_gfhwc_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, i32, i32)
+ outs(%output : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwgc_gfhwc_q
+// CHECK: linalg.conv_2d_nhwgc_gfhwc_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = linalg.conv_3d
ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
@@ -66,9 +157,9 @@ func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out :
// -----
-// -----------------------------
+//====================
// Depthwise Convolution ops.
-// -----------------------------
+//====================
func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = linalg.depthwise_conv_1d_ncw_cw
{dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
@@ -121,6 +212,58 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tens
// -----
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwc
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>{
+ %res = linalg.depthwise_conv_2d_nhwc_hwc_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32)
+ outs(%output : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %res : tensor<?x?x?x?xi32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwc_q
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwcm
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x?xi8>, %arg2: tensor<?x?x?x?x?xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x?xi32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+ outs(%arg2 : tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?x?xi32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwcm_q
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
%0 = linalg.depthwise_conv_3d_ndhwc_dhwcm
{dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
@@ -134,9 +277,9 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter:
// -----
-// -----------------------------
+//====================
// Pooling ops.
-// -----------------------------
+//====================
func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%0 = linalg.pooling_nhwc_max
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
>From ba9e125e29b11b742c43847e310a2af9e25e91d6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 25 Nov 2025 02:32:39 -0600
Subject: [PATCH 2/2] Review comment v1.0 + Add missing tests + refactor
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 14 +-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 266 +++++++++---------
.../convolution/roundtrip-convolution.mlir | 39 +++
3 files changed, 179 insertions(+), 140 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index b52b93f8cc9b9..ae71e822afee7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -271,7 +271,7 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
#define CONV_OP_SPECIALIZER(ConvOpTy) \
if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
- strides); \
+ strides);
// -----------------------------
// Convolution ops.
// -----------------------------
@@ -279,17 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
CONV_OP_SPECIALIZER(linalg::Conv2DOp);
- CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
- CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
- CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
- CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
- CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
CONV_OP_SPECIALIZER(linalg::Conv3DOp);
// -----------------------------
// Depthwise Convolution ops.
@@ -299,8 +299,8 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
- CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
// -----------------------------
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 57593abac7ab0..0491f1b332d37 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -647,13 +647,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
}
// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+ if (isa<linalg::Conv2DNhwcHwcfOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -683,7 +683,7 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
if (!convLayoutMatches(
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1], c},
- /*filterMap=*/{F, h, w, c},
+ /*filterMap=*/{h, w, c, F},
/*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
@@ -696,12 +696,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+ if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -720,11 +721,11 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
@@ -732,6 +733,8 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1], c},
/*filterMap=*/{h, w, c, F},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
/*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
@@ -739,17 +742,17 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
}
-// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
-// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
-// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNchwFchwOp>(op))
+ if (isa<linalg::Conv2DNhwcFhwcOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -759,28 +762,28 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
*strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr F = getAffineDimExpr(1, context);
- AffineExpr H = getAffineDimExpr(2, context);
- AffineExpr W = getAffineDimExpr(3, context);
- AffineExpr C = getAffineDimExpr(4, context);
- AffineExpr h = getAffineDimExpr(5, context);
- AffineExpr w = getAffineDimExpr(6, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr F = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
- /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
- /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
- W * (*strides)[1] + w * (*dilations)[1]},
- /*filterMap=*/{F, C, h, w},
- /*outputMap=*/{N, F, H, W}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{F, h, w, c},
+ /*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
// Match body
@@ -841,6 +844,54 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
}
+// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNchwFchwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr F = getAffineDimExpr(1, context);
+ AffineExpr H = getAffineDimExpr(2, context);
+ AffineExpr W = getAffineDimExpr(3, context);
+ AffineExpr C = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{F, C, h, w},
+ /*outputMap=*/{N, F, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
// #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
@@ -893,13 +944,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
}
// #inputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
-// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (FG, G, C, h, w)>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+ if (isa<linalg::Conv2DNgchwGfchwOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -930,7 +981,7 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
if (!convLayoutMatches(
{/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1]},
- /*filterMap=*/{FG, G, C, h, w},
+ /*filterMap=*/{G, FG, C, h, w},
/*outputMap=*/{N, G, FG, H, W}},
indexingMaps, context))
return false;
@@ -943,12 +994,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
// #inputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
+// #scalarMap = affine_map<(N, G, FG, H, W, C, h, w) -> ()>
// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+ if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -980,6 +1032,8 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
{/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1]},
/*filterMap=*/{G, FG, C, h, w},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
/*outputMap=*/{N, G, FG, H, W}},
indexingMaps, context))
return false;
@@ -987,18 +1041,17 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
}
-// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
-// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
-// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+// #inputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
+// #filterMap = affine_map<(N, H, W, G, FG, h, w, c) -> (G, FG, h, w, c)>
+// #outputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H, W, G, FG)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+ if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -1010,35 +1063,34 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
AffineExpr N = getAffineDimExpr(0, context);
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr F = getAffineDimExpr(3, context);
- AffineExpr h = getAffineDimExpr(4, context);
- AffineExpr w = getAffineDimExpr(5, context);
- AffineExpr c = getAffineDimExpr(6, context);
+ AffineExpr G = getAffineDimExpr(3, context);
+ AffineExpr FG = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ AffineExpr c = getAffineDimExpr(7, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
- W * (*strides)[1] + w * (*dilations)[1], c},
- /*filterMap=*/{h, w, c, F},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, F}},
+ W * (*strides)[1] + w * (*dilations)[1], G, c},
+ /*filterMap=*/{G, FG, h, w, c},
+ /*outputMap=*/{N, H, W, G, FG}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
}
// #inputMap = affine_map<(N, H, W, G, FG, h, w, c) -> (N, H + h, W + w, G, c)>
@@ -1094,14 +1146,13 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
}
// #inputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, C, H + h, W + w)>
-// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (G, FG, C, h, w)>
-// #scalarMap = affine_map<(N, G, FG, H, W, C, h, w) -> ()>
+// #filterMap = affine_map<(N, G, FG, H, W, C, h, w) -> (FG, G, C, h, w)>
// #outputMap = affine_map<(N, G, FG, H, W, C, h, w) -> (N, G, FG, H, W)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+ if (isa<linalg::Conv2DNgchwFgchwOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -1132,9 +1183,7 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
if (!convLayoutMatches(
{/*inputMap=*/{N, G, C, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1]},
- /*filterMap=*/{G, FG, C, h, w},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
+ /*filterMap=*/{FG, G, C, h, w},
/*outputMap=*/{N, G, FG, H, W}},
indexingMaps, context))
return false;
@@ -1142,55 +1191,6 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
-}
-
-// #inputMap = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H + h, W + w, G, C)>
-// #filterMap = affine_map<(N, H, W, G, FG, h, w, C) -> (G, FG, h, w, C)>
-// #outputMap = affine_map<(N, H, W, G, FG, h, w, C) -> (N, H, W, G, FG)>
-template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
- return true;
-
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
-
- *dilations = SmallVector<int64_t>(2, 1);
- *strides = SmallVector<int64_t>(2, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr H = getAffineDimExpr(1, context);
- AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr G = getAffineDimExpr(3, context);
- AffineExpr FG = getAffineDimExpr(4, context);
- AffineExpr h = getAffineDimExpr(5, context);
- AffineExpr w = getAffineDimExpr(6, context);
- AffineExpr C = getAffineDimExpr(7, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
- /*oDim=*/1, (*dilations)[0], (*strides)[0]))
- return false;
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
- /*oDim=*/2, (*dilations)[1], (*strides)[1]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
- W * (*strides)[1] + w * (*dilations)[1], G, C},
- /*filterMap=*/{G, FG, h, w, C},
- /*outputMap=*/{N, H, W, G, FG}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}
@@ -1461,14 +1461,15 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
-// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
+// #inputMap = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
+// #filterMap = affine_map<(N, H, W, c, h, w) -> ()>
+// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+ if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -1481,9 +1482,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
AffineExpr c = getAffineDimExpr(3, context);
- AffineExpr cm = getAffineDimExpr(4, context);
- AffineExpr h = getAffineDimExpr(5, context);
- AffineExpr w = getAffineDimExpr(6, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
@@ -1498,26 +1498,27 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
if (!convLayoutMatches(
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1], c},
- /*filterMap=*/{h, w, c, cm},
- /*outputMap=*/{N, H, W, c, cm}},
+ /*filterMap=*/{h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, c}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
}
-// #inputMap = affine_map<(N, H, W, c, h, w) -> (N, H + h, W + w, c)>
-// #filterMap = affine_map<(N, H, W, c, h, w) -> (h, w, c)>
-// #filterMap = affine_map<(N, H, W, c, h, w) -> ()>
-// #outputMap = affine_map<(N, H, W, c, h, w) -> (N, H, W, c)>
+// #inputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, c, cm, h, w) -> (h, w, c, cm)>
+// #outputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H, W, c, cm)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
@@ -1530,8 +1531,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
AffineExpr c = getAffineDimExpr(3, context);
- AffineExpr h = getAffineDimExpr(4, context);
- AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr cm = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
@@ -1546,17 +1548,15 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
if (!convLayoutMatches(
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1], c},
- /*filterMap=*/{h, w, c},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, c}},
+ /*filterMap=*/{h, w, c, cm},
+ /*outputMap=*/{N, H, W, c, cm}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
}
// #inputMap = affine_map<(N, H, W, c, cm, h, w) -> (N, H + h, W + w, c)>
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 11a2672d04632..f602668af7daa 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -81,6 +81,32 @@ func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?
// -----
+func.func @conv_2d_nhwc_fhwc(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+ ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwc_fhwc
+// CHECK: linalg.conv_2d_nhwc_fhwc
+// CHECK-SAME: dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.conv_2d_nhwc_fhwc_q
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
+ outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @conv_2d_nhwc_fhwc_q
+// CHECK: linalg.conv_2d_nhwc_fhwc_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
%0 = linalg.conv_2d_ngchw_fgchw
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
@@ -133,6 +159,19 @@ func.func @conv_2d_nhwc_hwcf_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x
// -----
+func.func @conv_2d_nhwgc_gfhwc(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwgc_gfhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %filter : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs(%output : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwgc_gfhwc
+// CHECK: linalg.conv_2d_nhwgc_gfhwc
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+
+// -----
+
func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
%0 = linalg.conv_2d_nhwgc_gfhwc_q
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
More information about the Mlir-commits
mailing list