[Mlir-commits] [mlir] [mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator (PR #172294)

Luke Hutton llvmlistbot at llvm.org
Thu Jan 15 10:29:49 PST 2026


================
@@ -3493,6 +3501,232 @@ LogicalResult Conv2DOp::verify() {
   return success();
 }
 
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    Conv2DBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+  int64_t inputWidth = ShapedType::kDynamic;
+  int64_t inputHeight = ShapedType::kDynamic;
+  int64_t weightWidth = ShapedType::kDynamic;
+  int64_t weightHeight = ShapedType::kDynamic;
+
+  // Input shape describes input width/height and batch.
+  const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+  if (inputDataShape.hasRank()) {
+    outShape[0] = inputDataShape.getDimSize(0);
+    inputHeight = inputDataShape.getDimSize(1);
+    inputWidth = inputDataShape.getDimSize(2);
+  }
+  const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+  if (inputScaleShape.hasRank()) {
+    outShape[0] = ShapedType::isDynamic(outShape[0])
+                      ? inputScaleShape.getDimSize(0)
+                      : outShape[0];
+    inputHeight = ShapedType::isDynamic(inputHeight)
+                      ? inputScaleShape.getDimSize(1)
+                      : inputHeight;
+    inputWidth = ShapedType::isDynamic(inputWidth)
+                     ? inputScaleShape.getDimSize(2)
+                     : inputWidth;
+  }
+
+  // Weight shapes describes the filter width/height and the output channels.
+  const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
+  if (weightDataShape.hasRank()) {
+    outShape[3] = weightDataShape.getDimSize(0);
+    weightHeight = weightDataShape.getDimSize(1);
+    weightWidth = weightDataShape.getDimSize(2);
+  }
+  const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
+  if (weightScaleShape.hasRank()) {
+    outShape[3] = ShapedType::isDynamic(outShape[3])
+                      ? weightScaleShape.getDimSize(0)
+                      : outShape[3];
+    weightHeight = ShapedType::isDynamic(weightHeight)
+                       ? weightScaleShape.getDimSize(1)
+                       : weightHeight;
+    weightWidth = ShapedType::isDynamic(weightWidth)
+                      ? weightScaleShape.getDimSize(2)
+                      : weightWidth;
+  }
+
+  // Bias shape can describe the output channels.
+  const ShapeAdaptor biasShape(adaptor.getBias().getType());
+  if (biasShape.hasRank()) {
+    const int64_t biasSize = biasShape.getDimSize(0);
+    // Bias of size 1 may be broadcast
+    if (biasSize != 1) {
+      outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
+    }
+  }
+
+  SmallVector<int64_t> padValues;
+  SmallVector<int64_t> strideValues;
+  SmallVector<int64_t> dilationValues;
+  if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
+      !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+                                 strideValues) ||
+      !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
+                                 dilationValues)) {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+    return success();
+  }
+
+  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
+    const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
+    const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
+    const int64_t unstridedResult = inputSize - filterSize + 1;
+    outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
+  }
+
+  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
+    const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
+    const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
+    const int64_t unstridedResult = inputSize - filterSize + 1;
+    outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
+  }
+
----------------
lhutton1 wrote:

Conv2dBlockScaled is slightly different due to `input_scale` and `weight_scale` parameters, but agree they could be consolidated. Would you mind if I did this in a follow up patch? There might be some overlap with Conv3D, etc as well

https://github.com/llvm/llvm-project/pull/172294


More information about the Mlir-commits mailing list