[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 17 04:55:37 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

<details>
<summary>Changes</summary>

-- This commit is the third in the series of adding matchers
   for linalg.*conv*/*pool*. Refer: https://github.com/llvm/llvm-project/pull/163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>

---

Patch is 62.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168362.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+15) 
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+882-91) 
- (modified) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+149-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index c2485a08932dd..b52b93f8cc9b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -279,6 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
   CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
   CONV_OP_SPECIALIZER(linalg::Conv3DOp);
   // -----------------------------
   // Depthwise Convolution ops.
@@ -287,6 +298,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
   // -----------------------------
   // Pooling ops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..57593abac7ab0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
   BlockArgument blockArg = dyn_cast<BlockArgument>(val);
   if ((blockArg))
     return blockArg;
@@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
   Operation *defOp = val.getDefiningOp();
   if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
       !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
-      !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+      !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
     return nullptr;
   }
   return dyn_cast<BlockArgument>(defOp->getOperand(0));
 }
 
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+///     %a - %b
+///   where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+                                           Block *body) {
+  Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+    return false;
+  Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+    return false;
+  BlockArgument inputBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+  BlockArgument inputScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+  BlockArgument filterBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+  BlockArgument filterScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+  if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+      !filterScalarBlockArg || !outBlockArg ||
+      inputBlockArg.getOwner() != body ||
+      inputScalarBlockArg.getOwner() != body ||
+      filterBlockArg.getOwner() != body ||
+      filterScalarBlockArg.getOwner() != body ||
+      outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+      inputScalarBlockArg.getArgNumber() != 2 ||
+      filterBlockArg.getArgNumber() != 1 ||
+      filterScalarBlockArg.getArgNumber() != 3 ||
+      outBlockArg.getArgNumber() != 4)
+    return false;
+  return true;
+}
+
 /// Utility to match block body for convolution ops.
 /// The body is thus expected to yield :-
 ///     %out + (%lhs * %rhs)
 ///   where: %lhs, %rhs and %out are block arguments and
 ///          %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+///       %input - %input_scalar
+///          where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+                                         bool zeroPointOffset = false) {
   Operation *addOp = yieldVal.getDefiningOp();
   if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
     return false;
@@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
   if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
     return false;
 
+  if (zeroPointOffset) {
+    return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+  }
   BlockArgument lhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
   BlockArgument rhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
   BlockArgument outBlockArg =
-      getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
       lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
       outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
     return false;
 
   BlockArgument lhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(0));
   BlockArgument rhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(1));
   if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
       rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
       rhsArg.getArgNumber() != 0)
@@ -599,49 +646,45 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
-// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
-// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
-                                              SmallVector<int64_t> *dilations,
-                                              SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DOp>(op))
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(3, 1);
-  *strides = SmallVector<int64_t>(3, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
-  AffineExpr D = getAffineDimExpr(0, context);
+  AffineExpr N = getAffineDimExpr(0, context);
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr d = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(3, context);
   AffineExpr h = getAffineDimExpr(4, context);
   AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
-  // Match: D * stride + d * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
   // Match: H * stride + h * dilation
   if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
   if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
-                         H * (*strides)[1] + h * (*dilations)[1],
-                         W * (*strides)[2] + w * (*dilations)[2]},
-           /*filterMap=*/{d, h, w},
-           /*outputMap=*/{D, H, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -651,37 +694,45 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
-// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{C, w},
-           /*outputMap=*/{N, C, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c, F},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -691,37 +742,196 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+  if (isa<linalg::Conv2DNchwFchwOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
                                   /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C},
-           /*outputMap=*/{N, W, C}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, F}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match bo...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/168362


More information about the Mlir-commits mailing list