[Mlir-commits] [mlir] [mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator (PR #172294)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 15 05:29:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in https://github.com/arm/tosa-specification/commit/408a5e53f5a7357adef7121ba3cc88e2225d4231.

This includes:
- Operator definition
- Addition of the EXT_MXFP_CONV extension
- Verification logic for the operator
- Output shape inference for the operator
- Validation checks to ensure compliance with the TOSA specification.

---

Patch is 70.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172294.diff


17 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+7) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+5-3) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+10) 
- (modified) mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp (+1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+309-79) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+13) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+52-2) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+12-1) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+13-2) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+71) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+20) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+13-3) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+48) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+12-1) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+132) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..e452723f193b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -512,6 +512,13 @@ extensionComplianceMap = {
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
          SpecificationVersion::V_1_0}}}}},
+    {"tosa.conv2d_block_scaled",
+     {{{Extension::mxfp_conv},
+       {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.conv3d",
      {{{Extension::int4},
        {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..421abc939b2e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 // DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 // MXFP         : Microscaling formats.
+// MXFP_CONV    : Microscaling format convolution.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 def Tosa_EXT_MXFP         : I32EnumAttrCase<"mxfp", 12>;
 def Tosa_EXT_INT64        : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_MXFP_CONV    : I32EnumAttrCase<"mxfp_conv", 14>;
 
 
 def Tosa_ExtensionAttr
@@ -281,16 +283,16 @@ def Tosa_ExtensionAttr
       Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
       Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
       Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
-      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV,
     ]> {
   let extraClassDeclaration = [{
-    static llvm::SmallVector<Extension, 13> getAllValues() {
+    static llvm::SmallVector<Extension, 14> getAllValues() {
       return {
         Extension::int16, Extension::int4, Extension::bf16,
         Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
         Extension::variable, Extension::controlflow, Extension::doubleround,
         Extension::inexactround, Extension::dynamic, Extension::mxfp,
-        Extension::int64
+        Extension::int64, Extension::mxfp_conv
       };
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 370ce8c161d0b..edd8f0fc266bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: conv2d_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> {
+  let summary = "Performs two dimensional convolution using block scaled tensors.";
+
+  let description = [{
+    Performs a 2D convolution over the given input data and scales, using
+    the weight data and scales. Implementations may choose to skip calculation
+    of multiplies in the padding area.
+  }];
+
+  let arguments = (ins
+    Tosa_MXFPDataTensor4D:$input_data,
+    Tosa_MXFPScaleTensor4D:$input_scale,
+    Tosa_MXFPDataTensor4D:$weight_data,
+    Tosa_MXFPScaleTensor4D:$weight_scale,
+    Tosa_Tensor1D:$bias,
+    Rank4TosaShape:$pad,
+    Rank2TosaShape:$stride,
+    Rank2TosaShape:$dilation,
+    Tosa_BlockSizeAttr:$block_size
+  );
+
+  let results = (outs
+    Tosa_Tensor4D:$output
+  );
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_MXFP_CONV]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: conv3d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..5c77bd701e416 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -149,6 +149,7 @@ class TosaProfileCompliance {
     case Extension::fp8e5m2:
     case Extension::fft:
     case Extension::mxfp:
+    case Extension::mxfp_conv:
       return {Profile::pro_fp};
     case Extension::variable:
     case Extension::controlflow:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 266a9e3a7d946..0468ca29e10ac 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
 def Tosa_TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
+def Tosa_IndexTensor1D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>;
 def Tosa_IndexTensor2D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
 
@@ -216,6 +218,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
   TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
 ]>;
+def Tosa_MXFPDataTensor4D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPNumber], [4]>
+]>;
+def Tosa_MXFPScaleTensor4D : AnyTypeOf<[
+  TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+  TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]>
+]>;
 def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
   TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
   TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..2e0a0d85d7dbe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
     return TosaSpecificationVersion(1, 0);
   case Extension::mxfp:
   case Extension::int64:
+  case Extension::mxfp_conv:
     return TosaSpecificationVersion(1, 1);
   case Extension::none:
     return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..6382c28ed4312 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -550,6 +550,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
   printWithEnumHandling(parser, *this);
 }
 
+ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
+  printWithEnumHandling(parser, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // Tosa utilities.
 //===----------------------------------------------------------------------===//
@@ -612,6 +621,55 @@ unsigned mlir::tosa::getBitWidth(Type type) {
   return type.getIntOrFloatBitWidth();
 }
 
+// Update dim size if current dim is dynamic, otherwise raise an error if sizes
+// do not match
+LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
+                                    const int64_t newDim,
+                                    const StringRef operandName,
+                                    const StringRef dimName) {
+  if (ShapedType::isDynamic(currDim)) {
+    currDim = newDim;
+    return success();
+  } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+    return op->emitOpError("expected ")
+           << dimName << " of " << operandName << " to match size " << currDim
+           << ", got " << newDim;
+  }
+  return success();
+}
+
+LogicalResult verifyConvOutputSize(
+    Operation *op, const int64_t inputSize, const int64_t kernelSize,
+    const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
+    const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
+    const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
+    const llvm::StringRef padAfterName) {
+  if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
+    return success();
+
+  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+      inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+      stride);
+  if (!calculatedOutSizeMinusOne.has_value())
+    return op->emitOpError("expected input_")
+           << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+           << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+           << dimAxis << " to be wholly divisible by stride_" << dimAxis
+           << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+           << padAfter << " - (" << kernelSize << " - 1) * " << dilation
+           << ") / " << stride;
+
+  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+    return op->emitOpError("calculated output ")
+           << dimName << " did not match expected: "
+           << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -791,53 +849,16 @@ static LogicalResult verifyConvOpErrorIf(T op) {
       llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
 
   if (inputType && weightType) {
-    const auto verifyOutputSize =
-        [&op](const int64_t inputSize, const int64_t kernelSize,
-              const int64_t outputSize, const int64_t padBefore,
-              const int64_t padAfter, const int64_t stride,
-              const int64_t dilation, const llvm::StringRef dimName,
-              const llvm::StringRef dimAxis,
-              const llvm::StringRef padBeforeName,
-              const llvm::StringRef padAfterName) -> LogicalResult {
-      if (inputSize == ShapedType::kDynamic ||
-          kernelSize == ShapedType::kDynamic)
-        return success();
-
-      // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
-
-      const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
-          inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
-          stride);
-      if (!calculatedOutSizeMinusOne.has_value())
-        return op.emitOpError("expected input_")
-               << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
-               << padAfterName << " - (kernel_" << dimName
-               << " - 1) * dilation_" << dimAxis
-               << " to be wholly divisible by stride_" << dimAxis << ", got ("
-               << inputSize << " - 1 + " << padBefore << " + " << padAfter
-               << " - (" << kernelSize << " - 1) * " << dilation << ") / "
-               << stride;
-
-      const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
-      if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
-        return op.emitOpError("calculated output ")
-               << dimName << " did not match expected: "
-               << "calculated=" << calculatedOutSize
-               << ", expected=" << outputSize;
-
-      return success();
-    };
-
     // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
     if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(1),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(2),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(2),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "width", "x", "left", "right")))
         return failure();
@@ -845,14 +866,14 @@ static LogicalResult verifyConvOpErrorIf(T op) {
 
     // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
     if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(0),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(0),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(1),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "width", "x", "left", "right")))
         return failure();
@@ -860,20 +881,20 @@ static LogicalResult verifyConvOpErrorIf(T op) {
 
     // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
     if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(1), weightType.getDimSize(1),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(1), weightType.getDimSize(1),
               outputType.getDimSize(1), padding[0], padding[1], strides[0],
               dilations[0], "depth", "d", "front", "back")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(2), weightType.getDimSize(2),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(2), weightType.getDimSize(2),
               outputType.getDimSize(2), padding[2], padding[3], strides[1],
               dilations[1], "height", "y", "top", "bottom")))
         return failure();
 
-      if (failed(verifyOutputSize(
-              inputType.getDimSize(3), weightType.getDimSize(3),
+      if (failed(verifyConvOutputSize(
+              op, inputType.getDimSize(3), weightType.getDimSize(3),
               outputType.getDimSize(3), padding[4], padding[5], strides[2],
               dilations[2], "width", "x", "left", "right")))
         return failure();
@@ -1954,20 +1975,6 @@ LogicalResult MatmulTBlockScaledOp::verify() {
                                     "B_data")))
     return failure();
 
-  auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
-                                   const StringRef operandName,
-                                   const StringRef dimName) -> LogicalResult {
-    if (ShapedType::isDynamic(currDim)) {
-      currDim = newDim;
-      return success();
-    } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
-      return emitOpError("expected ")
-             << dimName << " of " << operandName << " to match size " << currDim
-             << ", got " << newDim;
-    }
-    return success();
-  };
-
   // Verify input shape compatibility
   int64_t N = ShapedType::kDynamic;
   int64_t D = ShapedType::kDynamic;
@@ -1985,32 +1992,33 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
   if (aScaleShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
-                                     "height")))
+    if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
+                                     "a_scale", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
+                                     "a_scale", "height")))
       return failure();
     multiplesOfC = aScaleShape.getDimSize(2);
   }
 
   const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
   if (bDataShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
-                                     "channels")))
+    if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
+                                     "b_data", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
+                                     "b_data", "channels")))
       return failure();
     W = bDataShape.getDimSize(1);
   }
 
   const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
   if (bScaleShape.hasRank()) {
-    if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
-                                     "batch")) ||
-        failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
-                                     "width")) ||
-        failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
-                                     "b_scale", "C/block_size")))
+    if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
+                                     "b_scale", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
+                                     "b_scale", "width")) ||
+        failed(tryUpdateDimOrFailure(*this, multiplesOfC,
+                                     bScaleShape.getDimSize(2), "b_scale",
+                                     "C/block_size")))
       return failure();
   }
 
@@ -3485,6 +3493,228 @@ LogicalResult Conv2DOp::verify() {
   return success();
 }
 
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    Conv2DBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+  int64_t inputWidth = ShapedType::kDynamic;
+  int64_t inputHeight = ShapedType::kDynamic;
+  int64_t weightWidth = ShapedType::kDynamic;
+  int64_t weightHeight = ShapedType::kDynamic;
+
+  // Input shape describes input width/height and batch.
+  const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+  if (inputDataShape.hasRank()) {
+    outShape[0] = inputDataShape.getDimSize(0);
+    inputHeight = inputDataShape.getDimSize(1);
+    inputWidth = inputDataShape.getDimSize(2);
+  }
+  const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+  if (inputScaleShape.hasRank()) {
+    outShape[0] = ShapedType::isDynamic(outShape[0])
+                      ? inputScaleShape.getDimSize(0)
+                      : outShape[0];
+    inputHeight = ShapedType::isDynamic(inputHeight)
+                      ? ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172294


More information about the Mlir-commits mailing list