[Mlir-commits] [mlir] [mlir][tosa] `StringBasedAttr` TOSA enumerations to `Tosa_I32EnumAttr` (PR #152856)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 9 06:37:10 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Annu Singh (AnnuCode)
<details>
<summary>Changes</summary>
Fixes #<!-- -->152129
---
Patch is 157.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152856.diff
29 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+30-1)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+7-7)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-23)
- (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+8-6)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-18)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+9-7)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+8-3)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3-3)
- (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir (+1-1)
- (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir (+4-4)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+1-1)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-4)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+16-16)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+42-42)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+14-14)
- (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+15-15)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+28-28)
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+7-7)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+8-8)
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1)
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir (+4-4)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+10-10)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-valid.mlir (+2-2)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+1-1)
- (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+3-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..2aafed26a4e29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -207,7 +207,7 @@ def Tosa_VariableOpBuilder : OpBuilder<
}]>;
-// Wrapper over base I32EnumAttr to set common fields.
+ // Wrapper over base I32EnumAttr to set common fields.
class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
@@ -276,6 +276,7 @@ def Tosa_ProfileAttr
def Tosa_ProfileArrayAttr
: TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -381,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_BILINEAR : I32EnumAttrCase<"BILINEAR", 1>;
+def Tosa_RESIZE_NEAREST_NEIGHBOR : I32EnumAttrCase<"NEAREST_NEIGHBOR", 2>;
+
+def Tosa_ResizeTypeAttr
+ : Tosa_I32EnumAttr<"ResizeType", "Supported resize/upsampling strategies", "resize_type",
+ [Tosa_RESIZE_BILINEAR, Tosa_RESIZE_NEAREST_NEIGHBOR]>;
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationAttr
+ : Tosa_I32EnumAttr<"NanPropagation", "Supported NaN propagation strategies", "nan",
+ [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;
+
+def Tosa_ROUNDING_SINGLE_ROUND : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingTypeAttr
+ : Tosa_I32EnumAttr<"RoundingType", "Supported rounding modes", "rounding_type",
+ [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
+
+
//===----------------------------------------------------------------------===//
// TOSA Interfaces.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..fdb8f472dc060 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let arguments = (ins
Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$pad,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
Tosa_Tensor:$input,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Tensor:$input2,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
let arguments = (ins
Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
- DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640dca6561..b21ce51eb03b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,11 +13,13 @@
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE
+
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
@@ -234,29 +236,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
- "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
- "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
- CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
- "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
- "Supported rounding modes">;
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725c7d805..4a027ccdadd61 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingType roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingType::DOUBLE_ROUND &&
+ roundingMode != RoundingType::SINGLE_ROUND) {
return failure();
}
@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
- StringRef roundingMode = op.getRoundingMode();
- if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ RoundingType roundingMode = op.getRoundingMode();
+ if (roundingMode != RoundingType::DOUBLE_ROUND &&
+ roundingMode != RoundingType::SINGLE_ROUND) {
return failure();
}
@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getRoundingMode() == "DOUBLE_ROUND") {
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..35deeb43f51c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
return result;
auto nanMode = op.getNanMode();
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagation::PROPAGATE)
return result;
// Unordered comparison of NaN against itself will always return true.
@@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
- auto result = tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getStringAttr("SINGLE_ROUND"));
+ auto roundingAttr = RoundingTypeAttr::get(rewriter.getContext(),
+ RoundingType::SINGLE_ROUND);
+ auto result =
+ tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+ b, shiftConst, roundingAttr);
if (elementTy.isInteger(32))
return result;
@@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// In the case of "PROPAGATE" semantics no compare and selection is
// required.
- if (nanMode == "PROPAGATE")
+ if (nanMode == NanPropagation::PROPAGATE)
return result;
// In the case of "IGNORE" semantics materialize a comparison
@@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
// NaN propagation has no meaning for non floating point types.
- if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+ if (isa<FloatType>(elementTy) &&
+ op.getNanMode() == NanPropagation::IGNORE) {
isNanIgnoreMode = true;
// Because the TOSA spec requires the result be NaN iff all elements in
// the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getRoundingMode() == "INEXACT_ROUND")
+ if (op.getRoundingMode() == RoundingType::INEXACT_ROUND)
return rewriter.notifyMatchFailure(
op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
"currently supported");
- if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+ if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// is ever true.
bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
+ op.getRoundingMode() == RoundingType::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ RoundingType roundingMode =
+ doubleRound ? RoundingType::DOUBLE_ROUND : RoundingType::SINGLE_ROUND;
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
auto input = op.getInput();
auto inputTy = cast<RankedTensorType>(input.getType());
auto resultTy = cast<RankedTensorType>(op.getType());
- const bool isBilinear = op.getMode() == "BILINEAR";
+ const bool isBilinear = op.getMode() == ResizeType::BILINEAR;
auto inputH = inputTy.getDimSize(1);
auto inputW = inputTy.getDimSize(2);
@@ -1585,7 +1587,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
op, "tosa.resize is not a pure 1x1->1x1 image operation");
// TODO(suderman): These string values should be declared the TOSA dialect.
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeType::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1785,7 +1788,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "unable to get dynamic dimensions of tosa.resize");
- if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+ if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+ op.getMode() != ResizeType::BILINEAR)
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
@@ -1890,7 +1894,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
}
- if (op.getMode() == "NEAREST_NEIGHBOR") {
+ if (op.getMode() == ResizeType::NEAREST_NEIGHBOR) {
auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1930,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
linalg::YieldOp::create(b, result);
} else {
// The mode here must be BILINEAR.
- assert(op.getMode() == "BILINEAR");
+ assert(op.getMode() == ResizeType::BILINEAR);
auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
@@ -2291,7 +2295,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
Value predicate;
if (isa<FloatType>(inElementTy)) {
- if (argmaxOp.getNanMode() == "IGNORE") {
+ if (argmaxOp.getNanMode() == NanPropagation::IGNORE) {
// Only update index & max value for non NaN values. If all
// values are NaNs, the initial index will be return which is 0.
predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca3768dd..0f738a848fcb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
dilationAttr);
rewriter.setInsertionPointAfter(op);
- StringRef nanMode = op.getNanMode();
+ NanPropagation nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
// NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
// we've already produced a named op we will just take its body and modify
// it to include the appropriate checks. If the current value is NaN the
// old value of pool will be taken otherwise we use the result.
- if (nanMode == "IGNORE") {
+ if (nanMode == NanPropagation::IGNORE) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
rewriter, loc, rewriter.getI8IntegerAttr(30));
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
- auto scaled =
- tosa::ApplyScaleOp::create(
- rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
- shift, rewriter.getStringAttr("SINGLE_ROUND"))
- .getResult();
+ auto roundingAttr = RoundingTypeAttr::get(
+ rewriter.getContext(), RoundingType::SINGLE_ROUND);
+
+ auto scaled = tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal,
+ multiplier, shift, roundingAttr)
+ .getResult();
// If we have quantization information we need to apply output
// zeropoint.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..8e27b267c83d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
// Check we have a valid NaN propagation combination.
const auto opNanMode = op.getNanMode();
const auto clampNanMode = clampOp.getNanMode();
- if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ if (opNanMode == NanPropagation::IGNORE &&
+ clampNanMode == NanPropagation::PROPAGATE)
return failure();
auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +637,14 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
}
}
+ auto newMode =
+ (opNanMode != clampNanMode) ? tosa::NanPropagation::IGNORE : opNanMode;
+
+ auto newModeAttr = NanPropagationAttr::get(rewriter.getContext(), newMode);
+
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
- rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
- : opNanMode));
+ newModeAttr);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..5c04874e494c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,13 +508,13...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/152856
More information about the Mlir-commits
mailing list