[Mlir-commits] [mlir] [MLIR] Refactor to create vectorization convOp precondition check (PR #130181)
Zhuoran Yin
llvmlistbot at llvm.org
Wed Mar 12 12:30:10 PDT 2025
================
@@ -1939,6 +1939,132 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+namespace {
+enum class ConvOperationKind { Conv, Pool };
+} // namespace
+
+static bool isCastOfBlockArgument(Operation *op) {
+ return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+ isa<BlockArgument>(op->getOperand(0));
+}
+
+// Returns the ConvOperationKind of the op using reduceOp of the generic
+// payload. If it is neither a convolution nor a pooling, it returns
+// std::nullopt.
+//
+// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
+// + yield) and rhs is not used) then it is the body of a pooling
+// If conv, check for single `mul` predecessor. The `mul` operands must be
+// block arguments or extension of block arguments.
+// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
+// must be block arguments or extension of block arguments.
+static std::optional<ConvOperationKind>
+getConvOperationKind(Operation *reduceOp) {
+ int numBlockArguments =
+ llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
+
+ switch (numBlockArguments) {
+ case 1: {
+ // Will be convolution if feeder is a MulOp.
+ // A strength reduced version of MulOp for i1 type is AndOp which is also
+ // supported. Otherwise, it can be pooling. This strength reduction logic
+ // is in `buildBinaryFn` helper in the Linalg dialect.
+ auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+ llvm::IsaPred<BlockArgument>);
+ Operation *feedOp = (*feedValIt).getDefiningOp();
+ if (isCastOfBlockArgument(feedOp)) {
+ return ConvOperationKind::Pool;
+ }
+
+ if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+ (isa<arith::AndIOp>(feedOp) &&
+ feedOp->getResultTypes()[0].isInteger(1))) &&
+ llvm::all_of(feedOp->getOperands(), [](Value v) {
+ if (isa<BlockArgument>(v))
+ return true;
+ if (Operation *op = v.getDefiningOp())
+ return isCastOfBlockArgument(op);
+ return false;
+ }))) {
+ return std::nullopt;
+ }
+
+ return ConvOperationKind::Conv;
+ }
+ case 2:
+ // Must be pooling
+ return ConvOperationKind::Pool;
+ default:
+ return std::nullopt;
+ }
+}
+
+static bool isSupportedPoolKind(vector::CombiningKind kind) {
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ case vector::CombiningKind::MAXNUMF:
+ case vector::CombiningKind::MAXIMUMF:
+ case vector::CombiningKind::MAXSI:
+ case vector::CombiningKind::MAXUI:
+ case vector::CombiningKind::MINNUMF:
+ case vector::CombiningKind::MINIMUMF:
+ case vector::CombiningKind::MINSI:
+ case vector::CombiningKind::MINUI:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
+ if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
+ return failure();
+
+ auto lhsShaped = convOp.getDpsInputOperand(0)->get();
+ auto rhsShaped = convOp.getDpsInputOperand(1)->get();
+ auto resShaped = convOp.getDpsInitOperand(0)->get();
----------------
jerryyin wrote:
Those are actually used in a few places like:
- Conv1DGenerator::conv()
- Conv1DGenerator::depthwiseconv()
It sure was a decision to store all those fields instead of the linalgOp :-p
https://github.com/llvm/llvm-project/pull/130181
More information about the Mlir-commits
mailing list