[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)

Abhishek Varma llvmlistbot at llvm.org
Wed Dec 3 00:02:15 PST 2025


================
@@ -240,27 +240,71 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
   BlockArgument blockArg = dyn_cast<BlockArgument>(val);
   if ((blockArg))
     return blockArg;
 
   Operation *defOp = val.getDefiningOp();
   if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
       !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
-      !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+      !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
     return nullptr;
   }
   return dyn_cast<BlockArgument>(defOp->getOperand(0));
 }
 
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+///     %a - %b
+///   where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+                                           Block *body) {
+  Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+    return false;
+  Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+    return false;
+  BlockArgument inputBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+  BlockArgument inputScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+  BlockArgument filterBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+  BlockArgument filterScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+  if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+      !filterScalarBlockArg || !outBlockArg ||
+      inputBlockArg.getOwner() != body ||
+      inputScalarBlockArg.getOwner() != body ||
+      filterBlockArg.getOwner() != body ||
+      filterScalarBlockArg.getOwner() != body ||
+      outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+      inputScalarBlockArg.getArgNumber() != 2 ||
+      filterBlockArg.getArgNumber() != 1 ||
+      filterScalarBlockArg.getArgNumber() != 3 ||
+      outBlockArg.getArgNumber() != 4)
----------------
Abhishek-Varma wrote:

> Could you try breaking it down somehow? My preference would be multiple if statements with comments

Done.

> On a related note, it's not clear to me why we check this particular condition and why do we check so many arguments?

So, I added these checks to make the matching aspect of the linalg body to be extremely strict with their occurrence and usage. Hence matching the arguments with `input`, `filter`, `inputZp`, `filterZp` and `output`.
I'm not sure if I was able to explain any better. The code doc/inline comment might make it more apparent. Could you please let me know? 😅

> I don't see anything being matched here 🤔 What am I missing?

So I match the "matmul" aspect of the convolution body before calling this utility. So, this function checks the rest of the body of the zero point offset convolution. The following is one of the parts being matched here :-
```
  // The multiplication should have two subtraction operands:
  // one for (input - inputZp) and one for (filter - filterZp).
  Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp();
  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
    return false;

  Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp();
  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
    return false;
    ```
Let me know if I missed anything here and I'll do my best to address. :)

https://github.com/llvm/llvm-project/pull/168362


More information about the Mlir-commits mailing list