[llvm] [mlir] [TOSA] Add Tosa_Shape type and ConstShapeOp (PR #122547)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 10 15:19:30 PST 2025


https://github.com/Jerry-Ge updated https://github.com/llvm/llvm-project/pull/122547

>From 39d8e4437ac2fcef4713c4872e2178d57865cf04 Mon Sep 17 00:00:00 2001
From: Jerry Ge <Jerry.Ge at arm.com>
Date: Wed, 8 Jan 2025 16:23:25 -0800
Subject: [PATCH] [TOSA] Add Tosa_Shape type and ConstShapeOp

Adds:
 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
    rank-4 shape values (size-4 array of index values)
 2. const_shape operator
 3. trait TosaShapeOperator, added to tosa shape operators, and a
    verifier that all operands and results of operator are tosa shapes
 4. trait TosaResolvableShapeOperands, added to all tosa operators, and
    a verifier that every tosa shape operand is produced by a tosa shape
    operator (indicated by trait TosaShapeOperator)
 5. trait TosaShapeOperatorWithSameRanks, added to
    Tosa_ElementwiseShapeOp and a verifier that all operands and result
    shapes have same ranks
 5. changed TileOp's multiples from attribute to input, of !tosa.shape
    type.
 6. add folder for tosa ConstShape operator

Signed-off-by: Jerry Ge <Jerry.Ge at arm.com>
Signed-off-by: Tai Ly <tai.ly at arm.com>

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
---
 .../mlir/Dialect/Tosa/IR/CMakeLists.txt       |   3 +-
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   1 +
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  41 +++++
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |   9 +-
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      |  79 ++++++++++
 .../include/mlir/Dialect/Tosa/IR/TosaTypes.td |  87 +++++++++++
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |   4 +-
 .../TosaToLinalg/TosaToLinalgPass.cpp         |   1 +
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  19 ++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 142 ++++++++++++++++--
 .../Tosa/Transforms/TosaValidation.cpp        |   2 +
 .../TosaToLinalg/tosa-to-linalg.mlir          |  15 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |   9 +-
 mlir/test/Dialect/Tosa/invalid.mlir           |  33 +++-
 mlir/test/Dialect/Tosa/level_check.mlir       |   4 +-
 mlir/test/Dialect/Tosa/ops.mlir               |  10 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |   6 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   8 +
 18 files changed, 441 insertions(+), 32 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 1ee105f0ceb98b..81c0f2ef159e82 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,6 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
 add_mlir_interface(TosaInterfaces)
 
 set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
+mlir_tablegen(TosaOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
 mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
 mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
 add_public_tablegen_target(MLIRTosaAttributesIncGen)
@@ -10,4 +12,3 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
 mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
 add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
-
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index d3f12c34421b06..a66c8b1975e5ba 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
   let cppNamespace = "mlir::tosa";
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..9bc95099243e49 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -90,14 +90,55 @@ template <typename ConcreteType>
 class TosaElementwiseOperator
     : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
 
+LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
+/// This class verifies that tosa shape operands are compile time resolvable
+template <typename ConcreteType>
+class TosaResolvableShapeOperands
+    : public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaResolvableShapeOperands(op);
+  }
+};
+
+LogicalResult verifyTosaShapeOperator(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaShapeOperator(op);
+  }
+};
+
+LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperatorWithSameRanks
+    : public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaShapeOperatorWithSameRanks(op);
+  }
+};
+
 } // namespace tosa
 } // namespace OpTrait
 
+namespace tosa {
+
+bool isa_tosa_shape_type(mlir::Type t);
+
+} // namespace tosa
+
 } // namespace mlir
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..718ca361c05469 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -23,6 +23,7 @@ include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
 
 include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
 include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaTypes.td"
 
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.2
@@ -1689,12 +1690,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$multiples);
+    Tosa_Shape:$multiples);
 
   let results = (outs
     Tosa_Tensor:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
+  }];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -2106,4 +2111,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
 
+include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
+
 #endif // TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
new file mode 100644
index 00000000000000..aacb04f77ce0e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -0,0 +1,79 @@
+//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines shape operators for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_SHAPE_OPS
+#define TOSA_SHAPE_OPS
+
+include "mlir/IR/OpBase.td"
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
+
+include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaTypes.td"
+
+// Op trait: operator has operands and results with TOSA shape type
+def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+
+  let hasFolder = 1;
+}
+
+// op trait: shape operator has same ranks for operands and results
+def TosaShapeOperatorWithSameRanks : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_ShapeOp<mnemonic, !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: ConstShape
+//===----------------------------------------------------------------------===//
+def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
+  let summary = "Constant Shape op.";
+
+  let description = [{
+    A node containing constant data for use as the input to an shape operation. May
+    hold data only in index data type.
+
+    Example:
+
+    ```mlir
+    // Generic form
+    %out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+    ```
+  }];
+
+  let arguments = (ins
+    IndexElementsAttr:$value
+  );
+
+  let results = (outs
+    Tosa_Shape:$output
+  );
+
+  let hasVerifier = 1;
+}
+
+#endif // TOSA_SHAPE_OPS
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
new file mode 100644
index 00000000000000..480248a3216af7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
@@ -0,0 +1,87 @@
+//===-- TosaTypes.td - TOSA type definitions ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the type definitions for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_TYPES
+#define TOSA_TYPES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Tosa Type Definitions.
+//===----------------------------------------------------------------------===//
+
+// The base class for Tosa dialect types.
+class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Tosa_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeType
+//===----------------------------------------------------------------------===//
+def Tosa_Shape : Tosa_Type<"shape", "shape"> {
+  let summary = "Shape with static rank and Index element type";
+  let description = [{
+    Syntax:
+
+    ```
+    shape-type ::= `shape` `<` rank `>`
+    ```
+    Values with shape type represents a shape with a fixed rank and a list of dimensions.
+    Rank must be zero or a positive integer.
+    Each dimension is represented by the builtin Index type.
+
+    Examples:
+
+    ```mlir
+    // Shape with rank of four, for example, [1, 1, 8, 16]:
+    !tosa.shape<4>
+
+    // Shape with rank of one, for example, [16]:
+    !tosa.shape<1>
+
+    // Shape with rank zero, for example, [] (i.e., shape of scalar values):
+    !tosa.shape<0>
+    ```
+  }];
+  let parameters = (ins
+    "int":$rank
+  );
+  let builders = [
+   TypeBuilder<(ins "int":$rank)>
+  ];
+  let assemblyFormat = "`<` $rank `>`";
+
+  let genVerifyDecl = 1;
+}
+
+def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
+
+// Whether a Tosa Shape type has a rank equal to the specified rank.
+class IsTosaShapeOfRankPred<int rank> : And<[
+    IsTosaShapeType,
+    CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
+]>;
+
+class TosaShapeOfRank<int rank> :
+    Type<IsTosaShapeOfRankPred<rank>,
+         "Tosa shape type of rank " # rank
+         >;
+
+def Rank1TosaShape : TosaShapeOfRank<1>;
+def Rank2TosaShape : TosaShapeOfRank<2>;
+def Rank4TosaShape : TosaShapeOfRank<4>;
+
+#endif // TOSA_TYPES
\ No newline at end of file
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1d7ead16e8b631..9295afd36e3ab1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     auto elementTy = inputTy.getElementType();
     int64_t rank = inputTy.getRank();
 
-    ArrayRef<int64_t> multiples = op.getMultiples();
+    SmallVector<int64_t> multiples;
+    if (failed(op.getConstantMultiples(multiples)))
+      return failure();
 
     // Broadcast the newly added dimensions to their appropriate multiple.
     SmallVector<int64_t, 2> genericShape;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 06a7262c467421..8dfa55bef74fc4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
     target.addLegalOp<tosa::ApplyScaleOp>();
     target.addLegalOp<tosa::IfOp>();
     target.addLegalOp<tosa::ConstOp>();
+    target.addLegalOp<tosa::ConstShapeOp>();
     target.addLegalOp<tosa::WhileOp>();
     target.addLegalOp<tosa::ConcatOp>();
     target.addLegalOp<tosa::SliceOp>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f51c3dbce6eefe..f7a596f1ccb192 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
+OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
+
 #define REDUCE_FOLDER(OP)                                                      \
   OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \
     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
-  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
-  if (allOnes && getInput1().getType() == getType())
-    return getInput1();
+  if (getInput1().getType() == getType()) {
+    if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getMultiples())) {
+      if (multiples.isSplat() &&
+          multiples.getSplatValue<APInt>().getSExtValue() == 1)
+        return getInput1();
+      if (auto int_array_attr =
+              llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
+        if (llvm::all_of(int_array_attr.getValues<APInt>(),
+                         [](APInt v) { return v.getSExtValue() == 1; }))
+          return getInput1();
+      }
+    }
+  }
   return {};
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..46400d74953e0d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -130,6 +130,10 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
 //===----------------------------------------------------------------------===//
 
 void TosaDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc"
+      >();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
@@ -153,6 +157,10 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                             Type type, Location loc) {
   // Tosa dialect constants only support ElementsAttr unlike standard dialect
   // constant which supports all attributes.
+  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
+    return builder.create<tosa::ConstShapeOp>(
+        loc, type, llvm::cast<DenseIntElementsAttr>(value));
+  }
   if (llvm::isa<ElementsAttr>(value))
     return builder.create<tosa::ConstOp>(loc, type,
                                          llvm::cast<ElementsAttr>(value));
@@ -962,11 +970,31 @@ LogicalResult tosa::TableOp::verify() {
   return success();
 }
 
+LogicalResult
+tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
+  // Multiples must be constants.
+  DenseIntElementsAttr multiplesAttr;
+  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
+    return failure();
+  multiples = llvm::to_vector(
+      llvm::map_range(multiplesAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+  return success();
+}
+
 LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ArrayRef<int64_t> multiples = adaptor.getMultiples();
+  DenseIntElementsAttr multiplesAttr;
+  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
+    return failure();
+
+  // ArrayRef<int64_t> multiples = adaptor.getMultiples();
+  SmallVector<int64_t> multiples = llvm::to_vector(
+      llvm::map_range(multiplesAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
@@ -992,20 +1020,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 LogicalResult tosa::TileOp::verify() {
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
-  auto multiples = getMultiples();
+
+  shapeType multiplesType =
+      llvm::cast<tosa::shapeType>(getMultiples().getType());
+
+  auto multiplesRank = multiplesType.getRank();
 
   if (inputType.hasRank()) {
-    if (static_cast<size_t>(inputType.getRank()) != multiples.size())
-      return emitOpError("expect 'multiples' array to have length ")
-             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (inputType.getRank() != multiplesRank)
+      return emitOpError("expect 'multiples' to have rank ")
+             << inputType.getRank() << " but got " << multiplesRank << ".";
     if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
       return emitOpError("expect same input and output tensor rank.");
-  } else if (outputType.hasRank() &&
-             static_cast<size_t>(outputType.getRank()) != multiples.size())
+  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
     return emitOpError("expect 'multiples' array to have length ")
-           << outputType.getRank() << " but got " << multiples.size() << ".";
+           << outputType.getRank() << " but got " << multiplesRank << ".";
 
-  if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
+  SmallVector<int64_t> multiples;
+  if (getConstantMultiples(multiples).succeeded() &&
+      llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
     return emitOpError(
         "expect element of 'multiples' to be positive integer or -1.");
 
@@ -2146,6 +2179,91 @@ void WhileOp::print(OpAsmPrinter &parser) {
   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA Shape and Shape Operators Helper functions.
+//===----------------------------------------------------------------------===//
+
+bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
+  return mlir::isa<tosa::shapeType>(t);
+}
+
+LogicalResult
+mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
+                              int rank) {
+  if (rank < 0)
+    return emitError() << "invalid rank (must be >= 0): " << rank;
+  return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
+  for (auto v : op->getOperands()) {
+    if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
+      Operation *definingOp = v.getDefiningOp();
+      if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
+        return op->emitOpError("shape operand is not compile time resolvable");
+      }
+    }
+  }
+  return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
+  for (auto type : op->getOperandTypes()) {
+    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+      return op->emitOpError("must have operands with tosa shape type");
+    }
+  }
+  for (auto type : op->getResultTypes()) {
+    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+      return op->emitOpError("must have result with tosa shape type");
+    }
+  }
+  return success();
+}
+
+LogicalResult
+OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
+  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
+      failed(verifyTosaShapeOperator(op)))
+    return failure();
+
+  // delegate function that returns rank of shape type
+  auto getRank = [](const Type type) {
+    return mlir::cast<mlir::tosa::shapeType>(type).getRank();
+  };
+  auto operandTypes = op->getOperandTypes();
+  auto resultTypes = op->getResultTypes();
+
+  auto rank = getRank(*op->getOperandTypes().begin());
+  for (auto type : operandTypes) {
+    if (getRank(type) != rank) {
+      return op->emitOpError("operands don't have matching ranks");
+    }
+  }
+  for (auto type : resultTypes) {
+    if (getRank(type) != rank) {
+      return op->emitOpError("result shape has different rank than operands");
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Shape Operators verify functions.
+//===----------------------------------------------------------------------===//
+
+LogicalResult tosa::ConstShapeOp::verify() {
+  // check that number of elements in value attr equal to rank of result shape
+  auto count = getValue().getNumElements();
+  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
+  if (!(count == rank || (count == 1 && rank == 0))) {
+    return emitOpError("expect number of elements in attribute value (")
+           << count << ") to be equal to the rank (" << rank
+           << ") for the result shape type";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Attribute Definitions.
 //===----------------------------------------------------------------------===//
@@ -2153,6 +2271,12 @@ void WhileOp::print(OpAsmPrinter &parser) {
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// TOSA Type Definitions.
+//===----------------------------------------------------------------------===//
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Definitions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 8588c878bfe4f8..a49870687fdc60 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -536,6 +536,8 @@ bool TosaValidation::isValidElementType(Type type) {
         return true;
       }
     }
+  } else if (mlir::isa<tosa::shapeType>(type)) {
+    return true;
   }
   return false;
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c840fb8648d7b7..1d235092b71d55 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1378,21 +1378,24 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 4, 3>}
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1>} : (tensor<2x3xi8>) -> tensor<4x3xi8>
+  %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 2, 6>}
-  %1 = tosa.tile %arg0 {multiples = array<i64: 1, 2>} : (tensor<2x3xi8>) -> tensor<2x6xi8>
+  %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 10, 21>}
-  %2 = tosa.tile %arg0 {multiples = array<i64: 5, 7>} : (tensor<2x3xi8>)  -> tensor<10x21xi8>
+  %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>)  -> tensor<10x21xi8>
 
   return
 }
@@ -1412,7 +1415,8 @@ func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
   // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: -9223372036854775808, 3>}
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1>} : (tensor<?x3xi8>)  -> tensor<?x3xi8>
+  %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst21: (tensor<?x3xi8>, !tosa.shape<2>)  -> tensor<?x3xi8>
 
   return
 }
@@ -1432,7 +1436,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: 2, -9223372036854775808>}
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, -1>} : (tensor<2x3xi8>)  -> tensor<2x?xi8>
+  %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>)  -> tensor<2x?xi8>
 
   return
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 60121bb0ea2f12..889e2eda9e5b84 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -588,7 +588,8 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
 // CHECK-LABEL: @tile_fold
 func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
-  %0 = tosa.tile %arg0 { multiples = array<i64: 1, 1> }: (tensor<3x4xf32>) -> tensor<3x4xf32>
+  %cst = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32>
   return %0 : tensor<3x4xf32>
 }
 
@@ -597,7 +598,8 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
 // CHECK-LABEL: @tile_nofold
 func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   // CHECK: tosa.tile
-  %0 = tosa.tile %arg0 { multiples = array<i64: 1, 2> }: (tensor<3x4xf32>) -> tensor<3x8xf32>
+  %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32>
   return %0 : tensor<3x8xf32>
 }
 
@@ -763,7 +765,8 @@ func.func @fold_reduce_rank_zero() {
 func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   // CHECK-NOT: tosa.tile
   %0 = tensor.empty() : tensor<i32>
-  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  %cst = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+  %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
   return %1 : tensor<i32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a6d57f8a2f61f3..cc7fd009f01fa6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -621,8 +621,9 @@ func.func @test_slice_invalid_size() {
 
 func.func @test_tile_invalid_multiples() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
-  // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
-  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+  %cst = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1>
+  // expected-error at +1 {{'tosa.tile' op expect 'multiples' to have rank 3 but got 1.}}
+  %1 = tosa.tile %0, %cst: (tensor<4x31x31xf32>, !tosa.shape<1>) -> tensor<4x31x31xf32>
   return
 }
 
@@ -630,8 +631,9 @@ func.func @test_tile_invalid_multiples() {
 
 func.func @test_tile_invalid_multiples_value() {
   %0 = tensor.empty() : tensor<4x31xf32>
+  %multiples = tosa.const_shape { value = dense<[2, -2]> : tensor<2xindex> } : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.tile' op expect element of 'multiples' to be positive integer or -1.}}
-  %1 = tosa.tile %0 {multiples = array<i64: 2, -2>} : (tensor<4x31xf32>) -> tensor<4x31xf32>
+  %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31xf32>
   return
 }
 
@@ -639,8 +641,9 @@ func.func @test_tile_invalid_multiples_value() {
 
 func.func @test_tile_io_rank_mismatch() {
   %0 = tensor.empty() : tensor<4x31xf32>
+  %multiples = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.tile' op expect same input and output tensor rank.}}
-  %1 = tosa.tile %0 {multiples = array<i64: 2, 2>} : (tensor<4x31xf32>) -> tensor<4x31x31xf32>
+  %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31x31xf32>
   return
 }
 
@@ -993,3 +996,25 @@ func.func @test_non_tosa_ops() {
   %2 = tensor.empty(%0) : tensor<?x27xi64>
   return
 }
+
+// -----
+
+// expected-error at +1 {{invalid rank (must be >= 0): -1}}
+func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> {
+  return %arg0 : !tosa.shape<-1>
+}
+
+// -----
+func.func @test_const_shape() -> !tosa.shape<4> {
+  // expected-error at +1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}}
+  %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4>
+  return %cst : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_const_shape_value() -> !tosa.shape<5> {
+  // expected-error at +1 {{'tosa.const_shape' op expect number of elements in attribute value (4) to be equal to the rank (5) for the result shape type}}
+  %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5>
+  return %cst : !tosa.shape<5>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index ba8ed8a1e5f50f..0fe35d88f0e73a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -95,8 +95,9 @@ func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11
 // -----
 // CHECK-LABEL: tile
 func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> {
+  %cst = tosa.const_shape { value = dense<[1, 1, 1, 1, 3, 1, 2]> : tensor<7xindex> } : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.tile"(%arg0) {multiples = array<i64: 1, 1, 1, 1, 3, 1, 2>} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32>
+  %0 = tosa.tile %arg0, %cst : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x39x21x6xf32>
   return %0 : tensor<1x1x1x1x39x21x6xf32>
 }
 
@@ -740,4 +741,3 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
           (tensor<*xf32>) -> tensor<*xf32>
   return
 }
-
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f2e1cff72ab281..690e208af1e5f9 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -562,7 +562,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
 // -----
 // CHECK-LABEL: tile
 func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
-  %0 = tosa.tile %arg0 {multiples = array<i64: 3, 1, 2>} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32>
+  %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
   return %0 : tensor<39x21x6xf32>
 }
 
@@ -692,3 +693,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
   %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
   return %0 : tensor<10xi32>
 }
+
+// -----
+// CHECK-LABEL: const_shape
+func.func @test_const_shape() -> !tosa.shape<4> {
+  %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+  return %cst : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 82f3e22a387221..f4da66ef561b26 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -543,8 +543,10 @@ func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
 
 // CHECK-LABEL: @test_tile
 func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
-  // CHECK: tosa.tile %arg0 {multiples = array<i64: 2, 1, 5>} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32>
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1, 5>} : (tensor<2x3x?xi32>) -> tensor<?x?x?xi32>
+  // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x3x?xi32>
+  %cst = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
   return
 }
 
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5c2a77ca67fd4c..1211759a86c5eb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12115,6 +12115,14 @@ gentbl_cc_library(
             ["-gen-dialect-defs"],
             "include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc",
         ),
+        (
+            ["-gen-typedef-decls"],
+            "include/mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc",
+        ),
+        (
+            ["-gen-typedef-defs"],
+            "include/mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc",
+        ),
         (
             ["-gen-attrdef-decls"],
             "include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc",



More information about the llvm-commits mailing list