[Mlir-commits] [mlir] [Linalg] Add *Conv2D* matchers (PR #168362)

Han-Chung Wang llvmlistbot at llvm.org
Sun Dec 7 00:30:42 PST 2025


================
@@ -240,27 +240,91 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
   BlockArgument blockArg = dyn_cast<BlockArgument>(val);
   if ((blockArg))
     return blockArg;
 
   Operation *defOp = val.getDefiningOp();
   if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
       !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
-      !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+      !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
     return nullptr;
   }
   return dyn_cast<BlockArgument>(defOp->getOperand(0));
 }
 
+/// Utility function to match the zero point offset body of quantized
+/// convolution ops.
+///
+/// Quantized convolutions have a body of the form:
+///   %out + ((%input - %inputZp) * (%filter - %filterZp))
+/// where:
+///   - %input is the input tensor element (block arg 0)
+///   - %filter is the filter tensor element (block arg 1)
+///   - %inputZp is the input zero-point scalar (block arg 2)
+///   - %filterZp is the filter zero-point scalar (block arg 3)
+///   - %out is the output accumulator (block arg 4)
+///
+/// This function verifies that the multiplication operands are subtraction
+/// operations matching this pattern.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+                                           Block *body) {
+  // The multiplication should have two subtraction operands:
+  // one for (input - inputZp) and one for (filter - filterZp).
+  Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
+    return false;
+
+  Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
+    return false;
+
+  // Extract block arguments from subtraction operands.
+  BlockArgument inputBlockArg =
+      getBlockArgumentWithOptionalCastOps(inputSubOp->getOperand(0));
+  BlockArgument inputZpBlockArg =
+      getBlockArgumentWithOptionalCastOps(inputSubOp->getOperand(1));
+  BlockArgument filterBlockArg =
+      getBlockArgumentWithOptionalCastOps(filterSubOp->getOperand(0));
+  BlockArgument filterZpBlockArg =
+      getBlockArgumentWithOptionalCastOps(filterSubOp->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+
+  // Verify all block arguments are valid.
+  if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
+      !filterZpBlockArg || !outBlockArg)
+    return false;
+
+  // Verify all block arguments belong to the convolution body.
+  if (inputBlockArg.getOwner() != body || inputZpBlockArg.getOwner() != body ||
+      filterBlockArg.getOwner() != body ||
+      filterZpBlockArg.getOwner() != body || outBlockArg.getOwner() != body)
+    return false;
+
+  // Verify block arguments have expected indices:
+  // arg0: input, arg1: filter, arg2: inputZp, arg3: filterZp, arg4: output
+  if (inputBlockArg.getArgNumber() != 0 || filterBlockArg.getArgNumber() != 1 ||
+      inputZpBlockArg.getArgNumber() != 2 ||
+      filterZpBlockArg.getArgNumber() != 3 || outBlockArg.getArgNumber() != 4)
+    return false;
+
+  return true;
+}
+
 /// Utility to match block body for convolution ops.
 /// The body is thus expected to yield :-
 ///     %out + (%lhs * %rhs)
 ///   where: %lhs, %rhs and %out are block arguments and
 ///          %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+///       %input - %input_scalar
+///          where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+                                         bool zeroPointOffset = false) {
----------------
hanhanW wrote:

I just learned that Abhishek is out for a while, so I go ahead to address the comment. Let's land it before it becomes a stale PR?

https://github.com/llvm/llvm-project/pull/168362


More information about the Mlir-commits mailing list