[llvm] [mlir] [TOSA] Add Tosa_Shape type and ConstShapeOp (PR #122547)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 14 08:26:48 PST 2025
https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/122547
>From f8a63cef9761bd3bca55a8294d019b0acf9e2f71 Mon Sep 17 00:00:00 2001
From: Jerry Ge <Jerry.Ge at arm.com>
Date: Wed, 8 Jan 2025 16:23:25 -0800
Subject: [PATCH] [TOSA] Add Tosa_Shape type and ConstShapeOp
Adds:
1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
rank-4 shape values (size-4 array of index values)
2. const_shape operator
3. trait TosaShapeOperator, added to tosa shape operators, and a
verifier that all operands and results of operator are tosa shapes
4. trait TosaResolvableShapeOperands, added to all tosa operators, and
a verifier that every tosa shape operand is produced by a tosa shape
operator (indicated by trait TosaShapeOperator)
5. trait TosaShapeOperatorWithSameRanks, added to
Tosa_ElementwiseShapeOp and a verifier that all operands and result
shapes have same ranks
5. changed TileOp's multiples from attribute to input, of !tosa.shape
type.
6. add folder for tosa ConstShape operator
Signed-off-by: Jerry Ge <Jerry.Ge at arm.com>
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
---
.../mlir/Dialect/Tosa/IR/CMakeLists.txt | 3 +-
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 12 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 41 +++++
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 +-
.../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 77 ++++++++++
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 65 ++++++++
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 4 +-
.../TosaToLinalg/TosaToLinalgPass.cpp | 1 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 19 ++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 141 ++++++++++++++++--
.../Tosa/Transforms/TosaValidation.cpp | 2 +
.../TosaToLinalg/tosa-to-linalg.mlir | 15 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 9 +-
mlir/test/Dialect/Tosa/invalid.mlir | 33 +++-
mlir/test/Dialect/Tosa/level_check.mlir | 4 +-
mlir/test/Dialect/Tosa/ops.mlir | 10 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 6 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 8 +
18 files changed, 425 insertions(+), 33 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 1ee105f0ceb98b..cc8d5ed9b00449 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,6 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
+mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)
@@ -10,4 +12,3 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
-
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index d3f12c34421b06..47cda3c9f481ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
+ let useDefaultTypePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -217,12 +218,21 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
let cppNamespace = "mlir::OpTrait::tosa";
}
+//===----------------------------------------------------------------------===//
+// TOSA Operator Trait.
+//===----------------------------------------------------------------------===//
+// Op operands with TOSA shape types must be compile time resolvable
+def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//
class Tosa_Op<string mnemonic, list<Trait> traits = []> :
- Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
+ Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
+ TosaResolvableShapeOperands])> {
}
class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..e4f5d09064cd75 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -90,14 +90,55 @@ template <typename ConcreteType>
class TosaElementwiseOperator
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};
+LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
+/// This class verifies that tosa shape operands are compile time resolvable
+template <typename ConcreteType>
+class TosaResolvableShapeOperands
+ : public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return verifyTosaResolvableShapeOperands(op);
+ }
+};
+
+LogicalResult verifyTosaShapeOperator(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return verifyTosaShapeOperator(op);
+ }
+};
+
+LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperatorWithSameRanks
+ : public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return verifyTosaShapeOperatorWithSameRanks(op);
+ }
+};
+
} // namespace tosa
} // namespace OpTrait
+namespace tosa {
+
+bool isa_tosa_shape_type(mlir::Type t);
+
+} // namespace tosa
+
} // namespace mlir
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..e1efa7a3001b9f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1689,12 +1689,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
let arguments = (ins
Tosa_Tensor:$input1,
- DenseI64ArrayAttr:$multiples);
+ Tosa_Shape:$multiples);
let results = (outs
Tosa_Tensor:$output
);
+ let extraClassDeclaration = [{
+ LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
+ }];
+
let hasFolder = 1;
let hasVerifier = 1;
}
@@ -2106,4 +2110,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
+include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
+
#endif // TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
new file mode 100644
index 00000000000000..597dc32e84402f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -0,0 +1,77 @@
+//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines shape operators for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_SHAPE_OPS
+#define TOSA_SHAPE_OPS
+
+include "mlir/IR/OpBase.td"
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
+
+include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
+// Op trait: operator has operands and results with TOSA shape type
+def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
+ : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
+
+ let hasFolder = 1;
+}
+
+// op trait: shape operator has same ranks for operands and results
+def TosaShapeOperatorWithSameRanks
+ : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
+ let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
+ : Tosa_ShapeOp<mnemonic,
+ !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
+}
+
+
+//===----------------------------------------------------------------------===//
+// Operator: ConstShape
+//===----------------------------------------------------------------------===//
+def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
+ let summary = "Constant Shape op.";
+
+ let description = [{
+ A node containing constant data for use as the input to an shape operation. May
+ hold data only in index data type.
+
+ Example:
+
+ ```mlir
+ // Generic form
+ %out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+ ```
+ }];
+
+ let arguments = (ins IndexElementsAttr : $value);
+
+ let results = (outs Tosa_Shape : $output);
+
+ let hasVerifier = 1;
+}
+
+#endif // TOSA_SHAPE_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index d3cc6e92bac227..13325fb0ab9a20 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -13,8 +13,11 @@
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE
+include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
@@ -215,4 +218,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
+//===----------------------------------------------------------------------===//
+// Tosa Type Definitions.
+//===----------------------------------------------------------------------===//
+
+// The base class for Tosa dialect types.
+class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<Tosa_Dialect, name, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeType
+//===----------------------------------------------------------------------===//
+def Tosa_Shape : Tosa_Type<"shape", "shape"> {
+ let summary = "Shape with static rank and Index element type";
+ let description = [{
+ Syntax:
+
+ ``` shape - type :: = `shape` `<` rank `>`
+ ``` Values with shape type represents a shape with a fixed rank and a list
+ of dimensions
+ .Rank must be zero or a positive integer
+ .Each dimension is represented by the builtin
+ Index type.
+
+ Examples:
+
+ ```mlir
+ // Shape with rank of four, for example, [1, 1, 8, 16]:
+ !tosa
+ .shape<4>
+
+ // Shape with rank of one, for example, [16]:
+ !tosa
+ .shape<1>
+
+ // Shape with rank zero, for example, [] (i.e., shape of scalar values):
+ !tosa.shape<0>
+ ```
+ }];
+ let parameters = (ins "int" : $rank);
+ let builders = [TypeBuilder<(ins "int" : $rank)>];
+ let assemblyFormat = "`<` $rank `>`";
+
+ let genVerifyDecl = 1;
+}
+
+def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
+
+// Whether a Tosa Shape type has a rank equal to the specified rank.
+class IsTosaShapeOfRankPred<int rank> : And<[
+ IsTosaShapeType,
+ CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
+]>;
+
+class TosaShapeOfRank<int rank>
+ : Type<IsTosaShapeOfRankPred<rank>, "Tosa shape type of rank " #rank>;
+
+def Rank1TosaShape : TosaShapeOfRank<1>;
+def Rank2TosaShape : TosaShapeOfRank<2>;
+def Rank4TosaShape : TosaShapeOfRank<4>;
+
#endif // TOSA_TYPES_BASE
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1d7ead16e8b631..9295afd36e3ab1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
- ArrayRef<int64_t> multiples = op.getMultiples();
+ SmallVector<int64_t> multiples;
+ if (failed(op.getConstantMultiples(multiples)))
+ return failure();
// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 06a7262c467421..8dfa55bef74fc4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::ApplyScaleOp>();
target.addLegalOp<tosa::IfOp>();
target.addLegalOp<tosa::ConstOp>();
+ target.addLegalOp<tosa::ConstShapeOp>();
target.addLegalOp<tosa::WhileOp>();
target.addLegalOp<tosa::ConcatOp>();
target.addLegalOp<tosa::SliceOp>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f51c3dbce6eefe..f7a596f1ccb192 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
+OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
+
#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
- bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
- if (allOnes && getInput1().getType() == getType())
- return getInput1();
+ if (getInput1().getType() == getType()) {
+ if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getMultiples())) {
+ if (multiples.isSplat() &&
+ multiples.getSplatValue<APInt>().getSExtValue() == 1)
+ return getInput1();
+ if (auto int_array_attr =
+ llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
+ if (llvm::all_of(int_array_attr.getValues<APInt>(),
+ [](APInt v) { return v.getSExtValue() == 1; }))
+ return getInput1();
+ }
+ }
+ }
return {};
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..83cf4a9415fe68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -130,6 +130,10 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
//===----------------------------------------------------------------------===//
void TosaDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
+ >();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
@@ -153,6 +157,10 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
// Tosa dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
+ if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
+ return builder.create<tosa::ConstShapeOp>(
+ loc, type, llvm::cast<DenseIntElementsAttr>(value));
+ }
if (llvm::isa<ElementsAttr>(value))
return builder.create<tosa::ConstOp>(loc, type,
llvm::cast<ElementsAttr>(value));
@@ -962,11 +970,30 @@ LogicalResult tosa::TableOp::verify() {
return success();
}
+LogicalResult
+tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
+ // Multiples must be constants.
+ DenseIntElementsAttr multiplesAttr;
+ if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
+ return failure();
+ multiples = llvm::to_vector(
+ llvm::map_range(multiplesAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+ return success();
+}
+
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ArrayRef<int64_t> multiples = adaptor.getMultiples();
+ DenseIntElementsAttr multiplesAttr;
+ if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
+ return failure();
+
+ SmallVector<int64_t> multiples = llvm::to_vector(
+ llvm::map_range(multiplesAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+
ShapeAdaptor inputShape(adaptor.getInput1().getType());
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
@@ -992,20 +1019,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
LogicalResult tosa::TileOp::verify() {
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
- auto multiples = getMultiples();
+
+ shapeType multiplesType =
+ llvm::cast<tosa::shapeType>(getMultiples().getType());
+
+ auto multiplesRank = multiplesType.getRank();
if (inputType.hasRank()) {
- if (static_cast<size_t>(inputType.getRank()) != multiples.size())
- return emitOpError("expect 'multiples' array to have length ")
- << inputType.getRank() << " but got " << multiples.size() << ".";
+ if (inputType.getRank() != multiplesRank)
+ return emitOpError("expect 'multiples' to have rank ")
+ << inputType.getRank() << " but got " << multiplesRank << ".";
if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
return emitOpError("expect same input and output tensor rank.");
- } else if (outputType.hasRank() &&
- static_cast<size_t>(outputType.getRank()) != multiples.size())
+ } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
return emitOpError("expect 'multiples' array to have length ")
- << outputType.getRank() << " but got " << multiples.size() << ".";
+ << outputType.getRank() << " but got " << multiplesRank << ".";
- if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
+ SmallVector<int64_t> multiples;
+ if (getConstantMultiples(multiples).succeeded() &&
+ llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
return emitOpError(
"expect element of 'multiples' to be positive integer or -1.");
@@ -2146,6 +2178,91 @@ void WhileOp::print(OpAsmPrinter &parser) {
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
}
+//===----------------------------------------------------------------------===//
+// TOSA Shape and Shape Operators Helper functions.
+//===----------------------------------------------------------------------===//
+
+bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
+ return mlir::isa<tosa::shapeType>(t);
+}
+
+LogicalResult
+mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
+ int rank) {
+ if (rank < 0)
+ return emitError() << "invalid rank (must be >= 0): " << rank;
+ return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
+ for (auto v : op->getOperands()) {
+ if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
+ Operation *definingOp = v.getDefiningOp();
+ if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
+ return op->emitOpError("shape operand is not compile time resolvable");
+ }
+ }
+ }
+ return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
+ for (auto type : op->getOperandTypes()) {
+ if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+ return op->emitOpError("must have operands with tosa shape type");
+ }
+ }
+ for (auto type : op->getResultTypes()) {
+ if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+ return op->emitOpError("must have result with tosa shape type");
+ }
+ }
+ return success();
+}
+
+LogicalResult
+OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
+ failed(verifyTosaShapeOperator(op)))
+ return failure();
+
+ // delegate function that returns rank of shape type
+ auto getRank = [](const Type type) {
+ return mlir::cast<mlir::tosa::shapeType>(type).getRank();
+ };
+ auto operandTypes = op->getOperandTypes();
+ auto resultTypes = op->getResultTypes();
+
+ auto rank = getRank(*op->getOperandTypes().begin());
+ for (auto type : operandTypes) {
+ if (getRank(type) != rank) {
+ return op->emitOpError("operands don't have matching ranks");
+ }
+ }
+ for (auto type : resultTypes) {
+ if (getRank(type) != rank) {
+ return op->emitOpError("result shape has different rank than operands");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Shape Operators verify functions.
+//===----------------------------------------------------------------------===//
+
+LogicalResult tosa::ConstShapeOp::verify() {
+ // check that number of elements in value attr equal to rank of result shape
+ auto count = getValue().getNumElements();
+ auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
+ if (!(count == rank || (count == 1 && rank == 0))) {
+ return emitOpError("expect number of elements in attribute value (")
+ << count << ") to be equal to the rank (" << rank
+ << ") for the result shape type";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//
@@ -2153,6 +2270,12 @@ void WhileOp::print(OpAsmPrinter &parser) {
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
+//===----------------------------------------------------------------------===//
+// TOSA Type Definitions.
+//===----------------------------------------------------------------------===//
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
+
//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 8588c878bfe4f8..a49870687fdc60 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -536,6 +536,8 @@ bool TosaValidation::isValidElementType(Type type) {
return true;
}
}
+ } else if (mlir::isa<tosa::shapeType>(type)) {
+ return true;
}
return false;
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c840fb8648d7b7..1d235092b71d55 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1378,21 +1378,24 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
// CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 4, 3>}
- %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1>} : (tensor<2x3xi8>) -> tensor<4x3xi8>
+ %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8>
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
// CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 2, 6>}
- %1 = tosa.tile %arg0 {multiples = array<i64: 1, 2>} : (tensor<2x3xi8>) -> tensor<2x6xi8>
+ %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8>
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
// CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
// CHECK: linalg.yield %[[ARG1]] : i8
// CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 10, 21>}
- %2 = tosa.tile %arg0 {multiples = array<i64: 5, 7>} : (tensor<2x3xi8>) -> tensor<10x21xi8>
+ %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<10x21xi8>
return
}
@@ -1412,7 +1415,8 @@ func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK: linalg.yield %[[ARG1]] : i8
// CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: -9223372036854775808, 3>}
- %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1>} : (tensor<?x3xi8>) -> tensor<?x3xi8>
+ %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = tosa.tile %arg0, %cst21: (tensor<?x3xi8>, !tosa.shape<2>) -> tensor<?x3xi8>
return
}
@@ -1432,7 +1436,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK: linalg.yield %[[ARG1]] : i8
// CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: 2, -9223372036854775808>}
- %0 = tosa.tile %arg0 {multiples = array<i64: 2, -1>} : (tensor<2x3xi8>) -> tensor<2x?xi8>
+ %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x?xi8>
return
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 60121bb0ea2f12..889e2eda9e5b84 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -588,7 +588,8 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
// CHECK-LABEL: @tile_fold
func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
- %0 = tosa.tile %arg0 { multiples = array<i64: 1, 1> }: (tensor<3x4xf32>) -> tensor<3x4xf32>
+ %cst = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
@@ -597,7 +598,8 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK-LABEL: @tile_nofold
func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
// CHECK: tosa.tile
- %0 = tosa.tile %arg0 { multiples = array<i64: 1, 2> }: (tensor<3x4xf32>) -> tensor<3x8xf32>
+ %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32>
return %0 : tensor<3x8xf32>
}
@@ -763,7 +765,8 @@ func.func @fold_reduce_rank_zero() {
func.func nested @fold_tile_rank_zero() -> tensor<i32> {
// CHECK-NOT: tosa.tile
%0 = tensor.empty() : tensor<i32>
- %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+ %cst = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+ %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
return %1 : tensor<i32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a6d57f8a2f61f3..cc7fd009f01fa6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -621,8 +621,9 @@ func.func @test_slice_invalid_size() {
func.func @test_tile_invalid_multiples() {
%0 = tensor.empty() : tensor<4x31x31xf32>
- // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
- %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+ %cst = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1>
+ // expected-error at +1 {{'tosa.tile' op expect 'multiples' to have rank 3 but got 1.}}
+ %1 = tosa.tile %0, %cst: (tensor<4x31x31xf32>, !tosa.shape<1>) -> tensor<4x31x31xf32>
return
}
@@ -630,8 +631,9 @@ func.func @test_tile_invalid_multiples() {
func.func @test_tile_invalid_multiples_value() {
%0 = tensor.empty() : tensor<4x31xf32>
+ %multiples = tosa.const_shape { value = dense<[2, -2]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.tile' op expect element of 'multiples' to be positive integer or -1.}}
- %1 = tosa.tile %0 {multiples = array<i64: 2, -2>} : (tensor<4x31xf32>) -> tensor<4x31xf32>
+ %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31xf32>
return
}
@@ -639,8 +641,9 @@ func.func @test_tile_invalid_multiples_value() {
func.func @test_tile_io_rank_mismatch() {
%0 = tensor.empty() : tensor<4x31xf32>
+ %multiples = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error at +1 {{'tosa.tile' op expect same input and output tensor rank.}}
- %1 = tosa.tile %0 {multiples = array<i64: 2, 2>} : (tensor<4x31xf32>) -> tensor<4x31x31xf32>
+ %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31x31xf32>
return
}
@@ -993,3 +996,25 @@ func.func @test_non_tosa_ops() {
%2 = tensor.empty(%0) : tensor<?x27xi64>
return
}
+
+// -----
+
+// expected-error at +1 {{invalid rank (must be >= 0): -1}}
+func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> {
+ return %arg0 : !tosa.shape<-1>
+}
+
+// -----
+func.func @test_const_shape() -> !tosa.shape<4> {
+ // expected-error at +1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}}
+ %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4>
+ return %cst : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_const_shape_value() -> !tosa.shape<5> {
+ // expected-error at +1 {{'tosa.const_shape' op expect number of elements in attribute value (4) to be equal to the rank (5) for the result shape type}}
+ %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5>
+ return %cst : !tosa.shape<5>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index ba8ed8a1e5f50f..0fe35d88f0e73a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -95,8 +95,9 @@ func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11
// -----
// CHECK-LABEL: tile
func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> {
+ %cst = tosa.const_shape { value = dense<[1, 1, 1, 1, 3, 1, 2]> : tensor<7xindex> } : () -> !tosa.shape<7>
// expected-error at +1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = "tosa.tile"(%arg0) {multiples = array<i64: 1, 1, 1, 1, 3, 1, 2>} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32>
+ %0 = tosa.tile %arg0, %cst : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x39x21x6xf32>
return %0 : tensor<1x1x1x1x39x21x6xf32>
}
@@ -740,4 +741,3 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
(tensor<*xf32>) -> tensor<*xf32>
return
}
-
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f2e1cff72ab281..690e208af1e5f9 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -562,7 +562,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
// -----
// CHECK-LABEL: tile
func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
- %0 = tosa.tile %arg0 {multiples = array<i64: 3, 1, 2>} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32>
+ %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
return %0 : tensor<39x21x6xf32>
}
@@ -692,3 +693,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
%0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
return %0 : tensor<10xi32>
}
+
+// -----
+// CHECK-LABEL: const_shape
+func.func @test_const_shape() -> !tosa.shape<4> {
+ %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+ return %cst : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 82f3e22a387221..f4da66ef561b26 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -543,8 +543,10 @@ func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK-LABEL: @test_tile
func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
- // CHECK: tosa.tile %arg0 {multiples = array<i64: 2, 1, 5>} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32>
- %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1, 5>} : (tensor<2x3x?xi32>) -> tensor<?x?x?xi32>
+ // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x3x?xi32>
+ %cst = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
return
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5c2a77ca67fd4c..d3f3697903d722 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12115,6 +12115,14 @@ gentbl_cc_library(
["-gen-dialect-defs"],
"include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc",
),
+ (
+ ["-gen-typedef-decls"],
+ "include/mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc",
+ ),
+ (
+ ["-gen-typedef-defs"],
+ "include/mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc",
+ ),
(
["-gen-attrdef-decls"],
"include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc",
More information about the llvm-commits
mailing list