[Mlir-commits] [mlir] [NFC][Linalg] Follow-up on ConvMatchBuilder (PR #170080)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 23:33:10 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

<details>
<summary>Changes</summary>

-- This commit addresses [follow-up review comments on 169704](https://github.com/llvm/llvm-project/pull/169704#pullrequestreview-3521785548).
-- Contains NFC nit/minor changes.

Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>

---
Full diff: https://github.com/llvm/llvm-project/pull/170080.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+85-71) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e85a2ab26bd32..01e6e1e248658 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
                           })));
 }
 
-/// Enum of all kinds of Pooling Op's type.
-enum PoolingType {
-  NONE,
-  MAX_SIGNED,
-  MAX_UNSIGNED,
-  MIN_SIGNED,
-  MIN_UNSIGNED,
-  SUM
+/// Enum representing pooling operation types used by ConvMatcherBuilder.
+enum class PoolingType {
+  None,
+  MaxSigned,
+  MaxUnsigned,
+  MinSigned,
+  MinUnsigned,
+  Sum
 };
 
 /// Helper class for building convolution op matchers with minimal boilerplate.
 /// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
 /// as Pooling ops.
+///
+/// Usage: Create an instance with the op, spatial rank, and output pointers for
+/// extracted dilations/strides. Then chain matchStride() calls for each spatial
+/// dimension, followed by matchMaps() to verify indexing maps, and finally
+/// matchBody() to verify the operation body pattern.
+///
+/// The `matched` flag starts as `true` and is set to `false` if any match step
+/// fails. This allows chaining multiple match calls; once any match fails, all
+/// subsequent calls become no-ops and the final result is `false`.
+///
+/// The `dilations` and `strides` pointers are output parameters that get
+/// populated with the extracted dilation and stride values from the operation's
+/// indexing maps during matchStride() calls. These values are initially set to
+/// 1 for each spatial dimension and updated as patterns are matched.
 class ConvMatcherBuilder {
   LinalgOp op;
   MLIRContext *ctx;
@@ -454,7 +468,7 @@ class ConvMatcherBuilder {
 public:
   ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
                      SmallVector<int64_t> *s,
-                     PoolingType poolingType = PoolingType::NONE)
+                     PoolingType poolingType = PoolingType::None)
       : op(op), ctx(op->getContext()), dilations(d), strides(s),
         indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
     *dilations = SmallVector<int64_t>(spatialRank, 1);
@@ -474,16 +488,16 @@ class ConvMatcherBuilder {
   ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
                                   unsigned idx) {
     if (matched) {
-      matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
-                                           (*dilations)[idx], (*strides)[idx]);
+      matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+                                            (*dilations)[idx], (*strides)[idx]);
     }
     return *this;
   }
 
   /// Match expected indexing maps layout. Returns *this for method chaining.
-  ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+  ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
     if (matched)
-      matched = convLayoutMatches(maps, indexingMaps, ctx);
+      matched &= convLayoutMatches(maps, indexingMaps, ctx);
     return *this;
   }
 
@@ -494,17 +508,17 @@ class ConvMatcherBuilder {
     Block *body = op.getBlock();
     auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
     switch (poolingType) {
-    case PoolingType::NONE:
+    case PoolingType::None:
       return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
-    case PoolingType::MAX_SIGNED:
+    case PoolingType::MaxSigned:
       return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::MAX_UNSIGNED:
+    case PoolingType::MaxUnsigned:
       return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::MIN_SIGNED:
+    case PoolingType::MinSigned:
       return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::MIN_UNSIGNED:
+    case PoolingType::MinUnsigned:
       return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::SUM:
+    case PoolingType::Sum:
       return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
     }
     return false;
@@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
   AffineExpr w = m.dim(1);
 
   return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
-                   /*filterMap=*/{w},
-                   /*outputMap=*/{W}})
+      .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{W}})
       .matchBody();
 }
 
@@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
   AffineExpr c = m.dim(4);
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
-                   /*filterMap=*/{w, c, F},
-                   /*outputMap=*/{N, W, F}})
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+                  /*filterMap=*/{w, c, F},
+                  /*outputMap=*/{N, W, F}})
       .matchBody();
 }
 
@@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
   AffineExpr w = m.dim(4);
 
   return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
-                   /*filterMap=*/{F, c, w},
-                   /*outputMap=*/{N, F, W}})
+      .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+                  /*filterMap=*/{F, c, w},
+                  /*outputMap=*/{N, F, W}})
       .matchBody();
 }
 
@@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
 
   return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
       .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{H, W}})
+      .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{H, W}})
       .matchBody();
 }
 
@@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
       .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
       .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
-      .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
-                                 m.strided(W, w, 2)},
-                   /*filterMap=*/{d, h, w},
-                   /*outputMap=*/{D, H, W}})
+      .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2)},
+                  /*filterMap=*/{d, h, w},
+                  /*outputMap=*/{D, H, W}})
       .matchBody();
 }
 
@@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   AffineExpr w = m.dim(3);
 
   return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
-                   /*filterMap=*/{C, w},
-                   /*outputMap=*/{N, C, W}})
+      .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+                  /*filterMap=*/{C, w},
+                  /*outputMap=*/{N, C, W}})
       .matchBody();
 }
 
@@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   AffineExpr w = m.dim(3);
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
-                   /*filterMap=*/{w, C},
-                   /*outputMap=*/{N, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w, C},
+                  /*outputMap=*/{N, W, C}})
       .matchBody();
 }
 
@@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
   AffineExpr w = m.dim(4);
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
-                   /*filterMap=*/{w, C, CM},
-                   /*outputMap=*/{N, W, C, CM}})
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w, C, CM},
+                  /*outputMap=*/{N, W, C, CM}})
       .matchBody();
 }
 
@@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
 
   return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
       .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
-                   /*filterMap=*/{C, h, w},
-                   /*outputMap=*/{N, C, H, W}})
+      .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                  /*filterMap=*/{C, h, w},
+                  /*outputMap=*/{N, C, H, W}})
       .matchBody();
 }
 
@@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
       .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
-      .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
-                                 m.strided(W, w, 2), C},
-                   /*filterMap=*/{d, h, w, C, CM},
-                   /*outputMap=*/{N, D, H, W, C, CM}})
+      .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2), C},
+                  /*filterMap=*/{d, h, w, C, CM},
+                  /*outputMap=*/{N, D, H, W, C, CM}})
       .matchBody();
 }
 
@@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MAX_SIGNED);
+                       PoolingType::MaxSigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MIN_SIGNED);
+                       PoolingType::MinSigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::SUM);
+                       PoolingType::Sum);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MAX_UNSIGNED);
+                       PoolingType::MaxUnsigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MIN_UNSIGNED);
+                       PoolingType::MinUnsigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/170080


More information about the Mlir-commits mailing list