[Mlir-commits] [mlir] [mlir][tosa] `StringBasedAttr` TOSA enumerations to `Tosa_I32EnumAttr` (PR #152856)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 9 06:37:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Annu Singh (AnnuCode)

<details>
<summary>Changes</summary>

Fixes #<!-- -->152129 

---

Patch is 157.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152856.diff


29 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+30-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+7-7) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-23) 
- (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+8-6) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-18) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+9-7) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+8-3) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+3-3) 
- (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir (+1-1) 
- (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir (+4-4) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+1-1) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-4) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir (+16-16) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+42-42) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+14-14) 
- (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+4-4) 
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+15-15) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+28-28) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+7-7) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+4-4) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+8-8) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir (+4-4) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+10-10) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-valid.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+1-1) 
- (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+3-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index e048f8af7cc33..2aafed26a4e29 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -207,7 +207,7 @@ def Tosa_VariableOpBuilder : OpBuilder<
   }]>;
 
 
-// Wrapper over base I32EnumAttr to set common fields.
+ // Wrapper over base I32EnumAttr to set common fields.
 class Tosa_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
      : I32EnumAttr<name, description, cases> {
    let genSpecializedAttr = 0;
@@ -276,6 +276,7 @@ def Tosa_ProfileAttr
 def Tosa_ProfileArrayAttr
     : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
 
+
 // The base class for defining op availability dimensions.
 class Availability {
   // The following are fields for controlling the generated C++ OpInterface.
@@ -381,6 +382,34 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
   let instance = "ref";
 }
 
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
+def Tosa_RESIZE_BILINEAR          : I32EnumAttrCase<"BILINEAR", 1>;
+def Tosa_RESIZE_NEAREST_NEIGHBOR  : I32EnumAttrCase<"NEAREST_NEIGHBOR", 2>;
+
+def Tosa_ResizeTypeAttr
+    : Tosa_I32EnumAttr<"ResizeType", "Supported resize/upsampling strategies", "resize_type",
+                    [Tosa_RESIZE_BILINEAR, Tosa_RESIZE_NEAREST_NEIGHBOR]>;       
+
+def Tosa_NANPROPAGATION_PROPAGATE : I32EnumAttrCase<"PROPAGATE", 1>;
+def Tosa_NANPROPAGATION_IGNORE    : I32EnumAttrCase<"IGNORE", 2>;
+
+def Tosa_NanPropagationAttr
+    : Tosa_I32EnumAttr<"NanPropagation", "Supported NaN propagation strategies", "nan",
+                    [Tosa_NANPROPAGATION_PROPAGATE, Tosa_NANPROPAGATION_IGNORE]>;  
+
+def Tosa_ROUNDING_SINGLE_ROUND    : I32EnumAttrCase<"SINGLE_ROUND", 1>;
+def Tosa_ROUNDING_INEXACT_ROUND   : I32EnumAttrCase<"INEXACT_ROUND", 2>;
+def Tosa_ROUNDING_DOUBLE_ROUND    : I32EnumAttrCase<"DOUBLE_ROUND", 3>;
+
+def Tosa_RoundingTypeAttr
+    : Tosa_I32EnumAttr<"RoundingType", "Supported rounding modes", "rounding_type",
+                    [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;                         
+
+
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..fdb8f472dc060 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -43,7 +43,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D: $input,
     I32Attr: $axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -357,7 +357,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -487,7 +487,7 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
     Tosa_Tensor:$input,
     Tosa_IntOrFloatAttr:$min_val,
     Tosa_IntOrFloatAttr:$max_val,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -935,7 +935,7 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -964,7 +964,7 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -1711,7 +1711,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
@@ -1751,7 +1751,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   let arguments = (ins
     Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
-    DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
+    DefaultValuedAttr<Tosa_NanPropagationAttr, "::mlir::tosa::NanPropagation::PROPAGATE">:$nan_mode
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 754640dca6561..b21ce51eb03b1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,11 +13,13 @@
 #ifndef TOSA_TYPES_BASE
 #define TOSA_TYPES_BASE
 
+
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/OpBase.td"
 
 include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
 
+
 //===----------------------------------------------------------------------===//
 // Tosa Type Definitions.
 //===----------------------------------------------------------------------===//
@@ -234,29 +236,6 @@ def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
 
 def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;
 
-//===----------------------------------------------------------------------===//
-// Iterable attributes.
-//===----------------------------------------------------------------------===//
-// Defined in `section 3. Enumerations` of the TOSA specification.
-
-// Supported regimes for tosa.resize.
-def Tosa_ResizeTypeAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
-    "Supported resize/upsampling strategies">;
-
-// Supported NaN propagation strategies.
-def Tosa_NanPropagationAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
-    "Supported NaN propagation strategies">;
-
-// Rounding mode for tosa.rescale
-def Tosa_RoundingTypeAttr : StringBasedAttr<
-    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\"  || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
-          "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
-    "Supported rounding modes">;
 
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 044b725c7d805..4a027ccdadd61 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -64,8 +64,9 @@ class ApplyScaleGenericOpConverter
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
-    StringRef roundingMode = op.getRoundingMode();
-    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+    RoundingType roundingMode = op.getRoundingMode();
+    if (roundingMode != RoundingType::DOUBLE_ROUND &&
+        roundingMode != RoundingType::SINGLE_ROUND) {
       return failure();
     }
 
@@ -100,7 +101,7 @@ class ApplyScaleGenericOpConverter
     multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round);
 
     // Apply double rounding if necessary.
-    if (op.getRoundingMode() == "DOUBLE_ROUND") {
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
       int64_t roundInt = 1 << 30;
       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -129,8 +130,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
-    StringRef roundingMode = op.getRoundingMode();
-    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+    RoundingType roundingMode = op.getRoundingMode();
+    if (roundingMode != RoundingType::DOUBLE_ROUND &&
+        roundingMode != RoundingType::SINGLE_ROUND) {
       return failure();
     }
 
@@ -179,7 +181,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
         arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32);
 
     // Conditionally perform our double round.
-    if (op.getRoundingMode() == "DOUBLE_ROUND") {
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND) {
       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
       Value valuePositive = arith::CmpIOp::create(
           rewriter, loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0e3de067736c5..35deeb43f51c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
     return result;
 
   auto nanMode = op.getNanMode();
-  if (nanMode == "PROPAGATE")
+  if (nanMode == NanPropagation::PROPAGATE)
     return result;
 
   // Unordered comparison of NaN against itself will always return true.
@@ -156,9 +156,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         if (!b.getType().isInteger(32))
           b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
 
-        auto result = tosa::ApplyScaleOp::create(
-            rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
-            rewriter.getStringAttr("SINGLE_ROUND"));
+        auto roundingAttr = RoundingTypeAttr::get(rewriter.getContext(),
+                                                  RoundingType::SINGLE_ROUND);
+        auto result =
+            tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
+                                       b, shiftConst, roundingAttr);
 
         if (elementTy.isInteger(32))
           return result;
@@ -464,7 +466,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
     // In the case of "PROPAGATE" semantics no compare and selection is
     // required.
-    if (nanMode == "PROPAGATE")
+    if (nanMode == NanPropagation::PROPAGATE)
       return result;
 
     // In the case of "IGNORE" semantics materialize a comparison
@@ -1173,7 +1175,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
                 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
     // NaN propagation has no meaning for non floating point types.
-    if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
+    if (isa<FloatType>(elementTy) &&
+        op.getNanMode() == NanPropagation::IGNORE) {
       isNanIgnoreMode = true;
       // Because the TOSA spec requires the result be NaN iff all elements in
       // the reduction are NaN we can't simply perform a compare and select.
@@ -1336,11 +1339,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     unsigned rank = inputTy.getRank();
 
     // This is an illegal configuration. terminate and log an error
-    if (op.getRoundingMode() == "INEXACT_ROUND")
+    if (op.getRoundingMode() == RoundingType::INEXACT_ROUND)
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
               "currently supported");
-    if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
+    if (op.getRoundingMode() == RoundingType::DOUBLE_ROUND && !op.getScale32())
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
@@ -1386,11 +1389,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // is ever true.
 
     bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
+        op.getRoundingMode() == RoundingType::DOUBLE_ROUND &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    RoundingType roundingMode =
+        doubleRound ? RoundingType::DOUBLE_ROUND : RoundingType::SINGLE_ROUND;
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1573,7 +1575,7 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
     auto input = op.getInput();
     auto inputTy = cast<RankedTensorType>(input.getType());
     auto resultTy = cast<RankedTensorType>(op.getType());
-    const bool isBilinear = op.getMode() == "BILINEAR";
+    const bool isBilinear = op.getMode() == ResizeType::BILINEAR;
 
     auto inputH = inputTy.getDimSize(1);
     auto inputW = inputTy.getDimSize(2);
@@ -1585,7 +1587,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
           op, "tosa.resize is not a pure 1x1->1x1 image operation");
 
     // TODO(suderman): These string values should be declared the TOSA dialect.
-    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+    if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+        op.getMode() != ResizeType::BILINEAR)
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
@@ -1785,7 +1788,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       return rewriter.notifyMatchFailure(
           op, "unable to get dynamic dimensions of tosa.resize");
 
-    if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
+    if (op.getMode() != ResizeType::NEAREST_NEIGHBOR &&
+        op.getMode() != ResizeType::BILINEAR)
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
@@ -1890,7 +1894,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
       }
 
-      if (op.getMode() == "NEAREST_NEIGHBOR") {
+      if (op.getMode() == ResizeType::NEAREST_NEIGHBOR) {
         auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
 
         auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
@@ -1926,7 +1930,7 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
         linalg::YieldOp::create(b, result);
       } else {
         // The mode here must be BILINEAR.
-        assert(op.getMode() == "BILINEAR");
+        assert(op.getMode() == ResizeType::BILINEAR);
 
         auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
 
@@ -2291,7 +2295,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
 
           Value predicate;
           if (isa<FloatType>(inElementTy)) {
-            if (argmaxOp.getNanMode() == "IGNORE") {
+            if (argmaxOp.getNanMode() == NanPropagation::IGNORE) {
               // Only update index & max value for non NaN values. If all
               // values are NaNs, the initial index will be return which is 0.
               predicate = arith::CmpFOp::create(rewriter, nestedLoc,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 12d85ca3768dd..0f738a848fcb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -803,7 +803,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
         dilationAttr);
 
     rewriter.setInsertionPointAfter(op);
-    StringRef nanMode = op.getNanMode();
+    NanPropagation nanMode = op.getNanMode();
     rewriter.replaceOp(op, resultOp);
 
     // NaN propagation has no meaning for non floating point types.
@@ -817,7 +817,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
     // we've already produced a named op we will just take its body and modify
     // it to include the appropriate checks. If the current value is NaN the
     // old value of pool will be taken otherwise we use the result.
-    if (nanMode == "IGNORE") {
+    if (nanMode == NanPropagation::IGNORE) {
       auto genericOp = linalg::GenericOp::create(
           rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
           resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
@@ -1040,11 +1040,13 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
                 rewriter, loc, rewriter.getI8IntegerAttr(30));
             Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
 
-            auto scaled =
-                tosa::ApplyScaleOp::create(
-                    rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
-                    shift, rewriter.getStringAttr("SINGLE_ROUND"))
-                    .getResult();
+            auto roundingAttr = RoundingTypeAttr::get(
+                rewriter.getContext(), RoundingType::SINGLE_ROUND);
+
+            auto scaled = tosa::ApplyScaleOp::create(
+                              rewriter, loc, rewriter.getI32Type(), poolVal,
+                              multiplier, shift, roundingAttr)
+                              .getResult();
 
             // If we have quantization information we need to apply output
             // zeropoint.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba38871909..8e27b267c83d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,7 +555,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
     // Check we have a valid NaN propagation combination.
     const auto opNanMode = op.getNanMode();
     const auto clampNanMode = clampOp.getNanMode();
-    if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+    if (opNanMode == NanPropagation::IGNORE &&
+        clampNanMode == NanPropagation::PROPAGATE)
       return failure();
 
     auto maxValAttr = op.getMaxValAttr();
@@ -636,10 +637,14 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
       }
     }
 
+    auto newMode =
+        (opNanMode != clampNanMode) ? tosa::NanPropagation::IGNORE : opNanMode;
+
+    auto newModeAttr = NanPropagationAttr::get(rewriter.getContext(), newMode);
+
     rewriter.replaceOpWithNewOp<tosa::ClampOp>(
         op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr,
-        rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
-                                                           : opNanMode));
+        newModeAttr);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index c7b9534f9e744..5c04874e494c1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -508,13 +508,13...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/152856


More information about the Mlir-commits mailing list