[Mlir-commits] [mlir] [mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator (PR #172294)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 15 05:29:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in https://github.com/arm/tosa-specification/commit/408a5e53f5a7357adef7121ba3cc88e2225d4231.
This includes:
- Operator definition
- Addition of the EXT_MXFP_CONV extension
- Verification logic for the operator
- Output shape inference for the operator
- Validation checks to ensure compliance with the TOSA specification.
---
Patch is 70.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172294.diff
17 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+7)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+5-3)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+10)
- (modified) mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp (+1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+309-79)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+13)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+52-2)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+12-1)
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+13-2)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+71)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+20)
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+13-3)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+48)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+12-1)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+132)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..e452723f193b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -512,6 +512,13 @@ extensionComplianceMap = {
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
+ {"tosa.conv2d_block_scaled",
+ {{{Extension::mxfp_conv},
+ {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.conv3d",
{{{Extension::int4},
{{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..421abc939b2e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
// MXFP : Microscaling formats.
+// MXFP_CONV : Microscaling format convolution.
//===----------------------------------------------------------------------===//
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;
def Tosa_EXT_INT64 : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_MXFP_CONV : I32EnumAttrCase<"mxfp_conv", 14>;
def Tosa_ExtensionAttr
@@ -281,16 +283,16 @@ def Tosa_ExtensionAttr
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
- Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+ Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV,
]> {
let extraClassDeclaration = [{
- static llvm::SmallVector<Extension, 13> getAllValues() {
+ static llvm::SmallVector<Extension, 14> getAllValues() {
return {
Extension::int16, Extension::int4, Extension::bf16,
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
Extension::variable, Extension::controlflow, Extension::doubleround,
Extension::inexactround, Extension::dynamic, Extension::mxfp,
- Extension::int64
+ Extension::int64, Extension::mxfp_conv
};
}
}];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 370ce8c161d0b..edd8f0fc266bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: conv2d_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> {
+ let summary = "Performs two dimensional convolution using block scaled tensors.";
+
+ let description = [{
+ Performs a 2D convolution over the given input data and scales, using
+ the weight data and scales. Implementations may choose to skip calculation
+ of multiplies in the padding area.
+ }];
+
+ let arguments = (ins
+ Tosa_MXFPDataTensor4D:$input_data,
+ Tosa_MXFPScaleTensor4D:$input_scale,
+ Tosa_MXFPDataTensor4D:$weight_data,
+ Tosa_MXFPScaleTensor4D:$weight_scale,
+ Tosa_Tensor1D:$bias,
+ Rank4TosaShape:$pad,
+ Rank2TosaShape:$stride,
+ Rank2TosaShape:$dilation,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_MXFP_CONV]>,
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..5c77bd701e416 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -149,6 +149,7 @@ class TosaProfileCompliance {
case Extension::fp8e5m2:
case Extension::fft:
case Extension::mxfp:
+ case Extension::mxfp_conv:
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 266a9e3a7d946..0468ca29e10ac 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
def Tosa_TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+def Tosa_IndexTensor1D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>;
def Tosa_IndexTensor2D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
@@ -216,6 +218,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
]>;
+def Tosa_MXFPDataTensor4D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+ TosaTensorRankOf<[Tosa_MXFPNumber], [4]>
+]>;
+def Tosa_MXFPScaleTensor4D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+ TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]>
+]>;
def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..2e0a0d85d7dbe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
return TosaSpecificationVersion(1, 0);
case Extension::mxfp:
case Extension::int64:
+ case Extension::mxfp_conv:
return TosaSpecificationVersion(1, 1);
case Extension::none:
return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..6382c28ed4312 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -550,6 +550,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}
+ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -612,6 +621,55 @@ unsigned mlir::tosa::getBitWidth(Type type) {
return type.getIntOrFloatBitWidth();
}
+// Update dim size if current dim is dynamic, otherwise raise an error if sizes
+// do not match
+LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
+ const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return op->emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+}
+
+LogicalResult verifyConvOutputSize(
+ Operation *op, const int64_t inputSize, const int64_t kernelSize,
+ const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
+ const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
+ const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
+ const llvm::StringRef padAfterName) {
+ if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
+ return success();
+
+ // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+ const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+ stride);
+ if (!calculatedOutSizeMinusOne.has_value())
+ return op->emitOpError("expected input_")
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+ << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+ << dimAxis << " to be wholly divisible by stride_" << dimAxis
+ << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+ << padAfter << " - (" << kernelSize << " - 1) * " << dilation
+ << ") / " << stride;
+
+ const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+ if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+ return op->emitOpError("calculated output ")
+ << dimName << " did not match expected: "
+ << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
@@ -791,53 +849,16 @@ static LogicalResult verifyConvOpErrorIf(T op) {
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
if (inputType && weightType) {
- const auto verifyOutputSize =
- [&op](const int64_t inputSize, const int64_t kernelSize,
- const int64_t outputSize, const int64_t padBefore,
- const int64_t padAfter, const int64_t stride,
- const int64_t dilation, const llvm::StringRef dimName,
- const llvm::StringRef dimAxis,
- const llvm::StringRef padBeforeName,
- const llvm::StringRef padAfterName) -> LogicalResult {
- if (inputSize == ShapedType::kDynamic ||
- kernelSize == ShapedType::kDynamic)
- return success();
-
- // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
-
- const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
- inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
- stride);
- if (!calculatedOutSizeMinusOne.has_value())
- return op.emitOpError("expected input_")
- << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
- << padAfterName << " - (kernel_" << dimName
- << " - 1) * dilation_" << dimAxis
- << " to be wholly divisible by stride_" << dimAxis << ", got ("
- << inputSize << " - 1 + " << padBefore << " + " << padAfter
- << " - (" << kernelSize << " - 1) * " << dilation << ") / "
- << stride;
-
- const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
- if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
- return op.emitOpError("calculated output ")
- << dimName << " did not match expected: "
- << "calculated=" << calculatedOutSize
- << ", expected=" << outputSize;
-
- return success();
- };
-
// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(2),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
@@ -845,14 +866,14 @@ static LogicalResult verifyConvOpErrorIf(T op) {
// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(0),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(0),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(1),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
@@ -860,20 +881,20 @@ static LogicalResult verifyConvOpErrorIf(T op) {
// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "depth", "d", "front", "back")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(2),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(3), weightType.getDimSize(3),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(3), weightType.getDimSize(3),
outputType.getDimSize(3), padding[4], padding[5], strides[2],
dilations[2], "width", "x", "left", "right")))
return failure();
@@ -1954,20 +1975,6 @@ LogicalResult MatmulTBlockScaledOp::verify() {
"B_data")))
return failure();
- auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
- const StringRef operandName,
- const StringRef dimName) -> LogicalResult {
- if (ShapedType::isDynamic(currDim)) {
- currDim = newDim;
- return success();
- } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
- return emitOpError("expected ")
- << dimName << " of " << operandName << " to match size " << currDim
- << ", got " << newDim;
- }
- return success();
- };
-
// Verify input shape compatibility
int64_t N = ShapedType::kDynamic;
int64_t D = ShapedType::kDynamic;
@@ -1985,32 +1992,33 @@ LogicalResult MatmulTBlockScaledOp::verify() {
const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
if (aScaleShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
- "batch")) ||
- failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
- "height")))
+ if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
+ "a_scale", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
+ "a_scale", "height")))
return failure();
multiplesOfC = aScaleShape.getDimSize(2);
}
const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
if (bDataShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
- "batch")) ||
- failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
- "channels")))
+ if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
+ "b_data", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
+ "b_data", "channels")))
return failure();
W = bDataShape.getDimSize(1);
}
const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
if (bScaleShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
- "batch")) ||
- failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
- "width")) ||
- failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
- "b_scale", "C/block_size")))
+ if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
+ "b_scale", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
+ "b_scale", "width")) ||
+ failed(tryUpdateDimOrFailure(*this, multiplesOfC,
+ bScaleShape.getDimSize(2), "b_scale",
+ "C/block_size")))
return failure();
}
@@ -3485,6 +3493,228 @@ LogicalResult Conv2DOp::verify() {
return success();
}
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ Conv2DBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+ int64_t inputWidth = ShapedType::kDynamic;
+ int64_t inputHeight = ShapedType::kDynamic;
+ int64_t weightWidth = ShapedType::kDynamic;
+ int64_t weightHeight = ShapedType::kDynamic;
+
+ // Input shape describes input width/height and batch.
+ const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+ if (inputDataShape.hasRank()) {
+ outShape[0] = inputDataShape.getDimSize(0);
+ inputHeight = inputDataShape.getDimSize(1);
+ inputWidth = inputDataShape.getDimSize(2);
+ }
+ const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+ if (inputScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0])
+ ? inputScaleShape.getDimSize(0)
+ : outShape[0];
+ inputHeight = ShapedType::isDynamic(inputHeight)
+ ? ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172294
More information about the Mlir-commits
mailing list