[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)
Abhishek Varma
llvmlistbot at llvm.org
Fri Nov 7 12:53:32 PST 2025
================
@@ -240,6 +240,555 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
+/// ext* ops.
+static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+ BlockArgument blockArg;
+ if (!(blockArg = dyn_cast<BlockArgument>(val))) {
+ Operation *defOp = val.getDefiningOp();
+ if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
+ !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
+ !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+ return nullptr;
+ }
+ blockArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ }
+ return blockArg;
+}
+
+/// Utility to match block body for matmul-like ops.
+static bool bodyMatcherForMatmulLikeOps(Value yieldVal, Block *body) {
+ Operation *addOp = yieldVal.getDefiningOp();
+ if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
+ return false;
+
+ Operation *mulOp = addOp->getOperand(1).getDefiningOp();
+ if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+ return false;
+
+ BlockArgument lhsBlockArg =
+ getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+ BlockArgument rhsBlockArg =
+ getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+ BlockArgument outBlockArg =
+ getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+ if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
+ lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
+ outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
+ rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
+ return false;
+ return true;
+}
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
+
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
+ rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
+ rhsArg.getArgNumber() != 0)
+ return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
+}
+
+// max_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+// https://github.com/llvm/llvm-project/issues/164800
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
+}
+
+// min_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+// https://github.com/llvm/llvm-project/issues/164800
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+ uint32_t mapIndex, uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+/// Check if `expr` is either:
+/// - a dimension expr alone (implying multiplication by 1), or
+/// - a multiplication of dimension expr by any positive constant != 1
+/// In both cases we will capture the dimension expression into `dim` and
+/// return the constant multiplier. Returns -1 in case of a match failure.
+static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) {
+ if ((dim = dyn_cast<AffineDimExpr>(expr)))
+ return 1;
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return -1;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ AffineConstantExpr cst = nullptr;
+ if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
+ (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
+ ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
+ (cst = dyn_cast<AffineConstantExpr>(lhs))))
+ return cst.getValue();
+ return -1;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following
+/// commutatively:-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <c0> +
+/// indexingMaps[n-1].getResult(oDim) * <c1>
+/// where,
+/// - c0 and c1 can be any constant,
+/// - n is the size of the indexingMaps' array,
+/// - 0, 1 and n-1 are input, filter and output map indices respectively,
+/// - iDim, fDim and oDim are the input, filter and output dimension
+/// indices in their respective indexing maps
+/// Example:
+/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
+/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
+/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+///
+/// Here,
+/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
+/// Therefore,
+/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
+/// would return true and update dilation = 3 and stride = 2
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
+ unsigned inputMapIdx = 0, filterMapIdx = 1,
+ outputMapIdx = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
+ int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
+
+ if (c0 != -1 && c1 != -1) {
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ } else if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[aIndex].getResult(aDim) ==
+/// indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+ unsigned aDim, unsigned bIndex,
+ unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+ getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+ ArrayRef<int64_t> expectedSizes) {
+ if (indexingMaps.size() != expectedSizes.size())
+ return false;
+
+ for (auto [indexingMap, expectedSize] :
+ llvm::zip_equal(indexingMaps, expectedSizes)) {
+ auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+ if (affineMap.getNumResults() != expectedSize)
+ return false;
+ }
+ return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static void updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides,
+ ArrayRef<int64_t> tempDilations,
+ ArrayRef<int64_t> tempStrides) {
+ if (!(dilations && strides))
+ return;
+ for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+ dilations->push_back(dilation);
+ strides->push_back(stride);
+ }
+ return;
+}
+
+// ---------------------------------------------
+// Matchers for specific convolution operation.
+// ---------------------------------------------
+
+// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
+// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
+// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, filterMapIdx, 1) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, outputMapIdx, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ bodyMatcherForMatmulLikeOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
+ return returnVal;
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, filterMapIdx, 0) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, outputMapIdx, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ bodyMatcherForMatmulLikeOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
+ return returnVal;
+}
+
+// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (N, D + d, H + h, W + w, C)>
+// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (d, h, w, C, CM)>
+// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C)
+// -> (N, D, H, W, C, CM)>
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, filterMapIdx, 3) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, outputMapIdx, 4) &&
+ matchConvDimExprPattern(indexingMaps, filterMapIdx, 4, outputMapIdx,
+ 5) &&
+ bodyMatcherForMatmulLikeOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
+ return returnVal;
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
+ return returnVal;
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
+ return returnVal;
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcSumOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
+ return returnVal;
+}
+
+// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)>
+// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned inputMapIdx = 0, outputMapIdx = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) &&
+ bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+ if (returnVal)
+ updateConvDilationsAndStrides(dilations, strides, tempDilations,
+ tempStrides);
----------------
Abhishek-Varma wrote:
Done.
https://github.com/llvm/llvm-project/pull/163724
More information about the Mlir-commits
mailing list