[Mlir-commits] [mlir] b281211 - TOSA MLIR Dialect
Stella Laurenzo
llvmlistbot at llvm.org
Sat Nov 7 08:56:40 PST 2020
Author: Suraj Sudhir
Date: 2020-11-07T08:38:09-08:00
New Revision: b28121133d8cc4c57cc086b94e1248e7a2555465
URL: https://github.com/llvm/llvm-project/commit/b28121133d8cc4c57cc086b94e1248e7a2555465
DIFF: https://github.com/llvm/llvm-project/commit/b28121133d8cc4c57cc086b94e1248e7a2555465.diff
LOG: TOSA MLIR Dialect
This is the TOSA MLIR Dialect described in the following MLIR RFC: https://llvm.discourse.group/t/rfc-tosa-dialect-in-mlir/1971/24
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D90411
Added:
mlir/docs/Dialects/TOSA.md
mlir/include/mlir/Dialect/Tosa/CMakeLists.txt
mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td
mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/Tosa/Transforms/PassDetail.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
mlir/lib/Dialect/Tosa/CMakeLists.txt
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
mlir/test/Dialect/Tosa/broadcast.mlir
mlir/test/Dialect/Tosa/constrained_shapes.mlir
mlir/test/Dialect/Tosa/inlining.mlir
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/quant-test.mlir
mlir/test/lib/Dialect/Tosa/CMakeLists.txt
mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
Modified:
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/CMakeLists.txt
mlir/test/lib/Dialect/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/TOSA.md b/mlir/docs/Dialects/TOSA.md
new file mode 100644
index 000000000000..2d165d13261b
--- /dev/null
+++ b/mlir/docs/Dialects/TOSA.md
@@ -0,0 +1,146 @@
+# TOSA Dialect
+
+[TOC]
+
+## Rationale
+
+The MLIR TOSA dialect implements the [TOSA
+specification](https://developer.mlplatform.org/w/tosa/). This document
+describes the decision process for how TOSA expresses operators in
+high level dialects.
+
+TOSA was developed after parallel efforts to rationalize the top-down picture
+from multiple high-level frameworks, as well as a bottom-up view of
diff erent
+hardware target concerns (CPU, GPU and NPU), and reflects a set of choices
+that attempt to manage both sets of requirements.
+
+## TOSA and Tensor Level Expressiveness
+
+TOSA endeavors to provide an operator set that tries to fulfil the following
+expressivenes goals at the *tensor level of abstraction* :
+
+### Complete
+
+This is driven by the top-down perspective, needing to express as much of
+multiple high level frameworks fully in TOSA, as possible. This was originally
+done from an operator frequency analysis done upon dozens of high level
+networks in
diff erent frameworks, to select the most frequently occuring ones
+and establish a common set of tensor-level operators that could express them.
+
+TOSA categorizes its operator set into classes and attempts to address major
+functional operations at the tensor level, including compute, reduction,
+elementwise transformations, comparison and control flow.
+
+### Minimal
+
+This takes the bottom-up approach - keep the TOSA operator set minimal in
+order to bound the design of hardware, operator kernels, code generation
+strategies and associated considerations that effect the executability of TOSA
+content.
+
+In this regard TOSA seeks to avoid creating compound operators, instead
+leaving it to compiler backend to fuse multiple TOSA ops if required. This
+choice also benefits the numerical precision goal, since it is easier to fuse the
+numerical functionality of successive operators, than to split the numerical
+functionality of a compound operator.
+
+### Numerical Precision
+
+TOSA began as a means to address operator-level numerical precision for
+code generation and hardware development. It therefore incorporates precision
+detail into the operator set.
+
+In this regard, TOSA operators are best understood as a combination of the visible
+quantization information embedded within an operation, together with the
+functional information about how that information is used, as described in the
+specification of the operation.
+
+## TOSA Operator Rationale
+
+The general basis of selection of the operator set that constitutes TOSA is
+described in the TOSA specification document under Section 1.3 Operator
+Selection. Explanation of the thinking behind some operators is listed here:
+
+### IDENTITYN
+
+tosa.IDENTITYN is used to form a list of Operator results during
+lowering of operations such as tf.Split from a sequence of tosa.SLICE
+ops. If there are alternate ways to express this lowering without the
+tosa.IDENTITYN op, the tosa.IDENTITYN op could be removed from TOSA.
+
+```
+Value lower_split_op(Value %value, size_t axis, size_t
+num_split) { Value %output[]
+
+ size_t slice_size = %value.shape[axis] / num_split
+
+ for (int i = 0; i < num_split; i++) {
+ vector <size_t> begin_vals, size_vals
+
+ for (int j = 0; j < %value.rank; j++) {
+ if (j == axis) {
+ begin_vals.push_back(slice_size * i)
+ size_vals.push_back(slice_size)
+ } else {
+ begin_vals.push_back(0)
+ size_vals.push_bac(%value.shape[j])
+ }
+
+ %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} (tensor<%value.type>) -> tensor<size_vals, %value.dtype>
+ }
+
+ }
+
+ %output_list = tosa.IDENTITYN(%output) (tensor<%output:*.type>) -> tensor<%output_list:*.type>
+ return %output_list
+}
+```
+
+### COND\_IF and WHILE\_LOOP
+
+Several neural networks express conditional control flow at the tensor level.
+A survey of multiple high level frameworks indicated that conditional if and
+a loop construct are common in all major frameworks, with some variation.
+Since TOSA endeavors to be complete in expressing tensor level functionality
+including control flow, it implements these constructs.
+
+The COND\_IF and WHILE\_LOOP operators implement such structured control
+flow forms and should be lowerable to corresponding ops in the scf dialect.
+Since the dialect seeks to remain isomorphic with an external, serialized form,
+the decision was to keep these ops in the dialect (as opposed to deferring
+completely to scf), and this may be re-evaluated if this turns out to not yield
+the expected value.
+
+## Using TOSA In A Compiler
+
+The TOSA specification describes each operator in functional detail. It is
+expected that compilers that use TOSA will use its builders to construct the
+operators so that the quantization information for the operator is correctly
+generated.
+
+The functional steps described in the pseudocode of the specification enables
+the construction of code generation for that operation, or decisions on the
+design of underlying hardware. The functional pseudocode also describes
+how the quantization parameters are utilized within the operation.
+
+### Quantization Parameters in Ops vs Tensors
+
+TOSA uses the quantization parameters embedded in the input and output
+tensors to construct the quantization attributes that sit within the operator.
+Once these attributes are constructed, the quantization information within
+the tensors are no longer necessary for code generation.
+
+This enables the tensors to be subsequently interpreted simply as contiguous
+buffers containing raw data, with no 'meta information' in the form of the
+quantization_type. Precision related manipulation of the input or output are
+instead described by the operator itself which describes, for example, when
+the zero point is applied, or when the scale multiplication is done.
+
+However, TOSA does *not* eliminate the existing MLIR QuantOps quantization
+type information within the tensors; this leaves the choice of how to handle
+quantization information, to later backend code generation steps.
+
+Maintaining the ability to overlap these
diff erent representations of
+quantization parameters (i.e. tensor-carried vs op-carried) is an important
+capability when considering progressive lowering between uses that expect one
+scheme vs the other.
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 103225948238..09c6ae569c18 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -13,4 +13,5 @@ add_subdirectory(SCF)
add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(StandardOps)
+add_subdirectory(Tosa)
add_subdirectory(Vector)
diff --git a/mlir/include/mlir/Dialect/Tosa/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/CMakeLists.txt
new file mode 100644
index 000000000000..9f57627c321f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
new file mode 100644
index 000000000000..6b49fab778d7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -0,0 +1,18 @@
+set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaDialect.h.inc -gen-dialect-decls)
+mlir_tablegen(TosaOps.h.inc -gen-op-decls)
+mlir_tablegen(TosaOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTosaOpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls)
+mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs)
+add_public_tablegen_target(MLIRTosaStructsIncGen)
+
+
+set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td)
+mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(TosaInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRTosaInterfaceIncGen)
+
+add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td
new file mode 100644
index 000000000000..df4aa70427ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td
@@ -0,0 +1,25 @@
+//===-- TosaInterfaces.td - TOSA dialect interfaces --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dialect op interfaces for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_OP_INTERFACES
+#define TOSA_OP_INTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def TosaOpInterface : OpInterface<"TosaOp"> {
+ let description = [{
+ Implements interfaces implemented by ops that correspond to the Tosa
+ specification.
+ }];
+}
+
+#endif
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
new file mode 100644
index 000000000000..74c60b7f6a12
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -0,0 +1,192 @@
+//===-- TosaOpBase.td - TOSA dialect op builders -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the common definitions for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef TOSA_OP_BASE
+#define TOSA_OP_BASE
+
+//===----------------------------------------------------------------------===//
+// The TOSA Dialect.
+//===----------------------------------------------------------------------===//
+def Tosa_Dialect : Dialect {
+ let name = "tosa";
+
+ let description = [{
+ The Tensor Operator Set Architecture (TOSA) dialect.
+
+ This dialect implements the TOSA standard described at
+ https://developer.mlplatform.org/w/tosa/ .
+
+ Tensor Operator Set Architecture (TOSA) provides a set of whole-tensor
+ operations commonly employed by Deep Neural Networks. The intent is to
+ enable a variety of implementations running on a diverse range of
+ processors, with the results at the TOSA level consistent across those
+ implementations. Applications or frameworks which target TOSA can therefore
+ be deployed on a wide range of
diff erent processors, such as CPUs or GPUs,
+ with defined accuracy and compatibility constraints. Most operators from the
+ common ML frameworks should be expressible in TOSA. It is expected that
+ there will be tools to lower from the ML frameworks into TOSA.
+ }];
+
+ let cppNamespace = "mlir::tosa";
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Quantization Attributes.
+//===----------------------------------------------------------------------===//
+
+// Quantization attributes used across TOSA operators. Quantization attributes
+// feed numerical precision parameters to the functional implementation of TOSA
+// operators.
+// The functional behavior is defined in the TOSA specification maintained at
+// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in
+// quantization support: https://mlir.llvm.org/docs/Quantization/, and supports
+// uniform quantization. Depending on datatype, asymmetric and symmetric
+// quantization are supported. The types themselves are described in
+// TosaTypesBase.td .
+
+// This quantization attribute expresses numerical behavior of operators where
+// the operator has a numerical relationship between a single input and output.
+// For example: tosa.negate.
+def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr",
+ Tosa_Dialect, [
+ StructFieldAttr<"input_zp", I32Attr>,
+ StructFieldAttr<"output_zp", I32Attr>
+ ]> {
+ let description = "Attribute for UnaryOp quantization information.";
+}
+
+// There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In
+// this case, a tosa.rescale is used to express the inputs to the same scale.
+// TODO: Upload WIP legalization document describing this construction by
+// example.
+
+// This quantization attribute holds input and weight zero point. Both the
+// ConvOp and MatMulOp QuantizationAttrs follow a common design semantic where
+// their ownquantization attribute only expresses the numerical behavior at
+// the inputs.
+// The scaling of their accumulator output is done using an explicit
+// tosa.rescale operator that scales the accumulator result to output scale.
+def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr",
+ Tosa_Dialect, [
+ StructFieldAttr<"input_zp", I32Attr>,
+ StructFieldAttr<"weight_zp", I32Attr>
+ ]> {
+ let description = "Attribute for Conv type op quantization information.";
+}
+
+def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr",
+ Tosa_Dialect, [
+ StructFieldAttr<"a_zp", I32Attr>,
+ StructFieldAttr<"b_zp", I32Attr>
+ ]> {
+ let description = "Attribute for MatMulOp quantization information.";
+}
+
+// This attribute holds input zero point correction applied to the padding
+// zeros to ensure numerical accuracy in the subsequent TOSA operations.
+// Its functional application is described in the tosa.pad() operator
+// description in the specification.
+def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr",
+ Tosa_Dialect, [
+ StructFieldAttr<"input_zp", I32Attr>
+ ]> {
+ let description = "Attribute for PadOp quantization information.";
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Quantization Builders.
+//===----------------------------------------------------------------------===//
+
+// This builder is called on all convolution operators except for TransposeConv,
+// which has specialized output shape semantics. The builder also defines the
+// bitwidth of the output given the bit width of the input & weight content.
+def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias,
+ "ArrayAttr":$pad, "ArrayAttr":$stride, "ArrayAttr":$dilation),
+ [{
+ ::buildConvOpWithQuantInfo($_builder, $_state, outputType,
+ input, weight, bias,
+ pad, stride, dilation);
+ }]>;
+
+// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
+def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias,
+ "ArrayAttr":$outpad, "ArrayAttr":$stride, "ArrayAttr":$dilation,
+ "ArrayAttr":$outputShape),
+ [{
+ ::buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
+ input, weight, bias,
+ outpad, stride, dilation,
+ outputShape);
+ }]>;
+
+// The tosa.fully_connected op has its own builder as it does not have
+// strides/dilation/padding.
+def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
+ [{
+ ::buildFCOpWithQuantInfo($_builder, $_state, outputType,
+ input, weight, bias);
+ }]>;
+
+// The tosa.matmul op is also intended to be generated where a fully_connected
+// op must be constructed where the weight is not a constant. In this case,
+// the fully_connected op must be expressed using matmul.
+// TODO: Add link to the leglization document explaining this.
+def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$a, "Value":$b),
+ [{
+ ::buildMatMulOpWithQuantInfo($_builder, $_state, outputType,
+ a, b);
+ }]>;
+
+// Both the tosa.avg_pool2d and unary ops use the same
+// UnaruOpQuantizationAttr but the avg_pool operator has its own builder as it
+// has additional parameters not part of the unary ops.
+def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$input, "ArrayAttr":$kernel,
+ "ArrayAttr":$stride, "ArrayAttr":$pad),
+ [{
+ ::buildAvgPool2dOpWithQuantInfo($_builder, $_state, outputType,
+ input, kernel, stride, pad);
+ }]>;
+
+// This builder is called on single-parameter unary operators that have a scale
+// relationship between their input and output, expressed by the
+// UnaryOpQuantizationAttr.
+def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$input),
+ [{
+ ::buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
+ }]>;
+
+// This builder is called on the TOSA pad operator that needs to create its own
+// OptionalAttr quantization_attr parameter to scale the padding values
+// correctly.
+def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG<
+ (ins "Type":$outputType, "Value":$input, "Value":$paddings),
+ [{
+ ::buildPadOpWithQuantInfo($_builder, $_state, outputType,
+ input, paddings);
+ }]>;
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator.
+//===----------------------------------------------------------------------===//
+
+class Tosa_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
+}
+
+#endif // TOSA_OP_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
new file mode 100644
index 000000000000..7af0b6dbb704
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -0,0 +1,38 @@
+//===-- TosaOps.h - TOSA dialect operation definitions ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the TOSA Dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_TOSA_IR_TOSA_OPS_H
+#define DIALECT_TOSA_IR_TOSA_OPS_H
+
+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// TOSA dialect and structs includes.
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Tosa/IR/TosaDialect.h.inc"
+#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc"
+
+namespace mlir {
+namespace tosa {
+
+#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
+
+} // end namespace tosa
+} // end namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
+
+#endif // TOSA_OPS_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
new file mode 100644
index 000000000000..e9dc5eb6180b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -0,0 +1,1701 @@
+//===-- TosaOps.td - TOSA dialect operation definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the operation set for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_OPS
+#define TOSA_OPS
+
+include "mlir/IR/OpBase.td"
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
+
+include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.2
+// Operator Class: Tensor Data Engine Operators.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: argmax
+//===----------------------------------------------------------------------===//
+def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> {
+ let summary = "Perform argmax on the input.";
+
+ let description = [{
+ This returns the index with the largest value across the given axis of the
+ input tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D: $input,
+ I64Attr: $axis
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D: $output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: avg_pool2d
+//===----------------------------------------------------------------------===//
+def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> {
+ let summary = "Performs max pooling on the input.";
+
+ let description = [{
+ This performs an average pooling over the given input tensor. A sliding
+ window of size given by <kernel size> is passed over the input tensor, with
+ the mean value being placed in the output tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor4D:$input,
+
+ Tosa_IntArrayAttr2:$kernel,
+ Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr4:$pad,
+ OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+
+ let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: conv2d
+//===----------------------------------------------------------------------===//
+def Tosa_Conv2DOp : Tosa_Op<"conv2d", [NoSideEffect]> {
+ let summary = "2D Convolution Operator";
+
+ let description = [{
+ Performs a 2D convolution over the given tensor input, using the weight
+ tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor4D:$input,
+ Tosa_Tensor4D:$weight,
+ Tosa_Tensor1D:$bias,
+
+ Tosa_IntArrayAttr4:$pad,
+ Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr2:$dilation,
+ OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+
+ let builders = [Tosa_ConvOpQuantInfoBuilder];
+
+ let verifier = [{ return ::verifyConvOp(*this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: conv3d
+//===----------------------------------------------------------------------===//
+def Tosa_Conv3DOp : Tosa_Op<"conv3d", [NoSideEffect]> {
+ let summary = "3D Convolution operator";
+
+ let description = [{
+ Performs a 3D convolution over the given input tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor5D:$input,
+ Tosa_Tensor5D:$weight,
+ Tosa_Tensor1D:$bias,
+
+ Tosa_IntArrayAttr6:$pad,
+ Tosa_IntArrayAttr3:$stride,
+ Tosa_IntArrayAttr3:$dilation,
+ OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor5D:$output
+ );
+
+ let builders = [Tosa_ConvOpQuantInfoBuilder];
+
+ let verifier = [{ return ::verifyConvOp(*this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: depthwise_conv2d
+//===----------------------------------------------------------------------===//
+def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> {
+ let summary = "Depthwise 2D Convolution operator";
+
+ let description = [{
+ Performs 2D convolutions separately over each channel of the given tensor
+ input, using the weight tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor4D:$input,
+ Tosa_Tensor4D:$weight,
+ Tosa_Tensor1D:$bias,
+
+ Tosa_IntArrayAttr4:$pad,
+ Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr2:$dilation,
+ OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+
+ let builders = [Tosa_ConvOpQuantInfoBuilder];
+
+ let verifier = [{ return ::verifyConvOp(*this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: fully_connected
+//===----------------------------------------------------------------------===//
+def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> {
+ let summary = "Fully Connected operator";
+
+ let description = [{
+ Performs a fully connected network.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor2D:$input,
+ Tosa_Tensor2D:$weight,
+ Tosa_Tensor1D:$bias,
+ OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor2D:$output
+ );
+
+ let builders = [Tosa_FCOpQuantInfoBuilder];
+
+ let verifier = [{ return ::verifyConvOp(*this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: matmul
+//===----------------------------------------------------------------------===//
+def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> {
+ let summary = "Matrix multiplication with bias";
+
+ let description = [{
+ Performs a two dimensional matrix multiplication. This allows both inputs to
+ be activations, rather than reserving weights as an attribute in the
+ FULLY_CONNECTED operator.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor2D:$a,
+ Tosa_Tensor2D:$b,
+ OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor2D:$c
+ );
+
+ let builders = [Tosa_MatMulOpQuantInfoBuilder];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: max_pool2d
+//===----------------------------------------------------------------------===//
+def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> {
+ let summary = "Performs max pooling on the input.";
+
+ let description = [{
+ This performs a max pooling over the given input tensor. A sliding window of
+ size given by <kernel size> is passed over the input tensor, with the
+ maximum value being placed in the
+ output tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor4D:$input,
+
+ Tosa_IntArrayAttr2:$kernel,
+ Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr4:$pad
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: transpose_conv2d
+//===----------------------------------------------------------------------===//
+def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> {
+ let summary = "Transpose 2D Convolution operator.";
+
+ let description = [{
+ Performs a 2D transposed convolution over the given tensor input, using the
+ weights tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor4D:$input,
+ Tosa_Tensor4D:$filter,
+ Tosa_Tensor1D:$bias,
+
+ Tosa_IntArrayAttr2:$out_pad,
+ Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr2:$dilation,
+ Tosa_IntArrayAttrUpto4:$out_shape,
+ OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+
+ let builders = [Tosa_TransConvOpQuantInfoBuilder];
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.3
+// Operator Class: Activation Functions.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: clamp
+//===----------------------------------------------------------------------===//
+def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Computes clamp(features, min, max).";
+
+ let description = [{
+ Clamp to an arbitrary minimum and maximum value. Note that the maximum and
+ minimum values are specified as signed quantized values, no scaling happens
+ before or after this operation.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$min_int,
+ I64Attr:$max_int,
+ F32Attr:$min_fp,
+ F32Attr:$max_fp
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reluN
+//===----------------------------------------------------------------------===//
+def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Computes rectified linear: `max(features, N)`.";
+
+ let description = [{
+ ReLU with a scalar maximum value.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$max_int,
+ F32Attr:$max_fp
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: sigmoid
+//===----------------------------------------------------------------------===//
+def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect,
+ SameOperandsAndResultType]> {
+ let summary = "Computes elementwise sigmoid of input.";
+
+ let description = [{
+ Sigmoid function: output = 1 / (1 + exp(-input))
+ For quantized integer data types, the TABLE operator should be used instead
+ with the following definition. The sigmoid table has 513 entries each of
+ 16-bit precision and covering the input range -16.0 to +16.0
+ in steps of 1/16.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: tanh
+//===----------------------------------------------------------------------===//
+def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Computes elementwise hyperbolic tangent of input";
+
+ let description = [{
+ Parameterized hyperbolic tangent.
+ For quantized integer data types, the TABLE operator should be used instead
+ with the following definition. The tanh_table has 513 entries each of
+ 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.4
+// Operator Class: Elementwise unary/binary/ternary operators.
+// Operator Subclass: Elementwise binary ops.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: add
+//===----------------------------------------------------------------------===//
+def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect,
+ Commutative]> {
+ let summary = "Elementwise addition operator";
+
+ let description = [{
+ Elementwise addition of input1 and input2. Axis of size 1 will be broadcast,
+ as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: arithmetic_right_shift
+//===----------------------------------------------------------------------===//
+def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift",
+ [ResultsBroadcastableShape,
+ NoSideEffect]> {
+ let summary = "Elementwise Arithmetic Right Shift";
+
+ let description = [{
+ Elementwise arithmetic right shift of input1 by the amount specified in
+ input2. Axis of size 1 will be broadcast, as necessary. Rank of input
+ tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2,
+ BoolAttr:$round
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: bitwise_and
+//===----------------------------------------------------------------------===//
+def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape,
+ NoSideEffect, Commutative]> {
+ let summary = "Bitwise AND operator";
+
+ let description = [{
+ Elementwise bitwise AND of input1 and input2. Axis of size 1
+ will be broadcast as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: bitwise_or
+//===----------------------------------------------------------------------===//
+def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape,
+ NoSideEffect, Commutative]> {
+ let summary = "Bitwise OR operator";
+
+ let description = [{
+ Elementwise bitwise OR of input1 and input2. Axis of size 1 will be
+ broadcast as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: bitwise_xor
+//===----------------------------------------------------------------------===//
+def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape,
+ NoSideEffect, Commutative]> {
+ let summary = "Bitwise XOR operator";
+
+ let description = [{
+ Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be
+ broadcast as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: logical_and
+//===----------------------------------------------------------------------===//
+def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape,
+ Commutative, NoSideEffect]> {
+ let summary = "Returns the truth value of x AND y element-wise.";
+
+ let description = [{
+ Elementwise logical AND of input1 and input2. Axis of size 1 will be
+ broadcast, as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ I1Tensor:$input1,
+ I1Tensor:$input2
+ );
+
+ let results = (outs
+ I1Tensor:$z
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: logical_left_shift
+//===----------------------------------------------------------------------===//
+def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift",
+ [ResultsBroadcastableShape,
+ NoSideEffect]> {
+ let summary = "Elementwise Logical Left Shift";
+
+ let description = [{
+ Elementwise left shift of input1 and input2. Axis of size 1 will be
+ broadcast, as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: logical_right_shift
+//===----------------------------------------------------------------------===//
+def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift",
+ [ResultsBroadcastableShape,
+ NoSideEffect]> {
+ let summary = "Elementwise Logical Right Shift";
+
+ let description = [{
+ Elementwise logical right shift of input1 by the amount specified in input2.
+ Axis of size 1 will be broadcast, as necessary.
+ Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: logical_or
+//===----------------------------------------------------------------------===//
+def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape,
+ Commutative, NoSideEffect]> {
+ let summary = "Returns the truth value of x OR y element-wise.";
+
+ let description = [{
+ Elementwise logical OR of input1 and input2. Axis of size 1 will be
+ broadcast as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ I1Tensor:$input1,
+ I1Tensor:$input2
+ );
+
+ let results = (outs
+ I1Tensor:$z
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: logical_xor
+//===----------------------------------------------------------------------===//
+def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape,
+ Commutative, NoSideEffect]> {
+ let summary = "Returns the truth value of x XOR y element-wise.";
+
+ let description = [{
+ Elementwise logical XOR of input1 and input2. Axis of size 1 will be
+ broadcast as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ I1Tensor:$input1,
+ I1Tensor:$input2
+ );
+
+ let results = (outs
+ I1Tensor:$z
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: maximum
+//===----------------------------------------------------------------------===//
+def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape,
+ NoSideEffect, Commutative]> {
+ let summary = "Elementwise Maximum";
+
+ let description = [{
+ Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as
+ necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: minimum
+//===----------------------------------------------------------------------===//
+def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape,
+ NoSideEffect, Commutative]> {
+ let summary = "Elementwise Minimum";
+
+ let description = [{
+ Elementwise minimum of input1 and input2. Axis of size 1
+ will be broadcast, as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: mul
+//===----------------------------------------------------------------------===//
+def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect,
+ Commutative]> {
+ let summary = "Multiplication operator";
+
+ let description = [{
+ Elementwise multiplication (Hadamard product) of input1 and input2.
+ Axis of size 1 will be broadcast, as necessary.
+ Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2,
+ I32Attr:$shift
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: pow
+//===----------------------------------------------------------------------===//
+def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> {
+ let summary = "Computes the power of one value to another.";
+
+ let description = [{
+ Elementwise input1 raised to the power of input2.
+ Axis of size 1 will be broadcast, as necessary.
+ Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$z
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: sub
+//===----------------------------------------------------------------------===//
+def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> {
+ let summary = "Elementwise subtraction operator";
+
+ let description = [{
+ Elementwise subtraction of input1 and input2. Axis of size 1 will be
+ broadcast as necessary. Rank of input tensors must match.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: table
+//===----------------------------------------------------------------------===//
+def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> {
+ let summary = "Table lookup op";
+
+ let description = [{
+ Interpolated table lookup operation. Input values are scaled to create a
+ fixed-point 9.7 value. The high 9 bits are used to index into the table.
+ The fractional bits are used to interpolate based on the looked up value and
+ the index+1 value in the table. The TABLE operator then returns a 16.7
+ interpolated value. Note that there must be 513 values to handle the full
+ range of inputs.
+
+ The TABLE operator is expected to be used as follows:
+ * A RESCALE node is expected before the TABLE operator to scale the input
+ to a full int16_t range for the table lookup
+ * If an int16_t result is required then follow the TABLE operator with a
+ RESCALE with a right shift of 7
+ * If an int8_t result is required then follow the TABLE operator with a
+ RESCALE with a right shift of 15
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D: $input,
+ Tosa_Tensor1D: $table
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.5
+// Operator Class: Elementwise unary/binary/ternary operators.
+// Operator Subclass: Elementwise unary ops.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: abs
+//===----------------------------------------------------------------------===//
+def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise abs op";
+
+ let description = [{
+ Elementwise absolute value operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: bitwise_not
+//===----------------------------------------------------------------------===//
+def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape,
+ NoSideEffect]> {
+ let summary = "Bitwise NOT operator";
+
+ let description = [{
+ Elementwise bitwise NOT of input tensor.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: ceil
+//===----------------------------------------------------------------------===//
+def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise ceil op";
+
+ let description = [{
+ Elementwise ceiling operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: clz
+//===----------------------------------------------------------------------===//
+def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise count leading zero op";
+
+ let description = [{
+ Elementwise count leading zeros operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: exp
+//===----------------------------------------------------------------------===//
+def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise exp op";
+
+ let description = [{
+ Elementwise e to the x operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: floor
+//===----------------------------------------------------------------------===//
+def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise floor op";
+
+ let description = [{
+ Elementwise floor operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: log
+//===----------------------------------------------------------------------===//
+def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise log op";
+
+ let description = [{
+ Elementwise natural logarithm operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: logical_not
+//===----------------------------------------------------------------------===//
+def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect,
+ SameOperandsAndResultType]> {
+ let summary = "Returns the truth value of NOT x element-wise.";
+
+ let description = [{
+ Elementwise logical NOT of input.
+ }];
+
+ let arguments = (ins
+ I1Tensor:$input1
+ );
+
+ let results = (outs
+ I1Tensor:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: negate
+//===----------------------------------------------------------------------===//
+def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect,
+ SameOperandsAndResultType]> {
+ let summary = "Elementwise negate op";
+
+ let description = [{
+ Elementwise negation operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+
+ let builders = [Tosa_UnaryOpQuantInfoBuilder];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reciprocal
+//===----------------------------------------------------------------------===//
+def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect,
+ SameOperandsAndResultType]> {
+ let summary = "Elementwise reciprocal op";
+
+ let description = [{
+ Elementwise reciprocal operation. For integer operation, a TABLE should be
+ used with the appropriate ranges.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: rsqrt
+//===----------------------------------------------------------------------===//
+def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "Elementwise 1/sqrt op";
+
+ let description = [{
+ Elementwise reciprocal square root operation. For integer operation, a TABLE
+ should be used with the appropriate ranges.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.6
+// Operator Class: Elementwise unary/binary/ternary operators.
+// Operator Subclass: Elementwise ternary ops.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: select
+//===----------------------------------------------------------------------===//
+def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> {
+ let summary = "Elementwise select operator";
+
+ let description = [{
+ Elementwise select of the output based on a condition.
+ }];
+
+ let arguments = (ins
+ I1Tensor:$input1,
+ Tosa_TensorUpto4D:$input2,
+ Tosa_TensorUpto4D:$input3
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.7
+// Operator Class: Logical Operations.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: equal
+//===----------------------------------------------------------------------===//
+def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative,
+ NoSideEffect]> {
+ let summary = "Returns the truth value of (x == y) element-wise.";
+
+ let description = [{
+ Elementwise comparison operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ I1Tensor:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: greater
+//===----------------------------------------------------------------------===//
+def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape,
+ NoSideEffect]> {
+ let summary = "Returns the truth value of (x > y) element-wise.";
+
+ let description = [{
+ Elementwise greater than comparison operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ I1Tensor:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: greater_equal
+//===----------------------------------------------------------------------===//
+def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape,
+ NoSideEffect]> {
+ let summary = "Returns the truth value of (x >= y) element-wise.";
+
+ let description = [{
+ Elementwise comparison operation
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input1,
+ Tosa_TensorUpto4D:$input2
+ );
+
+ let results = (outs
+ I1Tensor:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.8
+// Operator Class: Reduction Ops.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: reduce_all
+//===----------------------------------------------------------------------===//
+def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> {
+ let summary = "Reduce All operator";
+
+ let description = [{
+ Reduce a tensor along the given axis with a logical AND operation
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reduce_any
+//===----------------------------------------------------------------------===//
+def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> {
+ let summary = "Reduce Any operator";
+
+ let description = [{
+ Reduce a tensor along the given axis with a logical OR operation
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reduce_max
+//===----------------------------------------------------------------------===//
+def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> {
+ let summary = "Reduce Max operator";
+
+ let description = [{
+ Reduce a tensor along the given axis with a maximum operation
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reduce_min
+//===----------------------------------------------------------------------===//
+def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> {
+ let summary = "Reduce Min operator";
+
+ let description = [{
+ Reduce a tensor along the given axis with a minimum operation
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reduce_prod
+//===----------------------------------------------------------------------===//
+def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> {
+ let summary = "Reduce Prod operator";
+
+ let description = [{
+ Reduce a tensor along the given axis by computing the product of the axis.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reduce_sum
+//===----------------------------------------------------------------------===//
+def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> {
+ let summary = "Reduce Sum operator";
+
+ let description = [{
+ Reduce a tensor along the given axis by computing the sum of the axis.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.9
+// Operator Class: Data Layout / Memory Reinterpretation.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: concat
+//===----------------------------------------------------------------------===//
+def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> {
+ let summary = "Concatenates tensors along one dimension.";
+
+ let description = [{
+ Concatenate two tensors along a given axis. No data conversion happens
+ during a concat operation.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input1,
+ Tosa_Tensor1Dto4D:$input2,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: pad
+//===----------------------------------------------------------------------===//
+def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> {
+ let summary = "Pads a tensor with zeros.";
+
+ let description = [{
+ Zero-pads a tensor along borders of each dimension.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input1,
+ Tosa_Int32Or64Tensor:$padding,
+ OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+
+ let builders = [Tosa_PadOpQuantInfoBuilder];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reshape
+//===----------------------------------------------------------------------===//
+def Tosa_ReshapeOp: Tosa_Op<"reshape", [
+ NoSideEffect]> {
+ let summary = "Reshape operator";
+
+ let description = [{
+ Returns a tensor with the same type/values as the input, with a new shape
+ specified by the shape argument. Reshape may operate on tensors of any rank.
+ No data conversion happens during a reshape operation.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto6D:$input1,
+ I64ArrayAttr:$new_shape
+ );
+
+ let results = (outs
+ Tosa_TensorUpto6D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: reverse
+//===----------------------------------------------------------------------===//
+def Tosa_ReverseOp: Tosa_Op<"reverse", [NoSideEffect]> {
+ let summary = "Reverse operator";
+
+ let description = [{
+ Returns a tensor with the same type/values as the input, with the data
+ reversed along the given axis. No data conversion happens during a reverse
+ operation.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input,
+ I64Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: slice
+//===----------------------------------------------------------------------===//
+def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> {
+ let summary = "Slice operator";
+
+ let description = [{
+ Extracts a slice of the input1 on the given axis, beginning at the
+ start coordinates, and extending for size elements in each direction. No
+ data conversion happens during a slice operation.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto6D:$input,
+ I64ArrayAttr:$start,
+ I64ArrayAttr:$size
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto6D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: tile
+//===----------------------------------------------------------------------===//
+def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> {
+ let summary = "Tile operator";
+
+ let description = [{
+ Replicates input 0 multiplies times along each dimension.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto4D:$input1,
+ I64ArrayAttr:$multiples);
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: transpose
+//===----------------------------------------------------------------------===//
+def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> {
+ let summary = "Transpose operator";
+
+ let description = [{
+ Permutes the dimensions based on perm.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor1Dto6D:$input1,
+ Tosa_Int32Or64Tensor:$perms
+ );
+
+ let results = (
+ outs Tosa_Tensor1Dto6D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.10
+// Operator Class: Scatter/gather Operations.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: gather
+//===----------------------------------------------------------------------===//
+def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> {
+ let summary = "Gather operation,";
+
+ let description = [{
+ Generate a tensor for which each element in the output is a subtensor of the
+ values tensor along the given axis, based on the value of indices.
+ }];
+
+ let arguments = (ins
+ Tosa_Int32Or64Tensor:$indices,
+ Tosa_Tensor1Dto4D:$values,
+ I32Attr:$axis
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.11
+// Operator Class: Image Frontend Functions.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: resize
+//===----------------------------------------------------------------------===//
+def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> {
+
+ let summary = "Resize operation, supports various resize/upsample modes";
+
+ let description = [{
+ Resizes a tensor. Resize is only allowed in the H and W dimensions. In
+ expected use, stride_y is approximately (IH<<shift)/OH and stride_x is
+ approximately (IW<<shift)/OW. OH and OW are also supplied as inputs since
+ there may be off by one errors if calculating OH and OW from the strides.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor4D:$input,
+ Tosa_IntArrayAttr2:$output_size,
+ Tosa_IntArrayAttr2:$stride,
+ Tosa_IntArrayAttr2:$offset,
+ I32Attr:$shift,
+ Tosa_ResizeTypeAttr:$mode
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.12
+// Operator Class: Type Conversion.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: cast
+//===----------------------------------------------------------------------===//
+def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
+
+ let summary = "Cast operation";
+
+ let description = [{
+ Performs a set of permissible cast operations
+ Mode Input Output
+ ---------------------------------------
+ signed 8 to bool int8 Boolean
+ signed 16 to bool int16 Boolean
+ signed 32 to bool int32 Boolean
+ bool to 8 Boolean int8
+ bool to 16 Boolean int16
+ bool to 32 Boolean int32
+ signed 8 to signed 16 int8 int16
+ signed 8 to signed 32 int8 int32
+ signed 16 to signed 8 int16 int8
+ signed 16 to signed 32 int16 int32
+ signed 32 to signed 8 int32 int8
+ signed 32 to signed 16 int32 int16
+ float to signed 8 float int8
+ float to signed 16 float int16
+ signed 8 to float int8 float
+ signed 16 to float int16 float
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: rescale
+//===----------------------------------------------------------------------===//
+def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> {
+ let summary = "Tosa rescale operator";
+
+ let description = [{
+ Rescale quantized values into a new domain. Supported rescalings are:
+ Mode Input Output
+ signed 8 to 8 aint8 aint8
+ signed 8 to 16 aint8 int16
+ signed 8 to 32 aint8 int32
+ signed 16 to 8 int16 aint8
+ signed 16 to 16 int16 int16
+ signed 16 to 32 int16 int32
+ signed 32 to 8 int32 aint8
+ signed 32 to 16 int32 int16
+ signed 32 to 32 int32 int32
+ signed 48 to 8 int48 aint8
+ signed 48 to 16 int48 int16
+ signed 48 to 32 int48 int32
+ unsigned 8 to signed 8 uint8 aint8
+ signed 8 to unsigned 8 aint8 uint8
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto4D:$input,
+ I32Attr:$input_zp,
+ I32Attr:$output_zp,
+ I32ArrayAttr:$multiplier,
+ I32ArrayAttr:$shift,
+ BoolAttr:$scale32,
+ BoolAttr:$double_round,
+ BoolAttr:$per_channel
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.13
+// Operator Class: Data Node Ops.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: const
+//===----------------------------------------------------------------------===//
+def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect,
+ FirstAttrDerivedResultType]> {
+ let summary = "Constant op.";
+
+ let description = [{
+ A node containing constant data for use as the input to an operation. May
+ hold data in any of the supported data formats.
+ }];
+
+ let arguments = (ins
+ AnyAttr:$value
+ );
+
+ let results = (outs
+ Tosa_TensorUpto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: identity
+//===----------------------------------------------------------------------===//
+def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> {
+ let summary = "Identity operator";
+ let description = [{
+ Returns a tensor with the same shape, size, type
+ and content as the input.
+ }];
+
+ let arguments = (ins
+ Tosa_TensorUpto6D:$input1
+ );
+
+ let results = (outs
+ Tosa_TensorUpto6D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: identityn
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Further described in docs/Rationale/RationaleTOSADialect.md .
+//===----------------------------------------------------------------------===//
+def Tosa_IdentityNOp: Tosa_Op<"identityn", [NoSideEffect]> {
+ let summary = "IdentityN operator";
+ let description = [{
+ Returns a list of tensors with the same shape, type, and contents as the
+ input list of tensors.
+ }];
+
+ let arguments = (ins
+ Variadic<Tosa_TensorUpto6D>:$input1
+ );
+
+ let results = (outs
+ Variadic<Tosa_TensorUpto6D>:$output
+ );
+}
+
+
+//===----------------------------------------------------------------------===//
+// Operator: placeholder
+//===----------------------------------------------------------------------===//
+def Tosa_PlaceholderOp : Tosa_Op<"placeholder", [NoSideEffect]> {
+ let summary = "Placeholder op";
+
+ let description = [{
+ A node where data will be inserted into the network at runtime. Generally
+ used for inputs to the network.
+ }];
+
+ let arguments = (ins
+ );
+
+ let results = (outs
+ Tosa_Tensor1Dto4D:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.14
+// Operator Class: Custom Operators.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: custom
+//===----------------------------------------------------------------------===//
+def Tosa_CustomOp : Tosa_Op<"custom"> {
+
+ let summary = "Custom operator wrapper for Tosa";
+
+ let description = [{
+ Hardware implementing TOSA may choose to add additional custom operators
+ that are not expressed in the existing TOSA operations. These operators are
+ not expected to be portable across TOSA implementations. The input and
+ output signatures must be expressed in the corresponding TOSA node.
+ }];
+
+ let arguments = (ins
+ StrAttr:$identifier,
+ Variadic<Tosa_Tensor>:$inputs
+ );
+
+ let results = (outs
+ Variadic<Tosa_Tensor>:$outputs
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Spec Section 2.15
+// Operator Class: Control Flow Operators.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Operator: cond_if
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Further described in docs/Rationale/RationaleTOSADialect.md .
+//===----------------------------------------------------------------------===//
+def Tosa_IfOp : Tosa_Op<"cond_if", [
+ SingleBlockImplicitTerminator<"YieldOp">,
+ RecursiveSideEffects]> {
+ let summary = "Conditional if operator";
+
+ let description = [{
+ Evaluates a Boolean condition and then takes one of two distinct execution
+ paths. This implements the semantic If-then-else structure.
+ }];
+
+ let arguments = (ins
+ I1Tensor:$cond,
+ Variadic<Tosa_Tensor>:$inputs
+ );
+
+ let results = (outs
+ Variadic<Tosa_Tensor>:$output
+ );
+
+ let regions = (region
+ SizedRegion<1>:$then_branch,
+ SizedRegion<1>:$else_branch
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: while_loop
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Further described in docs/Rationale/RationaleTOSADialect.md .
+//===----------------------------------------------------------------------===//
+def Tosa_WhileOp : Tosa_Op<"while_loop", [
+ DeclareOpInterfaceMethods<LoopLikeOpInterface>,
+ SingleBlockImplicitTerminator<"YieldOp">,
+ RecursiveSideEffects]> {
+ let summary = "output = input; While (Cond(output)) {output = Body(output)}";
+
+ let description = [{
+ Generates and evaluates a Bool condition and either executes a loop body or
+ exits to another control point. This action is performed repeatedly after
+ updating and re-evaluating the Boolean condition every iteration. This
+ implements the semantic foreach or while iterative loop structure.
+ }];
+
+ let arguments = (ins
+ Variadic<Tosa_Tensor>:$inputs
+ );
+
+ let results = (outs
+ Variadic<Tosa_Tensor>:$output
+ );
+
+ let regions = (region
+ SizedRegion<1>:$cond,
+ SizedRegion<1>:$body
+ );
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: yield
+//===----------------------------------------------------------------------===//
+def Tosa_YieldOp : Tosa_Op<"yield", [
+ Terminator,
+ NoSideEffect]> {
+ let summary = "yield operator";
+
+ let description = [{
+ return operation within the conditional and body of
+ structured control flow. Operation takes variadic operands
+ but produces no results of its own.
+ }];
+
+ let arguments = (ins
+ Variadic<Tosa_Tensor>:$inputs
+ );
+}
+
+#endif // TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
new file mode 100644
index 000000000000..0ceef0ce70a4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -0,0 +1,159 @@
+//===-- TosaTypesBase.td - TOSA type definitions -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the type definitions for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_TYPES_BASE
+#define TOSA_TYPES_BASE
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Tosa Type Definitions.
+//===----------------------------------------------------------------------===//
+
+// The base class of a quantized type.
+// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
+// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
+// the 8-bit case.
+class Tosa_QuantizedType<string n, list<int> params, bit signed>
+ : Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
+ CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
+ ".getStorageTypeIntegralWidth() == " # !head(params)>]>,
+ "Q" # !if (signed, "int", "uint") # !head(params) # " type"> {
+ string name = n;
+ string asTraitArgsStr =
+ StrJoinInt<params>.result # !if(signed, ", true", ", false");
+}
+
+//===----------------------------------------------------------------------===//
+// Non-Quantized Signed Integer Types.
+// Used to express accumulator results or compare results.
+//===----------------------------------------------------------------------===//
+
+def Tosa_Int32 : I<32>;
+def Tosa_Int48 : I<48>;
+def Tosa_Int64 : I<64>;
+
+def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32,
+ Tosa_Int48,
+ Tosa_Int64]>;
+
+def Tosa_Bool : I<1>;
+
+// No unsigned unquantized int types.
+def Tosa_Int : AnyTypeOf<[Tosa_Bool,
+ Tosa_SignedInt]>;
+
+def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
+ Tosa_Int64]>;
+
+//===----------------------------------------------------------------------===//
+// Quantized Integer Types.
+// Datatype for network feature map or weight content.
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Name Symmetry Grouping Sign
+//===----------------------------------------------------------------------===//
+// aint8 : asymmetric per tensor, signed
+// uint8 : asymmetric per tensor , unsigned
+// int4 : symmetric per channel, signed
+// int8 : symmetric per tensor/per channel, signed
+// int16 : symmetric per tensor, signed
+//===----------------------------------------------------------------------===//
+def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>,
+ Tosa_QuantizedType<"uint8", [8], 0>,
+ Tosa_QuantizedType<"int4", [4, 0], 1>,
+ Tosa_QuantizedType<"int8", [8, 0], 1>,
+ Tosa_QuantizedType<"int16", [16, 0], 1>]>;
+
+//===----------------------------------------------------------------------===//
+// Floating-point types.
+//===----------------------------------------------------------------------===//
+def Tosa_Float : AnyTypeOf<[
+ F32,
+ F16,
+ BF16]>;
+
+//===----------------------------------------------------------------------===//
+// Multi-category types.
+//===----------------------------------------------------------------------===//
+def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
+ "number">;
+
+//===----------------------------------------------------------------------===//
+// Tensor types
+//===----------------------------------------------------------------------===//
+
+def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+
+def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
+
+// Any tensor element type allowed in Tosa ops.
+def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
+ Tosa_Float.predicate]>, "tosa.dtype">;
+
+class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
+ AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
+
+//===----------------------------------------------------------------------===//
+// Tensor types with constrained ranks.
+//===----------------------------------------------------------------------===//
+
+// Must be listed rank.
+def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>;
+def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>;
+def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>;
+
+// Ranked tensors up to given rank.
+def Tosa_Tensor1Dto2D : TensorRankOf<[Tosa_AnyNumber], [1,2]>;
+def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>;
+def Tosa_Tensor1Dto5D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5]>;
+def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>;
+
+def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>;
+def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>;
+
+//===----------------------------------------------------------------------===//
+// Attribute predicates and classes.
+//===----------------------------------------------------------------------===//
+class ArrayMaxCt<int n> : AttrConstraint<
+ CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>,
+ "with at least " # n # " elements">;
+
+def Tosa_IntArrayAttr2 : Confined<I64ArrayAttr, [ArrayCount<2>]>;
+def Tosa_IntArrayAttr3 : Confined<I64ArrayAttr, [ArrayCount<3>]>;
+def Tosa_IntArrayAttr4 : Confined<I64ArrayAttr, [ArrayCount<4>]>;
+def Tosa_IntArrayAttr5 : Confined<I64ArrayAttr, [ArrayCount<5>]>;
+def Tosa_IntArrayAttr6 : Confined<I64ArrayAttr, [ArrayCount<6>]>;
+
+def Tosa_IntArrayAttrUpto2 : Confined<I64ArrayAttr, [ArrayMaxCt<2>]>;
+def Tosa_IntArrayAttrUpto4 : Confined<I64ArrayAttr, [ArrayMaxCt<4>]>;
+def Tosa_IntArrayAttrUpto5 : Confined<I64ArrayAttr, [ArrayMaxCt<5>]>;
+
+//===----------------------------------------------------------------------===//
+// Iterable attributes.
+//===----------------------------------------------------------------------===//
+// Supported regimes for tosa.resize.
+def Tosa_ResizeTypeAttr : StringBasedAttr<
+ CPred<"$_self.cast<StringAttr>().getValue() == \"BILINEAR\" || " #
+ "$_self.cast<StringAttr>().getValue() == \"NEAREST_NEIGHBOR\"">,
+ "Supported resize/upsampling strategies">;
+
+def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
+
+// Tensor to buffer types.
+def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
+def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
+def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
+
+#endif // TOSA_TYPES_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..d9b5375188b8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
+add_public_tablegen_target(MLIRTosaPassIncGen)
+add_dependencies(mlir-headers MLIRTosaPassIncGen)
+
+add_mlir_doc(Passes -gen-pass-doc TosaPasses ./)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/PassDetail.h b/mlir/include/mlir/Dialect/Tosa/Transforms/PassDetail.h
new file mode 100644
index 000000000000..2b5822466d15
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/PassDetail.h
@@ -0,0 +1,21 @@
+//===- PassDetail.h - TOSA Pass class details -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H
+#define DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+} // end namespace mlir
+
+#endif // DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
new file mode 100644
index 000000000000..7742281568c1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -0,0 +1,31 @@
+//===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the optimization passes for the TOSA Dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
+std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
new file mode 100644
index 000000000000..25a4400ca267
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -0,0 +1,37 @@
+//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the optimization passes for the TOSA Dialect in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+include "mlir/Pass/PassBase.td"
+
+def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
+ let summary = "TOSA rank Reshape to enable Broadcasting";
+ let description = [{
+ Pass that enables broadcast by making all input arrays have the same
+ number of dimensions. Insert RESHAPE operations to prepend dimensions
+ of size one until the number of dimensions is equal. Implements
+ approach similar to step 1 of Numpy 4-step broadcasting:
+ https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting
+ }];
+
+ let constructor = "createTosaMakeBroadcastablePass()";
+}
+
+// TOSA Test Passes
+
+def TosaTestQuantUtils : FunctionPass<"tosa-test-quant-utils"> {
+ let summary = "TOSA Test: Exercise the APIs in QuantUtils.cpp";
+ let description = [{
+ Exercises the API that builds a quantized type from min/max quantized range.
+ }];
+
+ let constructor = "createTosaTestQuantUtilAPIPass()";
+}
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
new file mode 100644
index 000000000000..d4e2016112eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -0,0 +1,68 @@
+//===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Function declarations for TOSA numerical support functions and quantization
+// attribute builders
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_TOSA_UTILS_QUANT_UTILS_H
+#define DIALECT_TOSA_UTILS_QUANT_UTILS_H
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+
+#include "mlir/Dialect/Quant/FakeQuantSupport.h"
+#include "mlir/Dialect/Quant/UniformSupport.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+//===----------------------------------------------------------------------===//
+// Utililty functions to support quantization handling in Tosa.
+//===----------------------------------------------------------------------===//
+
+/// From a scale value, computes multiplier and shift values
+/// for 16 or 32-bit scale widths.
+void computeMultiplierAndShift(double scale, int32_t &multiplier,
+ int32_t &shift, int32_t scaleWidth);
+
+//// Builds ConvOpQuantizationAttr from input and weight.
+ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
+ Value input, Value weight);
+
+//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
+MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
+ Value a, Value b);
+
+//// Builds UnaryOpQuantizationAttr for unary operations from input values.
+UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder,
+ Value input,
+ Type outputRawType);
+
+//// Builds PadOpQuantizationAttr for pad operations from input values.
+PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder,
+ Value input);
+
+//// construct ConvOp output type with correct bitwidth based on input/weight
+/// width.
+Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input,
+ Value weight);
+
+/// Builds Tosa quantization attributes from min/max values.
+Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr,
+ Attribute maxAttr, IntegerAttr quantBits,
+ int filterQuantDim, bool isSigned,
+ BoolAttr narrowRange);
+
+/// Builds Tosa quantization attributes from min/max values.
+TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType,
+ Attribute minAttr, Attribute maxAttr,
+ IntegerAttr quantBits, int filterQuantDim,
+ bool isSigned, BoolAttr narrowRange);
+
+#endif // DIALECT_TOSA_UTILS_QUANT_UTILS_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e32d877946e5..f0adcec2e664 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -33,6 +33,7 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Dialect.h"
@@ -60,7 +61,8 @@ inline void registerAllDialects(DialectRegistry ®istry) {
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,
SDBMDialect,
- shape::ShapeDialect>();
+ shape::ShapeDialect,
+ tosa::TosaDialect>();
// clang-format on
}
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 7d0a7726ea6c..11cf6ca5f8fb 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -24,6 +24,7 @@
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
#include <cstdlib>
@@ -54,6 +55,7 @@ inline void registerAllPasses() {
registerShapePasses();
spirv::registerSPIRVPasses();
registerStandardPasses();
+ tosa::registerTosaOptPasses();
}
} // namespace mlir
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 24ffb192338a..bc44049e2ef6 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -14,6 +14,7 @@ add_subdirectory(SDBM)
add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(StandardOps)
+add_subdirectory(Tosa)
add_subdirectory(Vector)
set(LLVM_OPTIONAL_SOURCES
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
new file mode 100644
index 000000000000..ddfe25f80f17
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_dialect_library(MLIRTosa
+ Utils/QuantUtils.cpp
+ IR/TosaOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+
+ DEPENDS
+ MLIRStandardOpsIncGen
+ MLIRTosaOpsIncGen
+ MLIRTosaStructsIncGen
+ MLIRTosaInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRStandard
+ MLIRCallInterfaces
+ MLIRControlFlowInterfaces
+ MLIRSideEffectInterfaces
+ MLIRViewLikeInterface
+ )
+
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
new file mode 100644
index 000000000000..9e27cbe73714
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -0,0 +1,273 @@
+//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// This file implements the TOSA Specification:
+// https://developer.mlplatform.org/w/tosa/
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+//===----------------------------------------------------------------------===//
+// Tosa dialect structs and interface includes.
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
+#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
+
+namespace {
+//===----------------------------------------------------------------------===//
+// Dialect Function Inliner Interface.
+//===----------------------------------------------------------------------===//
+struct TosaInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+
+ //===--------------------------------------------------------------------===//
+ // Analysis Hooks.
+ //===--------------------------------------------------------------------===//
+
+ /// All operations can be inlined by default.
+ bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
+ BlockAndValueMapping &map) const final {
+ return true;
+ }
+
+ /// All regions with If and While parent operators can be inlined.
+ bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
+ BlockAndValueMapping &map) const final {
+ return (isa<tosa::IfOp>(dest->getParentOp()) ||
+ isa<tosa::WhileOp>(dest->getParentOp()));
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// TOSA control flow support.
+//===----------------------------------------------------------------------===//
+
+/// Returns the while loop body.
+Region &tosa::WhileOp::getLoopBody() { return body(); }
+
+bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
+ return !body().isAncestor(value.getParentRegion());
+}
+
+LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
+ if (ops.empty())
+ return success();
+
+ Operation *tosaWhileOp = this->getOperation();
+ for (auto *op : ops)
+ op->moveBefore(tosaWhileOp);
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Tosa dialect initialization.
+//===----------------------------------------------------------------------===//
+
+void TosaDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
+ >();
+ addInterfaces<TosaInlinerInterface>();
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+template <typename T> static LogicalResult verifyConvOp(T op) {
+ // All TOSA conv ops have an input() and weight().
+ auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
+ auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
+
+ // Must be ranked tensor types
+ if (!inputType || !weightType)
+ return failure();
+
+ auto inputQType =
+ inputType.getElementType().template isa<mlir::quant::QuantizedType>();
+ auto weightQType =
+ weightType.getElementType().template isa<mlir::quant::QuantizedType>();
+
+ // Either both must be quantized or both unquantized.
+ if (inputQType != weightQType)
+ return failure();
+
+ // Quantized type must have constructed the quantizationattr, and unquantized
+ // types should not have a quantizationattr.
+ if ((inputQType && !op.quantization_info()) ||
+ (!inputQType && op.quantization_info()))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Quantization Builders.
+//===----------------------------------------------------------------------===//
+
+/// This builder is called on all convolution operators except TransposeConv,
+/// which has specialized output shape semantics. The builder also defines the
+/// bitwidth of the output given the bit width of the input & weight content.
+void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value weight,
+ Value bias, ArrayAttr pad, ArrayAttr stride,
+ ArrayAttr dilation) {
+
+ result.addOperands({input, weight, bias});
+ result.addAttribute("pad", pad);
+ result.addAttribute("stride", stride);
+ result.addAttribute("dilation", dilation);
+
+ auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
+ if (quantAttr) {
+ result.addAttribute("quantization_info", quantAttr);
+ result.addTypes(
+ buildConvOpResultTypeInfo(builder, outputType, input, weight));
+ } else {
+ result.addTypes(outputType);
+ }
+}
+
+/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
+void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value weight,
+ Value bias, ArrayAttr outpad,
+ ArrayAttr stride, ArrayAttr dilation,
+ ArrayAttr outputShape) {
+ result.addOperands({input, weight, bias});
+ result.addAttribute("out_pad", outpad);
+ result.addAttribute("stride", stride);
+ result.addAttribute("dilation", dilation);
+ result.addAttribute("out_shape", outputShape);
+ auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
+
+ if (quantAttr) {
+ result.addAttribute("quantization_info", quantAttr);
+ result.addTypes(
+ buildConvOpResultTypeInfo(builder, outputType, input, weight));
+ } else {
+ result.addTypes(outputType);
+ }
+}
+
+/// The tosa.fully_connected op has its own builder as it does not have
+/// strides/dilation/padding.
+void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value weight,
+ Value bias) {
+
+ result.addOperands({input, weight, bias});
+ auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
+ if (quantAttr) {
+ result.addAttribute("quantization_info", quantAttr);
+ result.addTypes(
+ buildConvOpResultTypeInfo(builder, outputType, input, weight));
+ } else {
+ result.addTypes(outputType);
+ }
+}
+
+/// The tosa.matmul op is also intended to be generated where a fully_connected
+/// op must be constructed where the weight is not a constant. In this case,
+/// the fully_connected op must be expressed using matmul.
+/// TODO: Add link to the leglization document explaining this.
+void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value a, Value b) {
+ result.addOperands({a, b});
+ auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
+
+ if (quantAttr) {
+ result.addAttribute("quantization_info", quantAttr);
+
+ auto inputType = a.getType().dyn_cast<RankedTensorType>();
+ assert(inputType && "Input must be a ranked tensor type!");
+
+ auto inputQType = inputType.getElementType()
+ .dyn_cast<mlir::quant::UniformQuantizedType>();
+ assert(inputQType && "Tensor must have quantized datatype!");
+
+ unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
+
+ auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
+ assert(outputShapedType && "Output must be a ranked tensor type");
+
+ auto outputShape = outputShapedType.getShape();
+
+ IntegerType accElementType;
+ if (inputBits == 16)
+ accElementType = builder.getIntegerType(48);
+ else
+ accElementType = builder.getI32Type();
+ auto accType = RankedTensorType::get(outputShape, accElementType);
+ result.addTypes(accType);
+ } else {
+ result.addTypes(outputType);
+ }
+}
+
+/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
+/// but avg_pool operator has its own builder as it has additional parameters
+/// not part of the unary ops.
+void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input,
+ ArrayAttr kernel, ArrayAttr stride,
+ ArrayAttr pad) {
+ result.addOperands(input);
+ result.addAttribute("kernel", kernel);
+ result.addAttribute("stride", stride);
+ result.addAttribute("pad", pad);
+ auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
+ if (quantAttr)
+ result.addAttribute("quantization_info", quantAttr);
+ result.types.push_back(outputType);
+}
+
+/// This builder is called on single-parameter unary operators that have scale
+/// relationship between their input and output, expressed by the
+/// UnaryOpQuantizationAttr.
+void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input) {
+ result.addOperands(input);
+ auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
+ if (quantAttr)
+ result.addAttribute("quantization_info", quantAttr);
+ result.types.push_back(outputType);
+}
+
+/// This builder is called on TOSA pad operator that needs to create its own
+/// OptionalAttr quantization_attr parameter to scale the padding values
+/// correctly.
+void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+ Type outputType, Value input, Value paddings) {
+ result.addOperands({input, paddings});
+ auto quantAttr = buildPadOpQuantizationAttr(builder, input);
+ if (quantAttr)
+ result.addAttribute("quantization_info", quantAttr);
+ result.types.push_back(outputType);
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Definitions.
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..04acbf6425b7
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRTosaTransforms
+ TosaMakeBroadcastable.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
+
+ DEPENDS
+ MLIRTosaPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRPass
+ MLIRTosa
+ )
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
new file mode 100644
index 000000000000..95076eb155a3
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -0,0 +1,272 @@
+//===- TosaMakeBroadcastable.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Insert reshape to binary op's input if needed to match rank
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR//TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+/// There are two potential ways implementing broadcast:
+/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
+/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
+/// TBD: picking option (a) now.
+
+/// In this pass, we insert RESHAPE operators to increase the rank of the
+/// lower rank operand as a first step in the broadcasting process. The TOSA
+/// operators that support broadcast require that the rank of the operands
+/// are equal.
+
+// Examples:
+// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
+// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
+// [1, b, 1].
+// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
+// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
+// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
+// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
+// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
+// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
+// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+
+static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
+ ArrayRef<int64_t> lowerRankShape,
+ SmallVectorImpl<int64_t> &reshapeOutputShape) {
+ // Intialize new shapes with [1] * higherRank.
+ int64_t higherRank = higherRankShape.size();
+ int64_t lowerRank = lowerRankShape.size();
+
+ reshapeOutputShape.assign(higherRank, 1);
+
+ int64_t higherLeftIndex = 0;
+ int64_t higherRightIndex = higherRank;
+ int64_t lowerLeftIndex = 0;
+ int64_t lowerRightIndex = lowerRank;
+ int64_t higherRankDim, lowerRankDim;
+
+ if (lowerRightIndex != 0 && higherRightIndex != 0) {
+ // Matches lower rank shape from right dimension first, until not
+ // matching high rank shape or reaching dimension 0.
+ while (true) {
+ higherRankDim = higherRankShape[higherRightIndex - 1];
+ lowerRankDim = lowerRankShape[lowerRightIndex - 1];
+ if (higherRankDim != lowerRankDim)
+ break;
+
+ reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
+
+ if (higherRightIndex > 0)
+ higherRightIndex--;
+
+ if (lowerRightIndex > 0)
+ lowerRightIndex--;
+
+ if (higherRightIndex == 0 || lowerRightIndex == 0)
+ break;
+ }
+ if (lowerRightIndex != 0 && higherRightIndex != 0) {
+ // Matches lower rank shape from left dimension, until not matching
+ // high rank shape or reaching right index.
+ while (true) {
+ higherRankDim = higherRankShape[higherLeftIndex];
+ lowerRankDim = lowerRankShape[lowerLeftIndex];
+ if (higherRankDim != lowerRankDim)
+ break;
+
+ reshapeOutputShape[higherLeftIndex] = higherRankDim;
+
+ if (higherLeftIndex < higherRightIndex)
+ higherLeftIndex++;
+
+ if (lowerLeftIndex < lowerRightIndex)
+ lowerLeftIndex++;
+
+ if (higherLeftIndex == higherRightIndex ||
+ lowerLeftIndex == lowerRightIndex)
+ break;
+ }
+ }
+ }
+}
+
+/// Common code to reate the reshape op where necessary to make the rank of the
+/// operations equal. Returns the updated input1 and input2 for the original
+/// input. The caller is expected to use these to rewrite the original operator
+/// with the RESHAPE now in the graph.
+static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
+ RankedTensorType outputType, Value input1,
+ Value input2, Value &outInput1,
+ Value &outInput2) {
+
+ int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
+ int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
+
+ Value higherTensorValue, lowerTensorValue;
+ // return if rank already match
+ if (input1Rank == input2Rank) {
+ return 1;
+ } else if (input1Rank > input2Rank) {
+ higherTensorValue = input1;
+ lowerTensorValue = input2;
+ } else {
+ higherTensorValue = input2;
+ lowerTensorValue = input1;
+ }
+
+ ArrayRef<int64_t> outputRankShape = outputType.getShape();
+ ArrayRef<int64_t> higherRankShape =
+ higherTensorValue.getType().cast<RankedTensorType>().getShape();
+ ArrayRef<int64_t> lowerRankShape =
+ lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+
+ // outputRank == higherRank == max(input1Rank, input2Rank)
+ assert(higherRankShape.size() == outputRankShape.size());
+
+ SmallVector<int64_t, 4> reshapeOutputShape;
+
+ computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
+
+ auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+ auto reshapeOutputType = RankedTensorType::get(
+ ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+
+ auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
+ loc, reshapeOutputType, lowerTensorValue,
+ rewriter.getI64ArrayAttr(reshapeOutputShape));
+
+ if (input1Rank > input2Rank) {
+ outInput1 = higherTensorValue;
+ outInput2 = reshapeLower.getResult();
+ } else {
+ outInput1 = reshapeLower.getResult();
+ outInput2 = higherTensorValue;
+ }
+
+ return 0;
+}
+
+namespace {
+template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
+ PatternRewriter &rewriter) const {
+
+ Value input1 = tosaBinaryOp.input1();
+ Value input2 = tosaBinaryOp.input2();
+ Value output = tosaBinaryOp.getResult();
+ auto outputType = output.getType().cast<RankedTensorType>();
+
+ Value outInput1, outInput2;
+ if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
+ input1, input2, outInput1, outInput2))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
+ outInput2);
+
+ return success();
+ }
+};
+
+// The MulOp has an extra parameter 'shift' not present in other elementwise
+// binary ops, that necessitates special handling of its builder.
+template <>
+struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
+ using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
+ PatternRewriter &rewriter) const {
+
+ Value input1 = tosaBinaryOp.input1();
+ Value input2 = tosaBinaryOp.input2();
+ int32_t shift = tosaBinaryOp.shift();
+ Value output = tosaBinaryOp.getResult();
+ auto outputType = output.getType().cast<RankedTensorType>();
+
+ Value outInput1, outInput2;
+ if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
+ input1, input2, outInput1, outInput2))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
+ outInput1, outInput2, shift);
+
+ return success();
+ }
+};
+
+// The ArithmeticRightShiftOp has an extra parameter 'round' not present in
+// other elementwise binary ops, that necessitates special handling of its
+// builder.
+template <>
+struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
+ : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
+ using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
+ PatternRewriter &rewriter) const {
+
+ Value input1 = tosaBinaryOp.input1();
+ Value input2 = tosaBinaryOp.input2();
+ int32_t round = tosaBinaryOp.round();
+ Value output = tosaBinaryOp.getResult();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+ Value outInput1, outInput2;
+ if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
+ input1, input2, outInput1, outInput2))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
+ tosaBinaryOp, outputType, outInput1, outInput2, round);
+
+ return success();
+ }
+};
+} // end anonymous namespace
+
+namespace {
+/// Pass that enables broadcast by making all input arrays have the same
+/// number of dimensions. Insert RESHAPE operations to lower rank operand
+struct TosaMakeBroadcastable
+ : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
+public:
+ void runOnFunction() override {
+ auto func = getFunction();
+ OwningRewritePatternList patterns;
+ MLIRContext *ctx = func.getContext();
+ // Add the generated patterns to the list.
+ patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::SubOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::MulOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::MaximumOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::MinimumOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::EqualOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::GreaterOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
+ patterns.insert<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
+ }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
+ return std::make_unique<TosaMakeBroadcastable>();
+}
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
new file mode 100644
index 000000000000..16ddd9f7383a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -0,0 +1,350 @@
+//===- QuantUtils.cpp -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains TOSA numerical support functions and quantization
+// attribute builders.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+/// From a scale value, generates multiplier and shift values where
+/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
+/// multiplier = mantissa*2^shift for 16-bit scaling.
+void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier,
+ int32_t &shift) {
+
+ const double mantissa = std::frexp(scale, &shift);
+ auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
+
+ // Can't be greater than 1.0.
+ assert(shiftedM <= (int64_t(1) << 15) &&
+ "Shifted mantissa exceeds 16 signed bits");
+
+ if (shiftedM == (int64_t(1) << 15)) {
+ shiftedM /= 2;
+ shift++;
+ }
+
+ // TOSA expects right shift to be positive and embed (1 << 15) into right
+ // shift bits.
+ shift = (-shift) + 15;
+
+ assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
+ "Shifted mantissa exceeds 32-bit signed output type");
+
+ multiplier = static_cast<int32_t>(shiftedM);
+}
+
+/// From a scale value, generates multiplier and shift values where
+/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
+/// multiplier = mantissa*2^shift for 32-bit scaling.
+void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier,
+ int32_t &shift) {
+
+ const double mantissa = std::frexp(scale, &shift);
+ auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
+
+ // Can't be greater than 1.0.
+ assert(shiftedM <= (int64_t(1) << 31) &&
+ "Shifted mantissa exceeds 32 signed bits");
+ if (shiftedM == (int64_t(1) << 31)) {
+ shiftedM /= 2;
+ shift++;
+ }
+
+ // TOSA expects right shift to be positive, and embed (1 << 31) into right
+ // shift bits.
+ shift = (-shift) + 31;
+
+ assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
+ "Shifted mantissa exceeds 32-bit signed output type");
+
+ multiplier = static_cast<int32_t>(shiftedM);
+}
+
+/// Generates a quantized multiplier/shift from double.
+void computeMultiplierAndShift(double scale, int32_t &multiplier,
+ int32_t &shift, int32_t scaleWidth) {
+
+ switch (scaleWidth) {
+ case 16:
+ computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
+ return;
+ case 32:
+ computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
+ return;
+ default:
+ assert(0 && "Unsupported Tosa quantized_scale regime specified!");
+ }
+}
+
+#define GET_UQTYPE(input_type) \
+ ((input_type).getElementType().dyn_cast<quant::UniformQuantizedType>())
+#define GET_QTYPE(input_type) \
+ ((input_type).getElementType().dyn_cast<quant::QuantizedType>())
+
+/// Method to build ConvOpQuantizationAttr, called from
+/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
+/// input_zp: input zeropoint
+/// weight_zp: weight zeropoint.
+ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
+ Value input, Value weight) {
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto weightType = weight.getType().dyn_cast<RankedTensorType>();
+
+ if (!inputType || !weightType)
+ return nullptr;
+
+ auto inputQType = GET_UQTYPE(inputType);
+ auto weightPerTensorQType = GET_UQTYPE(weightType);
+ auto weightPerAxisQType = weightType.getElementType()
+ .dyn_cast<quant::UniformQuantizedPerAxisType>();
+
+ // Weights must be either per-tensor quantized or per-axis quantized.
+ assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
+ "Weights must be either per-tensor or per-axis quantized");
+
+ // Either all quantized or all not quantized.
+ assert(!((bool)inputQType ^
+ ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
+ "Inputs and weights must be all quantized or all not quantized");
+
+ if (inputQType) {
+
+ int64_t inputZp = inputQType.getZeroPoint();
+ int64_t weightZp = 0;
+
+ if (weightPerTensorQType) {
+ weightZp = weightPerTensorQType.getZeroPoint();
+ } else if (weightPerAxisQType) {
+ weightZp = weightPerAxisQType.getZeroPoints().front();
+ }
+
+ auto quantAttr = tosa::ConvOpQuantizationAttr::get(
+ builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp),
+ builder.getContext());
+
+ return quantAttr;
+ }
+
+ return nullptr;
+}
+
+/// Builds MatMulOpQuantizationAttr, called from
+/// MatMulOpQuantInfoBuilder:
+/// aZp: input a zeropoint
+/// bZp: input b zeropoint.
+MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
+ Value a, Value b) {
+
+ auto aType = a.getType().dyn_cast<RankedTensorType>();
+ auto bType = b.getType().dyn_cast<RankedTensorType>();
+
+ if (!aType || !bType)
+ return nullptr;
+
+ auto aQType = GET_UQTYPE(aType);
+ auto bQType = GET_UQTYPE(bType);
+
+ // A and B are either all quantized or all not quantized.
+ assert(!((bool)aQType ^ (bool)bQType) &&
+ "Matmul operands must be all quantized or all not quantized");
+
+ if (aQType) {
+
+ int64_t aZp = aQType.getZeroPoint();
+ int64_t bZp = bQType.getZeroPoint();
+
+ auto quantAttr = tosa::MatMulOpQuantizationAttr::get(
+ builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp),
+ builder.getContext());
+
+ return quantAttr;
+ }
+
+ return nullptr;
+}
+
+/// Builds UnaryOpQuantizationAttr
+/// UnaryOpQuantInfoBuilder:
+/// inputZp: input zeropoint
+/// outputZp: output zeropoint.
+UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder,
+ Value input,
+ Type outputRawType) {
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto outputType = outputRawType.dyn_cast<RankedTensorType>();
+
+ if (!inputType || !outputType)
+ return nullptr;
+
+ auto inputQType = GET_UQTYPE(inputType);
+ auto outputQType = GET_UQTYPE(outputType);
+
+ // Either all quantized or all not quantized.
+ assert(!((bool)inputQType ^ (bool)outputQType) &&
+ "Unary inputs/outputs must be all quantized or all not quantized");
+
+ if (inputQType) {
+
+ int64_t inputZp = inputQType.getZeroPoint();
+ int64_t outputZp = outputQType.getZeroPoint();
+
+ auto quantAttr = tosa::UnaryOpQuantizationAttr::get(
+ builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp),
+ builder.getContext());
+
+ return quantAttr;
+ }
+
+ return nullptr;
+}
+
+/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
+/// inputZp: input zeropoint.
+PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder,
+ Value input) {
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+
+ if (!inputType)
+ return nullptr;
+
+ auto inputQType = GET_UQTYPE(inputType);
+
+ if (inputQType) {
+
+ int64_t inputZp = inputQType.getZeroPoint();
+
+ auto quantAttr = tosa::PadOpQuantizationAttr::get(
+ builder.getI32IntegerAttr(inputZp), builder.getContext());
+
+ return quantAttr;
+ }
+
+ return nullptr;
+}
+
+/// Builds output type for a quantized ConvOp with the right bitwidth.
+/// This is called by the builder when dealing with quantized content.
+Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input,
+ Value weight) {
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto weightType = weight.getType().dyn_cast<RankedTensorType>();
+
+ assert(inputType && weightType &&
+ "Could not extract input or weight tensors from Conv op");
+
+ auto inputQType = GET_QTYPE(inputType);
+ auto weightQType = GET_QTYPE(weightType);
+
+ assert(inputQType && weightQType &&
+ "Could not extract input or weight tensor types from Conv op");
+
+ unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
+ unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
+
+ auto outputShapedType = outputType.dyn_cast<RankedTensorType>();
+ assert(outputShapedType &&
+ "Could not extract output shape type from Conv op");
+
+ auto outputShape = outputShapedType.getShape();
+
+ IntegerType accElementType;
+ if (inputBits == 16 && weightBits == 8)
+ accElementType = builder.getIntegerType(48);
+ else
+ accElementType = builder.getI32Type();
+ auto accType = RankedTensorType::get(outputShape, accElementType);
+ return accType;
+}
+
+/// Builds Tosa quantization attributes from min/max values.
+Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr,
+ Attribute maxAttr, IntegerAttr quantBits,
+ int filterQuantDim, bool isSigned,
+ BoolAttr narrowRange) {
+
+ quant::QuantizedType retType;
+
+ auto convfunc =
+ quant::ExpressedToQuantizedConverter::forInputType(inputDType);
+
+ auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
+ auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
+
+ SmallVector<double, 2> min, max;
+
+ // At least one is per-axis quantized elementsattr.
+ if (minElems || maxElems) {
+ // Must have the same number of elements.
+ if (minElems.getNumElements() != maxElems.getNumElements())
+ return {};
+ min.reserve(minElems.getNumElements());
+ max.reserve(maxElems.getNumElements());
+ for (auto i : minElems)
+ min.push_back(FloatAttr::getValueAsDouble(i));
+ for (auto i : maxElems)
+ max.push_back(FloatAttr::getValueAsDouble(i));
+ } else { // Just a single FP value.
+ auto minVal = minAttr.dyn_cast<FloatAttr>();
+ if (minVal)
+ min.push_back(minVal.getValueAsDouble());
+ else
+ return {};
+ auto maxVal = maxAttr.dyn_cast<FloatAttr>();
+ if (maxVal)
+ max.push_back(maxVal.getValueAsDouble());
+ else
+ return {};
+ }
+
+ if (min.size() == max.size()) {
+ if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
+ retType = quant::fakeQuantAttrsToType(
+ builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
+ narrowRange.getValue(), convfunc.expressedType, isSigned);
+ } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
+ auto shape = inputDType.dyn_cast<ShapedType>();
+ if (!shape)
+ return {};
+ if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
+ retType = quant::fakeQuantAttrsToType(
+ builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
+ max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
+ }
+ } else {
+ return {};
+ }
+ } else {
+ return {};
+ }
+
+ if (!retType)
+ return {};
+
+ return convfunc.convert(retType);
+}
+
+/// Builds Tosa quantization attributes from min/max values.
+TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype,
+ Attribute minAttr, Attribute maxAttr,
+ IntegerAttr quantBits, int filterQuantDim,
+ bool isSigned, BoolAttr narrowRange) {
+
+ return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
+ maxAttr, quantBits, filterQuantDim,
+ isSigned, narrowRange));
+}
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
new file mode 100644
index 000000000000..98d2352b739a
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: broadcast0
+func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-NOT: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast1
+func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32>
+ return %0 : tensor<2x1xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast2
+func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
+ return %0 : tensor<2x1xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast3
+func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32>
+ return %0 : tensor<2x1x1x1xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast4
+func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32>
+ return %0 : tensor<1x1x1x2xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast5
+func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32>
+ return %0 : tensor<1x1x2x1xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast6
+func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast7
+func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32>
+ return %0 : tensor<17x16x1x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast8
+func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast9
+func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast10
+func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast13
+func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast14
+func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32>
+ return %0 : tensor<17x16x1x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast15
+func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast16
+func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast17
+func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> {
+ // CHECK: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast18
+func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> {
+ // CHECK: add
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32>
+ return %0 : tensor<14x15xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_mul
+func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
+ // CHECK: reshape
+ %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
+ return %0 : tensor<17x16x15x14xi32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_arithmetic_right_shift
+func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
+ // CHECK: reshape
+ %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
+ return %0 : tensor<17x16x15x14xi32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_scalar
+func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
+ // CHECK-NEXT: reshape
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
+ return %0 : tensor<17x16x15x14xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
new file mode 100644
index 000000000000..e6c1b8c6b450
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
+// CHECK-LABEL: argmax
+func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
+ %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<?xf32>) -> tensor<?xi32>
+ return %0 : tensor<?xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/inlining.mlir b/mlir/test/Dialect/Tosa/inlining.mlir
new file mode 100644
index 000000000000..363358b0781b
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/inlining.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -inline | FileCheck %s
+
+// These tests verify that regions with operations from TOSA dialect
+// can be inlined.
+
+// CHECK-LABEL: func @inlined_if_fn
+// Check that both the calls and the functions are eliminated after inlining:
+// CHECK-NOT: @add
+// CHECK-NOT: @sub
+func @inlined_if_fn(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
+ %1 = call @add(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "tosa.yield"(%1) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
+ %1 = call @sub(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "tosa.yield"(%1) : (tensor<f32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+func @add(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> attributes {sym_visibility = "private"} {
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+func @sub(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> attributes {sym_visibility = "private"} {
+ %0 = "tosa.sub"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func @inlined_while_fn
+func @inlined_while_fn(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> tensor<10xi32> {
+ // Check that calls are inlined and functions eliminated:
+ // CHECK-NOT: @while
+ %1:4 = "tosa.while_loop"(%arg0, %arg1, %arg2, %arg3) ( {
+ ^bb0(%arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>, %arg7: tensor<10xi32>): // no predecessors
+ %2 = call @while_cond_40(%arg4, %arg5, %arg6, %arg7) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) -> tensor<i1>
+ "tosa.yield"(%2) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg4: tensor<i32>, %arg5: tensor<i32>, %arg6: tensor<i32>, %arg7: tensor<10xi32>): // no predecessors
+ %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>)
+ "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) -> ()
+ }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>)
+ return %1#3 : tensor<10xi32>
+}
+func @while_body_50(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>) attributes {sym_visibility = "private"} {
+ %1 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor<i32>) -> tensor<10xi32>
+ return %1, %arg1, %arg2, %2: tensor<i32>, tensor<i32>, tensor<i32>, tensor<10xi32>
+}
+func @while_cond_40(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10xi32>) -> tensor<i1> attributes {sym_visibility = "private"} {
+ %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %1 = "tosa.logical_not"(%0) : (tensor<i1>) -> tensor<i1>
+ return %1 : tensor<i1>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
new file mode 100644
index 000000000000..f22e6fc7527d
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -0,0 +1,512 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+
+// -----
+// CHECK-LABEL: argmax
+func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
+ %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<14x19xf32>) -> tensor<14xi32>
+ return %0 : tensor<14xi32>
+}
+
+// -----
+// CHECK-LABEL: avg_pool2d
+func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
+ %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
+ return %0 : tensor<1x7x7x9xf32>
+}
+
+// -----
+// CHECK-LABEL: conv2d
+func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+ return %0 : tensor<1x4x4x8xf32>
+}
+
+// -----
+// CHECK-LABEL: depthwise_conv2d
+func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
+ %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
+ return %2 : tensor<1x4x4x8xf32>
+}
+
+// -----
+// CHECK-LABEL: fully_connected
+func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
+ %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32>
+ return %0 : tensor<14x28xf32>
+}
+
+// -----
+// CHECK-LABEL: test_matmul
+func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> {
+ %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32>
+ return %0 : tensor<14x28xf32>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d
+func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+/// CHECK-LABEL: transpose_conv2d
+func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+// CHECK-LABEL: clamp
+func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.clamp"(%arg0) {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: relu
+func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.reluN"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+
+// -----
+// CHECK-LABEL: sigmoid
+func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: tanh
+func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.tanh"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: add
+func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: arithmetic_right_shift
+func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = false } : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+
+// -----
+// CHECK-LABEL: bitwise_and
+func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_or
+func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_xor
+func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: logical_and
+func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> {
+ %0 = "tosa.logical_and"(%arg0, %arg1) : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: logical_left_shift
+func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: logical_right_shift
+func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: logical_or
+func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ %0 = "tosa.logical_or"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: logical_xor
+func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
+ %0 = "tosa.logical_xor"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: maximum
+func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: minimum
+func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: mul
+func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: pow
+func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: sub
+func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: table
+func @main(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform<i16:f32, 1.0:0>>) -> tensor<64x!quant.uniform<i16:f32, 1.0:0>> {
+ %0 = "tosa.table"(%arg0, %arg1) : (tensor<64xi32>, tensor<513x!quant.uniform<i16:f32, 1.0:0>>) -> tensor<64x!quant.uniform<i16:f32, 1.0:0>>
+ return %0 : tensor<64x!quant.uniform<i16:f32, 1.0:0>>
+}
+
+// -----
+// CHECK-LABEL: abs
+func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: bitwise_not
+func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> {
+ %0 = "tosa.bitwise_not"(%arg0) : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32>
+ return %0 : tensor<13x21x1xi32>
+}
+
+// -----
+// CHECK-LABEL: ceil
+func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: clz
+func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.clz"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: exp
+func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: floor
+func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.floor"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: log
+func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: logical_not
+func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> {
+ %0 = "tosa.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1>
+ return %0 : tensor<1x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: negate
+func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.negate"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reciprocal
+func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.reciprocal"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: rsqrt
+func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: select
+func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+
+// -----
+// CHECK-LABEL: equal
+func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tosa.equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: greater
+func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tosa.greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: greater_equal
+func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> {
+ %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1>
+ return %0 : tensor<13x21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: reduce_all
+func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
+ %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+ %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
+ return %1 : tensor<21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: reduce_any
+func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> {
+ %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1>
+ %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1>
+ return %1 : tensor<21x3xi1>
+}
+
+// -----
+// CHECK-LABEL: reduce_max
+func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+ %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ return %1 : tensor<21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reduce_min
+func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+ %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ return %1 : tensor<21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reduce_product
+func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+ %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ return %1 : tensor<21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reduce_sum
+func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> {
+ %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32>
+ %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32>
+ return %1 : tensor<21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: concat
+func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> {
+ %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32>
+ return %0 : tensor<26x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: pad
+func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: reshape
+func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> {
+ %0 = "tosa.reshape"(%arg0) {new_shape = [1, 819]} : (tensor<13x21x3xf32>) -> tensor<1x819xf32>
+ return %0 : tensor<1x819xf32>
+}
+
+// -----
+// CHECK-LABEL: reverse
+func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: slice
+func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
+ %0 = "tosa.slice"(%arg0) {start = [6, 8, 0], size = [4, 11, 1]} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32>
+ return %0 : tensor<4x11x1xf32>
+}
+
+// -----
+// CHECK-LABEL: tile
+func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
+ %0 = "tosa.tile"(%arg0) {multiples = [3, 1, 2]} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32>
+ return %0 : tensor<39x21x6xf32>
+}
+
+// -----
+// CHECK-LABEL: transpose
+func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> {
+ %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
+ return %1 : tensor<3x13x21xf32>
+}
+
+// -----
+// CHECK-LABEL: gather
+func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<26xi32>) -> tensor<26x21x3xi32> {
+ %0 = "tosa.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>) -> tensor<26x21x3xi32>
+ return %0 : tensor<26x21x3xi32>
+}
+
+// Test TBD
+// DISABLED-CHECK-LABEL: resize
+//func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
+// %0 = "tosa.const"() {value = dense<64> : tensor<2xi32>} : () -> tensor<2xi32>
+// %1 = "tosa.resize"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x32x32x8xf32>, tensor<2xi32>) -> tensor<1x64x64x8xf32>
+// return %1 : tensor<1x64x64x8xf32>
+//}
+
+// -----
+// CHECK-LABEL: cast
+func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+ %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: cast2
+func @test_cast2(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<u8:f32, 0.078431375324726104:128>> {
+ %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<u8:f32, 0.078431375324726104:128>>
+ return %0 : tensor<13x21x3x!quant.uniform<u8:f32, 0.078431375324726104:128>>
+}
+
+// -----
+// CHECK-LABEL: cast3
+func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>> {
+ %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+ return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
+}
+
+// -----
+// CHECK-LABEL: rescale
+func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+ %0 = "tosa.rescale"(%arg0) {double_round = false, input_zp = 127 : i32, multiplier = [1073741824 : i32], output_zp = -1 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+}
+
+// -----
+// CHECK-LABEL: const
+func @test_const(%arg0 : index) -> tensor<4xi32> {
+ %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
+// -----
+// CHECK-LABEL: identity
+func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ %0 = "tosa.identity"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+// CHECK-LABEL: identityn
+func @test_identityn(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> {
+ %0:2 = "tosa.identityn"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>)
+ return %0#0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: placeholder
+func @test_placeholder() -> tensor<1xi32> {
+ %0 = "tosa.placeholder"() : () -> tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: cond_if
+func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
+ %1 = "tosa.add"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "tosa.yield"(%1) : (tensor<f32>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
+ %1 = "tosa.sub"(%arg3, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ "tosa.yield"(%1) : (tensor<f32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+// CHECK-LABEL: while_loop
+func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %1:3 = "tosa.while_loop"(%0, %0, %arg0) ( {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>): // no predecessors
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ "tosa.yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>): // no predecessors
+ %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ %4 = "tosa.reshape"(%2) {new_shape = [1]} : (tensor<i32>) -> tensor<1xi32>
+ %5 = "tosa.add"(%arg4, %4) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+ %6 = "tosa.add"(%arg2, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "tosa.yield"(%6, %3, %5) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> ()
+ }) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>)
+ return
+}
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
new file mode 100644
index 000000000000..2124defb6b2a
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s
+
+// -----
+// CHECK-LABEL: test_build_qtype
+func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>> {
+ // CHECK: tosa.negate
+ %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+ return %0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32, 0.015680249780416489:128>>
+}
+
+// -----
+// CHECK-LABEL: test_build_mult_and_shift
+func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>> {
+ // CHECK: tosa.conv2d
+ %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 2, 2], dilation = [2, 1], stride = [1, 1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+ return %0 : tensor<1x32x32x16x!quant.uniform<i8:f32, 0.078431375324726104>>
+
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 9008b86314be..b220d0d81632 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(Affine)
add_subdirectory(SPIRV)
add_subdirectory(Test)
+add_subdirectory(Tosa)
diff --git a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt
new file mode 100644
index 000000000000..8d8a07418957
--- /dev/null
+++ b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRTosaTestPasses
+ TosaTestPasses.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms
+
+ DEPENDS
+ MLIRTosaPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRPass
+ MLIRTosa
+ )
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
new file mode 100644
index 000000000000..b9728e236013
--- /dev/null
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -0,0 +1,197 @@
+//===- TosaTestPasses.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Test passes to exercise TOSA helper functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR//TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+// This transformation converts quantized uint8 to quantized int8. The
+// construction of the new type invokes buildQTypeFromMinMax. Extracted from
+// TOSA legalization infrastructure.
+struct ConvertTosaNegateOp : public RewritePattern {
+ explicit ConvertTosaNegateOp(MLIRContext *context)
+ : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
+LogicalResult
+ConvertTosaNegateOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+
+ auto tosaNegateOp = cast<tosa::NegateOp>(op);
+
+ auto inputType =
+ tosaNegateOp.input1().getType().dyn_cast<mlir::RankedTensorType>();
+ // skip if input is not ranked tensor type
+ if (!inputType)
+ return failure();
+
+ // skip if it's not ranked tensor type.
+ auto outputType =
+ tosaNegateOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ if (!outputType)
+ return failure();
+
+ // skip if output is not per-tensor quantized type.
+ auto outputElementType =
+ outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ if (!outputElementType)
+ return failure();
+
+ // skip if output is not uint8.
+ if (outputElementType.isSigned() ||
+ outputElementType.getStorageTypeIntegralWidth() != 8)
+ return failure();
+
+ double typeRangeMin = double(outputElementType.getStorageTypeMin() -
+ outputElementType.getZeroPoint()) *
+ outputElementType.getScale();
+ double typeRangeMax = double(outputElementType.getStorageTypeMax() -
+ outputElementType.getZeroPoint()) *
+ outputElementType.getScale();
+ bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false;
+
+ auto dstQConstType = RankedTensorType::get(
+ outputType.getShape(),
+ buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(),
+ rewriter.getF64FloatAttr(typeRangeMin),
+ rewriter.getF64FloatAttr(typeRangeMax),
+ rewriter.getI32IntegerAttr(
+ outputElementType.getStorageTypeIntegralWidth()),
+ 0, true /* signed */,
+ rewriter.getBoolAttr(narrow_range)));
+
+ ElementsAttr inputElems;
+ if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems)))
+ return failure();
+
+ auto newConstOp =
+ rewriter.create<tosa::ConstOp>(op->getLoc(), dstQConstType, inputElems);
+ auto newNegateOp = rewriter.create<tosa::NegateOp>(
+ op->getLoc(), dstQConstType, newConstOp.getResult());
+
+ rewriter.replaceOp(op, {newNegateOp.getResult()});
+ return success();
+}
+
+// This transformation modifies the quantized output of a test conv2d input and
+// appends a TOSA rescale after it. The rescale op requires the invocation of
+// computeMultiplierAndShift. From TOSA legalization infrastructure.
+struct ConvertTosaConv2DOp : public RewritePattern {
+ explicit ConvertTosaConv2DOp(MLIRContext *context)
+ : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
+LogicalResult
+ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+
+ auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
+
+ auto inputType =
+ tosaConv2DOp.input().getType().dyn_cast<mlir::RankedTensorType>();
+
+ // skip if input is not ranked tensor type
+ if (!inputType)
+ return failure();
+
+ auto weightType =
+ tosaConv2DOp.weight().getType().dyn_cast<mlir::RankedTensorType>();
+
+ // skip if wt is not ranked tensor type
+ if (!weightType)
+ return failure();
+
+ // skip if it's not ranked tensor type.
+ auto outputType =
+ tosaConv2DOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ if (!outputType)
+ return failure();
+
+ auto inputQType =
+ inputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto weightQType =
+ weightType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto outputQType =
+ outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+
+ // Works on quantized type only.
+ if (!(inputQType && weightQType && outputQType))
+ return failure();
+
+ auto newTosaConv2DOpType =
+ RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
+
+ auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
+ op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(),
+ tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(),
+ tosaConv2DOp.stride(), tosaConv2DOp.dilation());
+
+ // Create rescale to quantized type
+ double inputScale = inputQType.getScale();
+ double weightScale = weightQType.getScale();
+ double outputScale = outputQType.getScale();
+ int64_t outputZp = outputQType.getZeroPoint();
+
+ double opTensorScale = (inputScale * weightScale) / outputScale;
+
+ int32_t multiplier;
+ int32_t shift;
+
+ // Obtain the quantized scale = multiplier and shift.
+ computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);
+
+ auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
+ op->getLoc(), outputType, newTosaConv2DOp.getResult(),
+ rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
+ rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
+ rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
+ rewriter.getBoolAttr(false));
+
+ rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
+ return success();
+}
+
+namespace {
+
+struct TosaTestQuantUtilAPI
+ : public TosaTestQuantUtilsBase<TosaTestQuantUtilAPI> {
+ void runOnFunction() override;
+};
+
+void TosaTestQuantUtilAPI::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto *ctx = &getContext();
+ auto func = getFunction();
+
+ patterns.insert<ConvertTosaNegateOp>(ctx);
+ patterns.insert<ConvertTosaConv2DOp>(ctx);
+ applyPatternsAndFoldGreedily(func, std::move(patterns));
+}
+
+} // anonymous namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaTestQuantUtilAPIPass() {
+ return std::make_unique<TosaTestQuantUtilAPI>();
+}
More information about the Mlir-commits
mailing list