[Mlir-commits] [mlir] [mlir][tosa] Add a pass to downgrade TOSA `1.1.draft` to `1.0` (PR #194971)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 30 04:51:12 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit adds a pass that will allow 1.1.draft operations to be rewritten to their 1.0 counterparts where possible. The pass currently covers the following operations:
- avg/max_pool2d_adaptive -> avg/max_pool2d when kernel/stride/pad are direct const_shape operands
- bool <-> fp32 casts via i8 bridge casts
- bool gather/scatter with i32 indices via i8 payload rewrites

Note that the downgrade is 'best-effort' and the pass does not perform any validation itself. The validation pass should be run after downgrading to check that the resulting IR was downgraded successfully.

Motivation: This decouples the target specification version in legalizations and backends. Legalizations from higher level frameworks may be updated to support producing TOSA 1.1.draft variants of operations, while backends can still consume TOSA 1.0 IR after running the downgrade pass.

---
Full diff: https://github.com/llvm/llvm-project/pull/194971.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+11) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Tosa/Transforms/TosaDowngrade1p1To1p0.cpp (+217) 
- (added) mlir/test/Dialect/Tosa/tosa-downgrade-1-1-to-1-0.mlir (+161) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 5979ce4962e55..005cbfab782df 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -185,6 +185,17 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
   ];
 }
 
+def TosaDowngrade1p1To1p0Pass
+    : Pass<"tosa-downgrade-1-1-to-1-0", "func::FuncOp"> {
+  let summary = "Downgrade TOSA 1.1 specification constructs to TOSA 1.0";
+  let description = [{
+    Rewrites constructs which are only compatible in TOSA specification 1.1 and
+    above to their TOSA 1.0 counterparts where possible. Downgrading is best-effort
+    and validation should be performed afterwards to ensure compatibility with
+    the TOSA 1.0 specification.
+  }];
+}
+
 def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> {
   let summary = "Narrow I64 TOSA operations to I32";
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index e8a76fa3a1d21..1fd18bb5a395b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaAttachTarget.cpp
   TosaArithConstantToConst.cpp
   TosaConvertIntegerTypeToSignless.cpp
+  TosaDowngrade1p1To1p0.cpp
   TosaDecomposeTransposeConv.cpp
   TosaDecomposeDepthwise.cpp
   TosaFolders.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDowngrade1p1To1p0.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDowngrade1p1To1p0.cpp
new file mode 100644
index 0000000000000..d48c4fbdf6dd1
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDowngrade1p1To1p0.cpp
@@ -0,0 +1,217 @@
+//===- TosaDowngrade1_1To1_0.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Rewrites constructs which are only compatible in TOSA specification 1.1 and
+// above to their TOSA 1.0 counterparts where possible. Downgrading is
+// best-effort and validation should be performed afterwards to ensure
+// compatibility with the TOSA 1.0 specification.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace tosa {
+#define GEN_PASS_DEF_TOSADOWNGRADE1P1TO1P0PASS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+} // namespace tosa
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+namespace {
+
+class AvgPool2dAdaptiveToAvgPool2d
+    : public OpRewritePattern<tosa::AvgPool2dAdaptiveOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::AvgPool2dAdaptiveOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallVector<int64_t> kernel;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> pad;
+    if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+        !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+        !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+      return rewriter.notifyMatchFailure(
+          op, "expected constant kernel, stride, and pad operands");
+
+    auto replacement = tosa::AvgPool2dOp::create(
+        rewriter, op.getLoc(), op.getType(), op.getInput(), op.getInputZp(),
+        op.getOutputZp(), rewriter.getDenseI64ArrayAttr(kernel),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr(pad), op.getAccTypeAttr());
+    rewriter.replaceOp(op, replacement.getOutput());
+    return success();
+  }
+};
+
+class MaxPool2dAdaptiveToMaxPool2d
+    : public OpRewritePattern<tosa::MaxPool2dAdaptiveOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MaxPool2dAdaptiveOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::SmallVector<int64_t> kernel;
+    llvm::SmallVector<int64_t> stride;
+    llvm::SmallVector<int64_t> pad;
+    if (!tosa::getConstShapeValues(op.getKernel().getDefiningOp(), kernel) ||
+        !tosa::getConstShapeValues(op.getStride().getDefiningOp(), stride) ||
+        !tosa::getConstShapeValues(op.getPad().getDefiningOp(), pad))
+      return rewriter.notifyMatchFailure(
+          op, "expected constant kernel, stride, and pad operands");
+
+    auto replacement = tosa::MaxPool2dOp::create(
+        rewriter, op.getLoc(), op.getType(), op.getInput(),
+        rewriter.getDenseI64ArrayAttr(kernel),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr(pad), op.getNanModeAttr());
+    rewriter.replaceOp(op, replacement.getOutput());
+    return success();
+  }
+};
+
+class BoolFp32CastRewrite : public OpRewritePattern<tosa::CastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::CastOp op,
+                                PatternRewriter &rewriter) const override {
+    const Value input = op.getInput();
+
+    const Type i1Type = rewriter.getI1Type();
+    const Type f32Type = rewriter.getF32Type();
+
+    const Type inputElemType = getElementTypeOrSelf(input.getType());
+    const Type outputElemType = getElementTypeOrSelf(op.getType());
+    const bool isFp32ToBool =
+        inputElemType == f32Type && outputElemType == i1Type;
+    const bool isBoolToFp32 =
+        inputElemType == i1Type && outputElemType == f32Type;
+
+    if (!isFp32ToBool && !isBoolToFp32)
+      return rewriter.notifyMatchFailure(op,
+                                         "expected cast between bool and f32");
+
+    const Type outputType = op.getType();
+    const Type i8Type = rewriter.getI8Type();
+    const Type intermediateType = cast<TensorType>(outputType).clone(i8Type);
+
+    auto inner =
+        tosa::CastOp::create(rewriter, op.getLoc(), intermediateType, input);
+    auto outer = tosa::CastOp::create(rewriter, op.getLoc(), outputType,
+                                      inner.getOutput());
+    rewriter.replaceOp(op, outer.getOutput());
+    return success();
+  }
+};
+
+class BoolGatherRewrite : public OpRewritePattern<tosa::GatherOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    const Value values = op.getValues();
+    const Value indices = op.getIndices();
+
+    const Type valuesType = values.getType();
+    const Type resultType = op.getType();
+
+    const Type i1Type = rewriter.getI1Type();
+    const Type i32Type = rewriter.getI32Type();
+    if (getElementTypeOrSelf(valuesType) != i1Type ||
+        getElementTypeOrSelf(indices.getType()) != i32Type)
+      return rewriter.notifyMatchFailure(
+          op, "expected values of bool type and indices of i32 type");
+
+    const Type i8Type = rewriter.getI8Type();
+    const Type valuesI8Type = cast<TensorType>(valuesType).clone(i8Type);
+    const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
+
+    auto valuesToI8 =
+        tosa::CastOp::create(rewriter, op.getLoc(), valuesI8Type, values);
+    auto gatherI8 = tosa::GatherOp::create(rewriter, op.getLoc(), resultI8Type,
+                                           valuesToI8.getOutput(), indices);
+    auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
+                                         gatherI8.getOutput());
+    rewriter.replaceOp(op, i8ToBool.getOutput());
+    return success();
+  }
+};
+
+class BoolScatterRewrite : public OpRewritePattern<tosa::ScatterOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    const Value valuesIn = op.getValuesIn();
+    const Value indices = op.getIndices();
+
+    const Type valuesInType = valuesIn.getType();
+    const Type i1Type = rewriter.getI1Type();
+    const Type i32Type = rewriter.getI32Type();
+    if (getElementTypeOrSelf(valuesInType) != i1Type ||
+        getElementTypeOrSelf(indices.getType()) != i32Type)
+      return rewriter.notifyMatchFailure(
+          op, "expected values of bool type and indices of i32 type");
+
+    const Value input = op.getInput();
+    const Type inputType = input.getType();
+    const Type resultType = op.getType();
+
+    const Type i8Type = rewriter.getI8Type();
+    const Type valuesInI8Type = cast<TensorType>(valuesInType).clone(i8Type);
+    const Type inputI8Type = cast<TensorType>(inputType).clone(i8Type);
+    const Type resultI8Type = cast<TensorType>(resultType).clone(i8Type);
+
+    auto valuesInToI8 =
+        tosa::CastOp::create(rewriter, op.getLoc(), valuesInI8Type, valuesIn);
+    auto inputToI8 =
+        tosa::CastOp::create(rewriter, op.getLoc(), inputI8Type, input);
+    auto scatterI8 = tosa::ScatterOp::create(
+        rewriter, op.getLoc(), resultI8Type, valuesInToI8.getOutput(), indices,
+        inputToI8.getOutput());
+    auto i8ToBool = tosa::CastOp::create(rewriter, op.getLoc(), resultType,
+                                         scatterI8.getValuesOut());
+    rewriter.replaceOp(op, i8ToBool.getOutput());
+    return success();
+  }
+};
+
+struct TosaDowngrade1p1To1p0Pass
+    : public tosa::impl::TosaDowngrade1p1To1p0PassBase<
+          TosaDowngrade1p1To1p0Pass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    func::FuncOp func = getOperation();
+
+    RewritePatternSet patterns(&context);
+    patterns.add<AvgPool2dAdaptiveToAvgPool2d, MaxPool2dAdaptiveToMaxPool2d,
+                 BoolFp32CastRewrite, BoolGatherRewrite, BoolScatterRewrite>(
+        &context);
+    FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+    if (failed(applyPatternsGreedily(func, frozenPatterns)))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-downgrade-1-1-to-1-0.mlir b/mlir/test/Dialect/Tosa/tosa-downgrade-1-1-to-1-0.mlir
new file mode 100644
index 0000000000000..44bc7a63f9558
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-downgrade-1-1-to-1-0.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt --split-input-file --tosa-downgrade-1-1-to-1-0 %s | FileCheck %s
+
+// CHECK-LABEL: @test_max_pool
+// CHECK: %[[POOL:.+]] = tosa.max_pool2d %arg0 {kernel = array<i64: 1, 1>, nan_mode = IGNORE, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+// CHECK: return %[[POOL]]
+func.func @test_max_pool(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+         (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_avg_pool
+// CHECK: %[[POOL:.+]] = tosa.avg_pool2d %arg0, %{{.*}}, %{{.*}} {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
+// CHECK: return %[[POOL]]
+func.func @test_avg_pool(%arg0: tensor<1x7x7x9xf32>, %input_zp: tensor<1xf32>, %output_zp: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
+  %kernel = tosa.const_shape {values = dense<[3, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.avg_pool2d_adaptive %arg0, %input_zp, %output_zp, %kernel, %stride, %pad {acc_type = f32} :
+         (tensor<1x7x7x9xf32>, tensor<1xf32>, tensor<1xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x7x7x9xf32>
+  return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_bool_to_fp32
+// CHECK: %[[BOOL_TO_I8:.+]] = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi8>
+// CHECK: %[[I8_TO_F32:.+]] = tosa.cast %[[BOOL_TO_I8]] : (tensor<13x21x3xi8>) -> tensor<13x21x3xf32>
+// CHECK: return %[[I8_TO_F32]]
+func.func @test_bool_to_fp32(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xf32> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_bool_to_fp32_unranked
+// CHECK: %[[BOOL_TO_I8:.+]] = tosa.cast %arg0 : (tensor<*xi1>) -> tensor<*xi8>
+// CHECK: %[[I8_TO_F32:.+]] = tosa.cast %[[BOOL_TO_I8]] : (tensor<*xi8>) -> tensor<*xf32>
+// CHECK: return %[[I8_TO_F32]]
+func.func @test_bool_to_fp32_unranked(%arg0: tensor<*xi1>) -> tensor<*xf32> {
+  %0 = tosa.cast %arg0 : (tensor<*xi1>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_fp32_to_bool_ranked_dynamic
+// CHECK: %[[FP32_TO_I8:.+]] = tosa.cast %arg0 : (tensor<13x?x3xf32>) -> tensor<13x?x3xi8>
+// CHECK: %[[I8_TO_BOOL:.+]] = tosa.cast %[[FP32_TO_I8]] : (tensor<13x?x3xi8>) -> tensor<13x?x3xi1>
+// CHECK: return %[[I8_TO_BOOL]]
+func.func @test_fp32_to_bool_ranked_dynamic(%arg0: tensor<13x?x3xf32>) -> tensor<13x?x3xi1> {
+  %0 = tosa.cast %arg0 : (tensor<13x?x3xf32>) -> tensor<13x?x3xi1>
+  return %0 : tensor<13x?x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_unranked_fp32_to_bool
+// CHECK: %[[FP32_TO_I8:.+]] = tosa.cast %arg0 : (tensor<*xf32>) -> tensor<*xi8>
+// CHECK: %[[I8_TO_BOOL:.+]] = tosa.cast %[[FP32_TO_I8]] : (tensor<*xi8>) -> tensor<*xi1>
+// CHECK: return %[[I8_TO_BOOL]]
+func.func @test_unranked_fp32_to_bool(%arg0: tensor<*xf32>) -> tensor<*xi1> {
+  %0 = tosa.cast %arg0 : (tensor<*xf32>) -> tensor<*xi1>
+  return %0 : tensor<*xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_preserve_bool_to_i8
+// CHECK: %[[CAST:.+]] = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi8>
+// CHECK: return %[[CAST]]
+func.func @test_preserve_bool_to_i8(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi8> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_gather_bool_i32
+// CHECK: %[[VALUES_TO_I8:.+]] = tosa.cast %arg0 : (tensor<13x21x3xi1>) -> tensor<13x21x3xi8>
+// CHECK: %[[GATHER_I8:.+]] = tosa.gather %[[VALUES_TO_I8]], %arg1 : (tensor<13x21x3xi8>, tensor<13x26xi32>) -> tensor<13x26x3xi8>
+// CHECK: %[[I8_TO_BOOL:.+]] = tosa.cast %[[GATHER_I8]] : (tensor<13x26x3xi8>) -> tensor<13x26x3xi1>
+// CHECK: return %[[I8_TO_BOOL]]
+func.func @test_gather_bool_i32(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xi1> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi32>) -> tensor<13x26x3xi1>
+  return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_preserve_gather_bool_i64
+// CHECK: %[[GATHER:.+]] = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi64>) -> tensor<13x26x3xi1>
+// CHECK: return %[[GATHER]]
+func.func @test_preserve_gather_bool_i64(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xi1> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi1>, tensor<13x26xi64>) -> tensor<13x26x3xi1>
+  return %0 : tensor<13x26x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_preserve_gather_i8_i32
+// CHECK: %[[GATHER:.+]] = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi8>, tensor<13x26xi32>) -> tensor<13x26x3xi8>
+// CHECK: return %[[GATHER]]
+func.func @test_preserve_gather_i8_i32(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xi8> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xi8>, tensor<13x26xi32>) -> tensor<13x26x3xi8>
+  return %0 : tensor<13x26x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_scatter_bool_i32
+// CHECK: %[[VALUES_IN_TO_I8:.+]] = tosa.cast %arg0 : (tensor<13x52x3xi1>) -> tensor<13x52x3xi8>
+// CHECK: %[[INPUT_TO_I8:.+]] = tosa.cast %arg2 : (tensor<13x26x3xi1>) -> tensor<13x26x3xi8>
+// CHECK: %[[SCATTER_I8:.+]] = tosa.scatter %[[VALUES_IN_TO_I8]], %arg1, %[[INPUT_TO_I8]] : (tensor<13x52x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x52x3xi8>
+// CHECK: %[[I8_TO_BOOL:.+]] = tosa.cast %[[SCATTER_I8]] : (tensor<13x52x3xi8>) -> tensor<13x52x3xi1>
+// CHECK: return %[[I8_TO_BOOL]]
+func.func @test_scatter_bool_i32(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi32>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+  return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_preserve_scatter_bool_i64
+// CHECK: %[[SCATTER:.+]] = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+// CHECK: return %[[SCATTER]]
+func.func @test_preserve_scatter_bool_i64(%arg0: tensor<13x52x3xi1>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xi1>) -> tensor<13x52x3xi1> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi1>, tensor<13x26xi64>, tensor<13x26x3xi1>) -> tensor<13x52x3xi1>
+  return %0 : tensor<13x52x3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @test_preserve_scatter_i8_i32
+// CHECK: %[[SCATTER:.+]] = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x52x3xi8>
+// CHECK: return %[[SCATTER]]
+func.func @test_preserve_scatter_i8_i32(%arg0: tensor<13x52x3xi8>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi8>) -> tensor<13x52x3xi8> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xi8>, tensor<13x26xi32>, tensor<13x26x3xi8>) -> tensor<13x52x3xi8>
+  return %0 : tensor<13x52x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_preserve_non_const_adaptive
+// CHECK: tosa.max_pool2d_adaptive
+func.func @test_preserve_non_const_adaptive(%arg0: tensor<1x?x?x8xf32>) -> tensor<1x?x?x8xf32> {
+  %dim1 = tosa.dim %arg0 {axis = 1 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %dim2 = tosa.dim %arg0 {axis = 2 : i32} : (tensor<1x?x?x8xf32>) -> !tosa.shape<1>
+  %kernel = tosa.concat_shape %dim1, %dim2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad {nan_mode = IGNORE} :
+          (tensor<1x?x?x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x?x?x8xf32>
+  return %0 : tensor<1x?x?x8xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/194971


More information about the Mlir-commits mailing list