[Mlir-commits] [mlir] [mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator (PR #172294)
Luke Hutton
llvmlistbot at llvm.org
Thu Jan 15 10:29:49 PST 2026
================
@@ -3493,6 +3501,232 @@ LogicalResult Conv2DOp::verify() {
return success();
}
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ Conv2DBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+ int64_t inputWidth = ShapedType::kDynamic;
+ int64_t inputHeight = ShapedType::kDynamic;
+ int64_t weightWidth = ShapedType::kDynamic;
+ int64_t weightHeight = ShapedType::kDynamic;
+
+ // Input shape describes input width/height and batch.
+ const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+ if (inputDataShape.hasRank()) {
+ outShape[0] = inputDataShape.getDimSize(0);
+ inputHeight = inputDataShape.getDimSize(1);
+ inputWidth = inputDataShape.getDimSize(2);
+ }
+ const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+ if (inputScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0])
+ ? inputScaleShape.getDimSize(0)
+ : outShape[0];
+ inputHeight = ShapedType::isDynamic(inputHeight)
+ ? inputScaleShape.getDimSize(1)
+ : inputHeight;
+ inputWidth = ShapedType::isDynamic(inputWidth)
+ ? inputScaleShape.getDimSize(2)
+ : inputWidth;
+ }
+
+ // Weight shapes describes the filter width/height and the output channels.
+ const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
+ if (weightDataShape.hasRank()) {
+ outShape[3] = weightDataShape.getDimSize(0);
+ weightHeight = weightDataShape.getDimSize(1);
+ weightWidth = weightDataShape.getDimSize(2);
+ }
+ const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
+ if (weightScaleShape.hasRank()) {
+ outShape[3] = ShapedType::isDynamic(outShape[3])
+ ? weightScaleShape.getDimSize(0)
+ : outShape[3];
+ weightHeight = ShapedType::isDynamic(weightHeight)
+ ? weightScaleShape.getDimSize(1)
+ : weightHeight;
+ weightWidth = ShapedType::isDynamic(weightWidth)
+ ? weightScaleShape.getDimSize(2)
+ : weightWidth;
+ }
+
+ // Bias shape can describe the output channels.
+ const ShapeAdaptor biasShape(adaptor.getBias().getType());
+ if (biasShape.hasRank()) {
+ const int64_t biasSize = biasShape.getDimSize(0);
+ // Bias of size 1 may be broadcast
+ if (biasSize != 1) {
+ outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
+ }
+ }
+
+ SmallVector<int64_t> padValues;
+ SmallVector<int64_t> strideValues;
+ SmallVector<int64_t> dilationValues;
+ if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
+ !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+ strideValues) ||
+ !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
+ dilationValues)) {
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+ }
+
+ if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
+ const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
+ const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
+ const int64_t unstridedResult = inputSize - filterSize + 1;
+ outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
+ }
+
+ if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
+ const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
+ const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
+ const int64_t unstridedResult = inputSize - filterSize + 1;
+ outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
+ }
+
----------------
lhutton1 wrote:
Conv2dBlockScaled is slightly different due to `input_scale` and `weight_scale` parameters, but agree they could be consolidated. Would you mind if I did this in a follow up patch? There might be some overlap with Conv3D, etc as well
https://github.com/llvm/llvm-project/pull/172294
More information about the Mlir-commits
mailing list