[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)

Abhishek Varma llvmlistbot at llvm.org
Tue Nov 11 01:59:19 PST 2025


================
@@ -240,6 +240,615 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
   return iteratorType == utils::IteratorType::reduction;
 }
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
+/// ext* ops.
+static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+  BlockArgument blockArg;
+  if (!(blockArg = dyn_cast<BlockArgument>(val))) {
+    Operation *defOp = val.getDefiningOp();
+    if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
+        !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
+        !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      return nullptr;
+    }
+    blockArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+  }
+  return blockArg;
+}
+
+/// Utility to match block body for convolution ops.
+/// The body is thus expected to yield :-
+///     %out + (%lhs * %rhs)
+///   where: %lhs, %rhs and %out are block arguments and
+///          %lhs and %rhs can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+  Operation *addOp = yieldVal.getDefiningOp();
+  if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
+    return false;
+
+  Operation *mulOp = addOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+    return false;
+
+  BlockArgument lhsBlockArg =
+      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+  BlockArgument rhsBlockArg =
+      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+  if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
+      lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
+      outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
+      rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
+    return false;
+  return true;
+}
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+  Operation *defOp = yieldVal.getDefiningOp();
+  if (!(isa_and_present<OpTypes>(defOp) || ...))
+    return false;
+
+  BlockArgument lhsArg =
+      getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+  BlockArgument rhsArg =
+      getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
+      rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
+      rhsArg.getArgNumber() != 0)
+    return false;
+  return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+                                                                  body);
+}
+
+// max_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+//       https://github.com/llvm/llvm-project/issues/164800
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+                                                                  body);
+}
+
+// min_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+//       https://github.com/llvm/llvm-project/issues/164800
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+                                        uint32_t mapIndex, uint32_t dimIndex) {
+  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+  if (dimIndex < affineMap.getNumResults())
+    return affineMap.getResult(dimIndex);
+  return nullptr;
+}
+
+/// Check if `expr` is either:
+/// - a dimension expr alone (implying multiplication by 1), or
+/// - a multiplication of dimension expr by any positive constant != 1
+/// In both cases we will capture the dimension expression into `dim` and
+/// return the constant multiplier. Returns -1 in case of a match failure.
+static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) {
+  if ((dim = dyn_cast<AffineDimExpr>(expr)))
+    return 1;
+
+  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return -1;
+
+  AffineExpr lhs = mulExpr.getLHS();
+  AffineExpr rhs = mulExpr.getRHS();
+
+  AffineConstantExpr cst = nullptr;
+  if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
+       (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
+      ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
+       (cst = dyn_cast<AffineConstantExpr>(lhs))))
+    return cst.getValue();
+  return -1;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following
+/// commutatively:-
+///   indexingMaps[0].getResult(iDim) ==
+///         indexingMaps[1].getResult(fDim) * <c0> +
+///         indexingMaps[n-1].getResult(oDim) * <c1>
+///  where,
+///       - c0 and c1 can be any constant,
+///       - n is the size of the indexingMaps' array,
+///       - 0, 1 and n-1 are input, filter and output map indices respectively,
+///       - iDim, fDim and oDim are the input, filter and output dimension
+///         indices in their respective indexing maps
+///  Example:
+///   #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
+///                     -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
+///   #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+///   #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+///
+///   Here,
+///     #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
+///   Therefore,
+///     matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
+///     would return true and update dilation = 3 and stride = 2
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+                                       unsigned fDim, unsigned oDim,
+                                       int64_t &dilation, int64_t &stride) {
+  unsigned inputMapIdx = 0, filterMapIdx = 1,
+           outputMapIdx = indexingMaps.size() - 1;
+  AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
+  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+    return false;
+
+  AffineExpr dim0, dim1;
+  int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
+  int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
+
+  if (c0 != -1 && c1 != -1) {
+    // Pattern matched with dims and constants extracted.
+    AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
+    AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
+    if (dim0 == fExpr && dim1 == oExpr) {
+      dilation = c0;
+      stride = c1;
+      return true;
+    } else if (dim1 == fExpr && dim0 == oExpr) {
+      dilation = c1;
+      stride = c0;
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+                                       ArrayRef<int64_t> expectedSizes) {
+  if (indexingMaps.size() != expectedSizes.size())
+    return false;
+
+  for (auto [indexingMap, expectedSize] :
+       llvm::zip_equal(indexingMaps, expectedSizes)) {
+    auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+    if (affineMap.getNumResults() != expectedSize)
+      return false;
+  }
+  return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static void updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides,
+                                          ArrayRef<int64_t> tempDilations,
+                                          ArrayRef<int64_t> tempStrides) {
+  if (!(dilations && strides))
+    return;
+  for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+    dilations->push_back(dilation);
+    strides->push_back(stride);
+  }
----------------
Abhishek-Varma wrote:

Done.

https://github.com/llvm/llvm-project/pull/163724


More information about the Mlir-commits mailing list