[Mlir-commits] [mlir] [MLIR] Allowing unsupported conv2d op to fail gracefully from vectorization (PR #130181)

Zhuoran Yin llvmlistbot at llvm.org
Tue Mar 11 07:42:08 PDT 2025


================
@@ -1939,6 +1939,127 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+namespace {
+bool isCastOfBlockArgument(Operation *op) {
+  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+         isa<BlockArgument>(op->getOperand(0));
+}
+
+// Returns true iff it is a valid conv/pooling op.
+// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
+// + yield) and rhs is not used) then it is the body of a pooling
+// If conv, check for single `mul` predecessor. The `mul` operands must be
+// block arguments or extension of block arguments.
+// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
+// must be block arguments or extension of block arguments.
+enum OperKind { Conv, Pool };
+bool getOperKind(Operation *reduceOp, OperKind &oper) {
+  int numBlockArguments =
+      llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
+
+  switch (numBlockArguments) {
+  case 1: {
+    // Will be convolution if feeder is a MulOp.
+    // A strength reduced version of MulOp for i1 type is AndOp which is also
+    // supported. Otherwise, it can be pooling. This strength reduction logic
+    // is in `buildBinaryFn` helper in the Linalg dialect.
+    auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+                                       llvm::IsaPred<BlockArgument>);
+    Operation *feedOp = (*feedValIt).getDefiningOp();
+    // llvm::outs() << "feedOp: " << *feedOp << "\n";
+    if (isCastOfBlockArgument(feedOp)) {
+      oper = Pool;
+    } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+                  (isa<arith::AndIOp>(feedOp) &&
+                   feedOp->getResultTypes()[0].isInteger(1))) &&
+                 llvm::all_of(feedOp->getOperands(), [](Value v) {
+                   if (isa<BlockArgument>(v))
+                     return true;
+                   if (Operation *op = v.getDefiningOp())
+                     return isCastOfBlockArgument(op);
+                   return false;
+                 }))) {
+      return false;
+    }
+    return true;
+  }
+  case 2:
+    // Must be pooling
+    oper = Pool;
+    return true;
+  default:
+    return false;
+  }
+}
+
+bool isSupportedPoolKind(vector::CombiningKind kind) {
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+  case vector::CombiningKind::MAXNUMF:
+  case vector::CombiningKind::MAXIMUMF:
+  case vector::CombiningKind::MAXSI:
+  case vector::CombiningKind::MAXUI:
+  case vector::CombiningKind::MINNUMF:
+  case vector::CombiningKind::MINIMUMF:
+  case vector::CombiningKind::MINSI:
+  case vector::CombiningKind::MINUI:
+    return true;
+  default:
+    return false;
+  }
+}
+} // namespace
+
+static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
+  if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
+    return failure();
+
+  auto lhsShaped = convOp.getDpsInputOperand(0)->get();
+  auto rhsShaped = convOp.getDpsInputOperand(1)->get();
+  auto resShaped = convOp.getDpsInitOperand(0)->get();
+  auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
+  auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
+  auto resShapedType = dyn_cast<ShapedType>(resShaped.getType());
+  if (!lhsShapedType || !rhsShapedType || !resShapedType)
+    return failure();
+  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
+  // (non-channeled convolution -> LHS and RHS both have single dimensions).
+  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
+      (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
+    return failure();
+
+  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
+  if (!reduceOp)
+    return failure();
+
+  OperKind oper = Conv;
+  if (!getOperKind(reduceOp, oper))
+    return failure();
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  // Typically convolution will have a `Add` CombiningKind but for i1 type it
+  // can get strength reduced to `OR` which is also supported. This strength
+  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
+  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                      *maybeKind != vector::CombiningKind::OR) &&
+                     (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
+    return failure();
+  }
+
+  auto rhsRank = rhsShapedType.getRank();
+  switch (oper) {
+  case Conv:
+    if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
+      return failure();
+    break;
----------------
jerryyin wrote:

Taking a second look, I've decided that switch case is an overkill so just make it if/else instead.

https://github.com/llvm/llvm-project/pull/130181


More information about the Mlir-commits mailing list