[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Nov 25 06:39:24 PST 2025
hanhanW wrote:
@Abhishek-Varma I've been thinking how to make the implementation simpler. Below is the plan that I brainstormed with AI agent, and I think we can do the refactoring first. Then this PR will be much easier to review.
The main benefit is that the code itself documents the expected indexing maps, and it simplifies code a lot to me. Can you take a look and let me know what you think?
---
**Problem:** Each template specialization follows the exact same pattern (~50-60 lines each), leading to ~800+ lines of highly repetitive code across the 15 new matchers. This pattern is consistent across ALL conv types (1D, 2D, 3D, regular, depthwise, quantized).
The repeated boilerplate in each matcher:
```cpp
template <>
bool isaConvolutionOpOfType<linalg::SomeConvOp>(...) {
if (isa<linalg::SomeConvOp>(op)) return true; // 1. type check
assert(isaConvolutionOpInterface(op) && ...); // 2. assert
*dilations = SmallVector<int64_t>(N, 1); // 3. init
*strides = SmallVector<int64_t>(N, 1);
MLIRContext *context = op->getContext();
AffineExpr A = getAffineDimExpr(0, context); // 4. dims (5-8 lines)
AffineExpr B = getAffineDimExpr(1, context);
// ... more dims ...
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!matchConvDimAddExprPattern(...)) // 5. stride matching
return false;
// ... more stride matching ...
if (!convLayoutMatches({...}, indexingMaps, context)) // 6. map matching
return false;
Block *body = op.getBlock(); // 7. body match
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}
```
**Suggested Solution:** Introduce a `ConvMatcherBuilder` helper class to reduce each matcher from ~50 lines to ~12 lines while improving readability:
```cpp
/// Helper class for building convolution op matchers with minimal boilerplate.
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants.
class ConvMatcherBuilder {
LinalgOp op;
MLIRContext *ctx;
SmallVector<int64_t> *dilations, *strides;
ArrayAttr indexingMaps;
bool matched = true;
public:
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank,
SmallVector<int64_t> *d, SmallVector<int64_t> *s)
: op(op), ctx(op->getContext()), dilations(d), strides(s),
indexingMaps(op.getIndexingMaps()) {
*dilations = SmallVector<int64_t>(spatialRank, 1);
*strides = SmallVector<int64_t>(spatialRank, 1);
}
/// Get affine dimension expression for dimension i.
AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
/// Build strided expression: base * stride[idx] + kernel * dilation[idx]
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
return base * (*strides)[idx] + kernel * (*dilations)[idx];
}
/// Match stride/dilation pattern for a spatial dimension.
/// Returns *this for method chaining.
ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim,
unsigned oDim, unsigned idx) {
if (matched) {
matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
(*dilations)[idx], (*strides)[idx]);
}
return *this;
}
/// Match expected indexing maps layout.
/// Returns *this for method chaining.
ConvMatcherBuilder &expectMaps(
std::initializer_list<SmallVector<AffineExpr>> maps) {
if (matched)
matched = convLayoutMatches(maps, indexingMaps, ctx);
return *this;
}
/// Match body pattern. This should be called last.
bool matchBody(bool zeroPointOffset = false) {
if (!matched)
return false;
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
zeroPointOffset);
}
};
```
**Example usage - Conv3DOp (before: ~55 lines, after: ~12 lines):**
```cpp
template <>
bool isaConvolutionOpOfType<linalg::Conv3DOp>(
LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
if (isa<linalg::Conv3DOp>(op)) return true;
assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
auto D = m.dim(0), H = m.dim(1), W = m.dim(2);
auto d = m.dim(3), h = m.dim(4), w = m.dim(5);
return m.matchStride(0, 0, 0, 0)
.matchStride(1, 1, 1, 1)
.matchStride(2, 2, 2, 2)
.expectMaps({/*in=*/{m.strided(D,d,0), m.strided(H,h,1), m.strided(W,w,2)},
/*filter=*/{d, h, w},
/*out=*/{D, H, W}})
.matchBody();
}
```
**Example - DepthwiseConv2DNhwcHwcQOp (quantized, before: ~55 lines, after: ~14 lines):**
```cpp
template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
auto N = m.dim(0), H = m.dim(1), W = m.dim(2), c = m.dim(3);
auto h = m.dim(4), w = m.dim(5);
return m.matchStride(1, 0, 1, 0)
.matchStride(2, 1, 2, 1)
.expectMaps({/*in=*/{N, m.strided(H,h,0), m.strided(W,w,1), c},
/*filter=*/{h, w, c},
/*inputScalar=*/{}, /*filterScalar=*/{},
/*out=*/{N, H, W, c}})
.matchBody(/*zeroPointOffset=*/true);
}
```
**Example - Conv2DNgchwGfchwOp (grouped conv, before: ~55 lines, after: ~14 lines):**
```cpp
template <>
bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
auto N = m.dim(0), G = m.dim(1), FG = m.dim(2);
auto H = m.dim(3), W = m.dim(4), C = m.dim(5);
auto h = m.dim(6), w = m.dim(7);
return m.matchStride(3, 3, 3, 0)
.matchStride(4, 4, 4, 1)
.expectMaps({/*in=*/{N, G, C, m.strided(H,h,0), m.strided(W,w,1)},
/*filter=*/{G, FG, C, h, w},
/*out=*/{N, G, FG, H, W}})
.matchBody();
}
```
**Benefits:**
| Metric | Before | After |
|--------|--------|-------|
| Lines per matcher | ~50-60 | ~12-15 |
| Total lines for 15 new ops | ~800 | ~200 |
| Readability | Low (wall of boilerplate) | High (declarative layout visible) |
| Error-prone copy-paste | High | Low |
| Works for all conv types | N/A | Yes (1D, 2D, 3D, regular, depthwise, quantized) |
https://github.com/llvm/llvm-project/pull/168362
More information about the Mlir-commits
mailing list