[Mlir-commits] [mlir] [Linalg] Introduce ConvMatchBuilder + refactor Conv matchers (PR #169704)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 26 10:06:59 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

<details>
<summary>Changes</summary>

-- This commit is a follow-up and third in the series of adding
   matchers for conv/pool ops. Refer: https://github.com/llvm/llvm-project/pull/163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D thread](https://github.com/llvm/llvm-project/pull/168362#issuecomment-3575972133) for further context.

Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>

---

Patch is 35.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169704.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+289-454) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..f7b86e7c385de 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -416,10 +416,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
   return false;
 }
 
-// ---------------------------------------------
-// Matchers for specific convolution operation.
-// ---------------------------------------------
-
 /// Returns true if the given indexing maps matches with the expected indexing
 /// maps.
 static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
@@ -434,6 +430,92 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
                           })));
 }
 
+/// Enum of all kinds of Pooling Op's type.
+enum PoolingType {
+  NONE,
+  MAX_SIGNED,
+  MAX_UNSIGNED,
+  MIN_SIGNED,
+  MIN_UNSIGNED,
+  SUM
+};
+
+/// Helper class for building convolution op matchers with minimal boilerplate.
+/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
+/// as Pooling ops.
+class ConvMatcherBuilder {
+  LinalgOp op;
+  MLIRContext *ctx;
+  SmallVector<int64_t> *dilations, *strides;
+  ArrayAttr indexingMaps;
+  PoolingType poolingType;
+  bool matched = true;
+
+public:
+  ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
+                     SmallVector<int64_t> *s,
+                     PoolingType poolingType = PoolingType::NONE)
+      : op(op), ctx(op->getContext()), dilations(d), strides(s),
+        indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
+    *dilations = SmallVector<int64_t>(spatialRank, 1);
+    *strides = SmallVector<int64_t>(spatialRank, 1);
+  }
+
+  /// Get affine dimension expression for dimension i.
+  AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
+
+  /// Build strided expression: base * stride[idx] + kernel * dilation[idx]
+  AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
+    return base * (*strides)[idx] + kernel * (*dilations)[idx];
+  }
+
+  /// Match stride/dilation pattern for a spatial dimension.
+  /// Returns *this for method chaining.
+  ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
+                                  unsigned idx) {
+    if (matched) {
+      matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+                                           (*dilations)[idx], (*strides)[idx]);
+    }
+    return *this;
+  }
+
+  /// Match expected indexing maps layout.
+  /// Returns *this for method chaining.
+  ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+    if (matched)
+      matched = convLayoutMatches(maps, indexingMaps, ctx);
+    return *this;
+  }
+
+  /// Match body pattern. This should be called last.
+  bool matchBody() {
+    if (!matched)
+      return false;
+    Block *body = op.getBlock();
+    auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+    switch (poolingType) {
+    case PoolingType::NONE:
+      return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
+    case PoolingType::MAX_SIGNED:
+      return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::MAX_UNSIGNED:
+      return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::MIN_SIGNED:
+      return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::MIN_UNSIGNED:
+      return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
+    case PoolingType::SUM:
+      return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
+    }
+    return false;
+  }
+};
+
+// ---------------------------------------------
+// Matchers for specific convolution operation.
+// ---------------------------------------------
+
 // #inputMap = affine_map<(W, w) -> (W + w)>
 // #filterMap = affine_map<(W, w) -> (w)>
 // #outputMap = affine_map<(W, w) -> (W)>
@@ -447,29 +529,15 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr W = getAffineDimExpr(0, context);
-  AffineExpr w = getAffineDimExpr(1, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{w},
-           /*outputMap=*/{W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr W = m.dim(0);
+  AffineExpr w = m.dim(1);
+
+  return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
+                   /*filterMap=*/{w},
+                   /*outputMap=*/{W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
@@ -485,32 +553,18 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr F = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  AffineExpr c = getAffineDimExpr(4, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
-           /*filterMap=*/{w, c, F},
-           /*outputMap=*/{N, W, F}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr F = m.dim(2);
+  AffineExpr w = m.dim(3);
+  AffineExpr c = m.dim(4);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+                   /*filterMap=*/{w, c, F},
+                   /*outputMap=*/{N, W, F}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
@@ -526,32 +580,18 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr F = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr c = getAffineDimExpr(3, context);
-  AffineExpr w = getAffineDimExpr(4, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{F, c, w},
-           /*outputMap=*/{N, F, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr F = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr c = m.dim(3);
+  AffineExpr w = m.dim(4);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+                   /*filterMap=*/{F, c, w},
+                   /*outputMap=*/{N, F, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(H, W, h, w) -> (H + h, W + w)>
@@ -567,36 +607,18 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr H = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr h = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
-                         W * (*strides)[1] + w * (*dilations)[1]},
-           /*filterMap=*/{h, w},
-           /*outputMap=*/{H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+  AffineExpr H = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr h = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+      .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+      .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+                   /*filterMap=*/{h, w},
+                   /*outputMap=*/{H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
@@ -612,43 +634,22 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(3, 1);
-  *strides = SmallVector<int64_t>(3, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr D = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr d = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: D * stride + d * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
-    return false;
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
-                         H * (*strides)[1] + h * (*dilations)[1],
-                         W * (*strides)[2] + w * (*dilations)[2]},
-           /*filterMap=*/{d, h, w},
-           /*outputMap=*/{D, H, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+  AffineExpr D = m.dim(0);
+  AffineExpr H = m.dim(1);
+  AffineExpr W = m.dim(2);
+  AffineExpr d = m.dim(3);
+  AffineExpr h = m.dim(4);
+  AffineExpr w = m.dim(5);
+
+  return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+      .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+      .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
+      .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+                                 m.strided(W, w, 2)},
+                   /*filterMap=*/{d, h, w},
+                   /*outputMap=*/{D, H, W}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
@@ -664,31 +665,17 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{C, w},
-           /*outputMap=*/{N, C, W}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+                   /*filterMap=*/{C, w},
+                   /*outputMap=*/{N, C, W}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
@@ -704,31 +691,17 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C},
-           /*outputMap=*/{N, W, C}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr w = m.dim(3);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                   /*filterMap=*/{w, C},
+                   /*outputMap=*/{N, W, C}})
+      .matchBody();
 }
 
 // #inputMap  = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
@@ -744,32 +717,18 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr CM = getAffineDimExpr(3, context);
-  AffineExpr w = getAffineDimExpr(4, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
-                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
-    return false;
-  // Match expected indexing maps
-  if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C, CM},
-           /*outputMap=*/{N, W, C, CM}},
-          indexingMaps, context))
-    return false;
-  // Match body
-  Block *body = op.getBlock();
-  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
-  Value yieldVal = yieldOp.getOperand(0);
-  return bodyMatcherForConvolutionOps(yieldVal, body);
+  ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+  AffineExpr N = m.dim(0);
+  AffineExpr W = m.dim(1);
+  AffineExpr C = m.dim(2);
+  AffineExpr CM = m.dim(3);
+  AffineExpr w = m.dim(4);
+
+  return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                   /*filterMap=*/{w, C, CM},
+                   /*outputMap=*/{N, W, C, CM}})
+      .matchBody();
 }
 
 // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
@@ -785,38 +744,20 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(2, 1);
-  *strides = SmallVector<int64_t>(2, 1);
-  MLIRContext *context = op->getContext();
-  AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr H = getAffineDimExpr(1, context);
-  AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr C = getAffineDimExpr(3, context);
-  AffineExpr h = getAffineDimExpr(4, context);
-  AffineExpr w = getAffineDimExpr(5, context);
-  ArrayAttr indexingMaps = op.getIndexingMaps();
-  // First fetch dilations/strides :-
-  // Match: H * stride + h * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/169704


More information about the Mlir-commits mailing list