[Mlir-commits] [mlir] 8388040 - [mlir][tosa] Add NaN Propagation Mode Support (#121951)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 23 02:14:03 PST 2025
Author: Jack Frankland
Date: 2025-01-23T10:14:00Z
New Revision: 8388040fc9e75d49cd000b3371e2610c6c3548ba
URL: https://github.com/llvm/llvm-project/commit/8388040fc9e75d49cd000b3371e2610c6c3548ba
DIFF: https://github.com/llvm/llvm-project/commit/8388040fc9e75d49cd000b3371e2610c6c3548ba.diff
LOG: [mlir][tosa] Add NaN Propagation Mode Support (#121951)
The TOSA-V1.0 specification adds "nan propagation" modes as attributes
for several operators. Adjust the ODS definitions of the relevant
operations to include this attribute.
The defined modes are "PROPAGATE" and "IGNORE" and the PROPAGATE mode is
set by default.
MAXIMUM, MINIMUM, REDUCE_MAX, REDUCE_MIN, MAX_POOL, CLAMP, and ARGMAX
support this attribute.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
Co-authored-by: TatWai Chong <tatwai.chong at arm.com>
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2953e006bbe8d1..92ab729f5b933a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -42,7 +42,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let arguments = (ins
Tosa_Tensor: $input,
- I32Attr: $axis
+ I32Attr: $axis,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
@@ -287,7 +288,8 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
Tosa_IntArrayAttr2:$kernel,
Tosa_IntArrayAttr2:$stride,
- Tosa_IntArrayAttr4:$pad
+ Tosa_IntArrayAttr4:$pad,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
@@ -388,7 +390,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {
I64Attr:$min_int,
I64Attr:$max_int,
Tosa_FloatAttr:$min_fp,
- Tosa_FloatAttr:$max_fp
+ Tosa_FloatAttr:$max_fp,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
@@ -752,7 +755,8 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Tensor:$input2
+ Tosa_Tensor:$input2,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
@@ -775,7 +779,8 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Tensor:$input2
+ Tosa_Tensor:$input2,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
@@ -1382,7 +1387,8 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
let arguments = (ins
Tosa_Tensor:$input,
- I32Attr:$axis
+ I32Attr:$axis,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
@@ -1417,7 +1423,8 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
let arguments = (ins
Tosa_Tensor:$input,
- I32Attr:$axis
+ I32Attr:$axis,
+ DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 13325fb0ab9a20..5693acf3a01db4 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -205,12 +205,20 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
+// Defined in `section 3. Enumerations` of the TOSA specification.
+
// Supported regimes for tosa.resize.
def Tosa_ResizeTypeAttr : StringBasedAttr<
CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"BILINEAR\" || " #
"::llvm::cast<StringAttr>($_self).getValue() == \"NEAREST_NEIGHBOR\"">,
"Supported resize/upsampling strategies">;
+// Supported NaN propagation strategies.
+def Tosa_NanPropagationAttr : StringBasedAttr<
+ CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"PROPAGATE\" || " #
+ "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
+ "Supported NaN propagation strategies">;
+
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Tensor to buffer types.
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f7a596f1ccb192..8b883487d1659b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
};
+// Attempts the following transformation:
+//
+// For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
+// tensor X the following identity holds:
+//
+// CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
+//
+// subject to the following valid NaN propagation semantics:
+// --------------------------------------------
+// | OUTER CLAMP | INNER CLAMP | RESULT MODE |
+// |-------------|--------------|-------------|
+// | PROPAGATE | PROPAGATE | PROPAGATE |
+// | PROPAGATE | IGNORE | IGNORE |
+// | IGNORE | PROPAGATE | INVALID |
+// | IGNORE | IGNORE | IGNORE |
+// |------------------------------------------|
+
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
+ // Helper structure to describe the range of a clamp operation.
+ template <typename T>
+ struct ClampRange {
+ ClampRange(const T &start, const T &end) : start(start), end(end) {}
+ T start;
+ T end;
+
+ // Helper function to determine if two Clamp ranges intersect.
+ bool intersects(const ClampRange<T> &otherRange) {
+ return start < otherRange.end && otherRange.start < end;
+ }
+ };
+
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
- Value input = op.getInput();
-
- Operation *definingOp = input.getDefiningOp();
- if (!definingOp)
+ // Check the input to the CLAMP op is itself a CLAMP.
+ auto clampOp =
+ dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
+ if (!clampOp)
return failure();
- if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
- auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
- auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
+ // Check we have a valid NaN propagation combination.
+ const auto opNanMode = op.getNanMode();
+ const auto clampNanMode = clampOp.getNanMode();
+ if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
+ return failure();
- auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
- auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
+ // Check we have intersecting ranges.
+ const auto opMinInt = op.getMinInt();
+ const auto opMaxInt = op.getMaxInt();
+ const auto clampOpMinInt = clampOp.getMinInt();
+ const auto clampOpMaxInt = clampOp.getMaxInt();
+ ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
+ ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
+ if (!opRangeIntRange.intersects(clampRangeIntRange))
+ return failure();
- rewriter.replaceOpWithNewOp<tosa::ClampOp>(
- op, op.getType(), clampOp.getInput(),
- rewriter.getI64IntegerAttr(minInt),
- rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
- rewriter.getF32FloatAttr(maxFp));
- return success();
- }
+ const auto opMinFloat = op.getMinFp();
+ const auto opMaxFloat = op.getMaxFp();
+ const auto clampOpMinFloat = clampOp.getMinFp();
+ const auto clampOpMaxFloat = clampOp.getMaxFp();
+ ClampRange opRangeFloatRange(opMinFloat, opMaxFloat);
+ ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
+ if (!opRangeFloatRange.intersects(clampRangeFloatRange))
+ return failure();
- return failure();
+ // Run the transformation.
+ const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
+ const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
+ const auto minInt = std::max(opMinInt, clampOpMinInt);
+ const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
+ rewriter.replaceOpWithNewOp<tosa::ClampOp>(
+ op, op.getType(), clampOp.getInput(),
+ rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
+ rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
+ rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
+ : opNanMode));
+ return success();
}
};
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e394188e9a9311..6f47f041b9199a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -138,6 +138,58 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// -----
+// CHECK: @disjoint_clamp_twice_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
+func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+ // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = -5.000000e+00 : f32, max_int = -5 : i64, min_fp = -1.000000e+00 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
+ // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 1.000000e+00 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
+ %0 = tosa.clamp %arg0 {max_fp = -5.0 : f32, max_int = -5 : i64, min_fp = -1.0 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 5 : i64, min_fp = 1.0 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp
+func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+ // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
+ %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp
+func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+ // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
+ %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp
+func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+ // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"}
+ %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
+// CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>)
+func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+ // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = 3.000000e+00 : f32, max_int = 4 : i64, min_fp = -5.000000e+00 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8>
+ // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8>
+ %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8>
+ return %1 : tensor<4xi8>
+}
+
+// -----
+
// CHECK-LABEL: @concat_fold
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 563c5fa457d351..19b93d7611854d 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -180,6 +180,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: clamp_propagate
+func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: clamp_ignore
+func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: clamp_f16
func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
More information about the Mlir-commits
mailing list