[Mlir-commits] [mlir] [NFC][Linalg] Follow-up on ConvMatchBuilder (PR #170080)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 30 23:33:10 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Abhishek Varma (Abhishek-Varma)
<details>
<summary>Changes</summary>
-- This commit addresses [follow-up review comments on 169704](https://github.com/llvm/llvm-project/pull/169704#pullrequestreview-3521785548).
-- Contains NFC nit/minor changes.
Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>
---
Full diff: https://github.com/llvm/llvm-project/pull/170080.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+85-71)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e85a2ab26bd32..01e6e1e248658 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
})));
}
-/// Enum of all kinds of Pooling Op's type.
-enum PoolingType {
- NONE,
- MAX_SIGNED,
- MAX_UNSIGNED,
- MIN_SIGNED,
- MIN_UNSIGNED,
- SUM
+/// Enum representing pooling operation types used by ConvMatcherBuilder.
+enum class PoolingType {
+ None,
+ MaxSigned,
+ MaxUnsigned,
+ MinSigned,
+ MinUnsigned,
+ Sum
};
/// Helper class for building convolution op matchers with minimal boilerplate.
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
/// as Pooling ops.
+///
+/// Usage: Create an instance with the op, spatial rank, and output pointers for
+/// extracted dilations/strides. Then chain matchStride() calls for each spatial
+/// dimension, followed by matchMaps() to verify indexing maps, and finally
+/// matchBody() to verify the operation body pattern.
+///
+/// The `matched` flag starts as `true` and is set to `false` if any match step
+/// fails. This allows chaining multiple match calls; once any match fails, all
+/// subsequent calls become no-ops and the final result is `false`.
+///
+/// The `dilations` and `strides` pointers are output parameters that get
+/// populated with the extracted dilation and stride values from the operation's
+/// indexing maps during matchStride() calls. These values are initially set to
+/// 1 for each spatial dimension and updated as patterns are matched.
class ConvMatcherBuilder {
LinalgOp op;
MLIRContext *ctx;
@@ -454,7 +468,7 @@ class ConvMatcherBuilder {
public:
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
SmallVector<int64_t> *s,
- PoolingType poolingType = PoolingType::NONE)
+ PoolingType poolingType = PoolingType::None)
: op(op), ctx(op->getContext()), dilations(d), strides(s),
indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
*dilations = SmallVector<int64_t>(spatialRank, 1);
@@ -474,16 +488,16 @@ class ConvMatcherBuilder {
ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
unsigned idx) {
if (matched) {
- matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
- (*dilations)[idx], (*strides)[idx]);
+ matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+ (*dilations)[idx], (*strides)[idx]);
}
return *this;
}
/// Match expected indexing maps layout. Returns *this for method chaining.
- ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+ ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
if (matched)
- matched = convLayoutMatches(maps, indexingMaps, ctx);
+ matched &= convLayoutMatches(maps, indexingMaps, ctx);
return *this;
}
@@ -494,17 +508,17 @@ class ConvMatcherBuilder {
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
switch (poolingType) {
- case PoolingType::NONE:
+ case PoolingType::None:
return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
- case PoolingType::MAX_SIGNED:
+ case PoolingType::MaxSigned:
return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::MAX_UNSIGNED:
+ case PoolingType::MaxUnsigned:
return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::MIN_SIGNED:
+ case PoolingType::MinSigned:
return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::MIN_UNSIGNED:
+ case PoolingType::MinUnsigned:
return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::SUM:
+ case PoolingType::Sum:
return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
}
return false;
@@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
AffineExpr w = m.dim(1);
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
- .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
- /*filterMap=*/{w},
- /*outputMap=*/{W}})
+ .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{W}})
.matchBody();
}
@@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
AffineExpr c = m.dim(4);
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
- /*filterMap=*/{w, c, F},
- /*outputMap=*/{N, W, F}})
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+ /*filterMap=*/{w, c, F},
+ /*outputMap=*/{N, W, F}})
.matchBody();
}
@@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
AffineExpr w = m.dim(4);
return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
- /*filterMap=*/{F, c, w},
- /*outputMap=*/{N, F, W}})
+ .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+ /*filterMap=*/{F, c, w},
+ /*outputMap=*/{N, F, W}})
.matchBody();
}
@@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
- .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{h, w},
- /*outputMap=*/{H, W}})
+ .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{H, W}})
.matchBody();
}
@@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
- .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2)},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{D, H, W}})
+ .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2)},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}})
.matchBody();
}
@@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
AffineExpr w = m.dim(3);
return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
- /*filterMap=*/{C, w},
- /*outputMap=*/{N, C, W}})
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}})
.matchBody();
}
@@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
AffineExpr w = m.dim(3);
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w, C},
- /*outputMap=*/{N, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}})
.matchBody();
}
@@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
AffineExpr w = m.dim(4);
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w, C, CM},
- /*outputMap=*/{N, W, C, CM}})
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}})
.matchBody();
}
@@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{C, h, w},
- /*outputMap=*/{N, C, H, W}})
+ .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{C, h, w},
+ /*outputMap=*/{N, C, H, W}})
.matchBody();
}
@@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w, C, CM},
- /*outputMap=*/{N, D, H, W, C, CM}})
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w, C, CM},
+ /*outputMap=*/{N, D, H, W, C, CM}})
.matchBody();
}
@@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MAX_SIGNED);
+ PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MIN_SIGNED);
+ PoolingType::MinSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::SUM);
+ PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MAX_UNSIGNED);
+ PoolingType::MaxUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MIN_UNSIGNED);
+ PoolingType::MinUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/170080
More information about the Mlir-commits
mailing list