[Mlir-commits] [mlir] [mlir][tosa] Refactor convolution infer return type (PR #178869)
Iliyan Georgiev
llvmlistbot at llvm.org
Tue Feb 10 08:52:32 PST 2026
================
@@ -3435,162 +3435,241 @@ static LogicalResult poolingInferReturnTypes(
return success();
}
-LogicalResult Conv2DOp::inferReturnTypeComponents(
- MLIRContext *context, ::std::optional<Location> location,
- Conv2DOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
+template <typename AdaptorT>
+class ConvInferShapeAdaptor;
- int64_t inputWidth = ShapedType::kDynamic;
- int64_t inputHeight = ShapedType::kDynamic;
- int64_t weightWidth = ShapedType::kDynamic;
- int64_t weightHeight = ShapedType::kDynamic;
-
- // Input shape describes input width/height and batch.
+class ConvInferShapeAdaptorBase {
+protected:
+ static void updateIfDynamic(int64_t ¤t, int64_t candidate) {
+ if (ShapedType::isDynamic(current))
+ current = candidate;
+ }
+};
- ShapeAdaptor inputShape(adaptor.getInput().getType());
- if (inputShape.hasRank()) {
+template <>
+class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
+ : public ConvInferShapeAdaptorBase {
+public:
+ explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
+ : adaptor(adaptor) {}
+
+ void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &inputSpatial) {
+ const ShapeAdaptor inputShape(adaptor.getInput().getType());
+ if (!inputShape.hasRank())
+ return;
outputShape[0] = inputShape.getDimSize(0);
- inputHeight = inputShape.getDimSize(1);
- inputWidth = inputShape.getDimSize(2);
+ inputSpatial[0] = inputShape.getDimSize(1);
+ inputSpatial[1] = inputShape.getDimSize(2);
----------------
iliyan-georgiev-arm wrote:
nit:
Could be worth leaving a comment here and in cases bellow so its clearer at a glance on why and how the indices correlate.
https://github.com/llvm/llvm-project/pull/178869
More information about the Mlir-commits
mailing list