[llvm] [mlir] [TOSA] Add Tosa_Shape type and ConstShapeOp (PR #122547)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 10 15:08:29 PST 2025


https://github.com/Jerry-Ge created https://github.com/llvm/llvm-project/pull/122547

Adds:
 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values)
 2. const_shape operator
 3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes
 4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator)
 5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks
 5. changed TileOp's multiples from attribute to input, of !tosa.shape type.
 6. add folder for tosa ConstShape operator

This patch was originally authored by Tai Ly <tai.ly at arm.com>

Signed-off-by: Jerry Ge <Jerry.Ge at arm.com>
Signed-off-by: Tai Ly <tai.ly at arm.com>

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8

>From 4360edfc421f6588ae50e2437330917b205f5517 Mon Sep 17 00:00:00 2001
From: Jerry Ge <Jerry.Ge at arm.com>
Date: Wed, 8 Jan 2025 16:23:25 -0800
Subject: [PATCH] [TOSA] Add Tosa_Shape type and ConstShapeOp

Adds:
 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
    rank-4 shape values (size-4 array of index values)
 2. const_shape operator
 3. trait TosaShapeOperator, added to tosa shape operators, and a
    verifier that all operands and results of operator are tosa shapes
 4. trait TosaResolvableShapeOperands, added to all tosa operators, and
    a verifier that every tosa shape operand is produced by a tosa shape
    operator (indicated by trait TosaShapeOperator)
 5. trait TosaShapeOperatorWithSameRanks, added to
    Tosa_ElementwiseShapeOp and a verifier that all operands and result
    shapes have same ranks
 5. changed TileOp's multiples from attribute to input, of !tosa.shape
    type.
 6. add folder for tosa ConstShape operator

Signed-off-by: Jerry Ge <Jerry.Ge at arm.com>
Signed-off-by: Tai Ly <tai.ly at arm.com>

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
---
 .../mlir/Dialect/Tosa/IR/CMakeLists.txt       |   3 +-
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   1 +
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  43 ++++++
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |   9 +-
 .../mlir/Dialect/Tosa/IR/TosaShapeOps.td      |  79 ++++++++++
 .../include/mlir/Dialect/Tosa/IR/TosaTypes.td |  87 +++++++++++
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |   4 +-
 .../TosaToLinalg/TosaToLinalgPass.cpp         |   1 +
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  19 ++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 143 ++++++++++++++++--
 .../Tosa/Transforms/TosaValidation.cpp        |   2 +
 .../TosaToLinalg/tosa-to-linalg.mlir          |  15 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |   9 +-
 mlir/test/Dialect/Tosa/invalid.mlir           |  33 +++-
 mlir/test/Dialect/Tosa/level_check.mlir       |   4 +-
 mlir/test/Dialect/Tosa/ops.mlir               |  10 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |   6 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   8 +
 18 files changed, 444 insertions(+), 32 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
 create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 1ee105f0ceb98b..81c0f2ef159e82 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,6 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
 add_mlir_interface(TosaInterfaces)
 
 set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
+mlir_tablegen(TosaOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
 mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
 mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
 add_public_tablegen_target(MLIRTosaAttributesIncGen)
@@ -10,4 +12,3 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
 mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
 add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
-
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index d3f12c34421b06..a66c8b1975e5ba 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
   let cppNamespace = "mlir::tosa";
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..f00cb4c282db88 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -90,14 +90,57 @@ template <typename ConcreteType>
 class TosaElementwiseOperator
     : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
 
+LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
+/// This class verifies that tosa shape operands are compile time resolvable
+template <typename ConcreteType>
+class TosaResolvableShapeOperands
+    : public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaResolvableShapeOperands(op);
+  }
+};
+
+
+LogicalResult verifyTosaShapeOperator(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaShapeOperator(op);
+  }
+};
+
+
+LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperatorWithSameRanks
+    : public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaShapeOperatorWithSameRanks(op);
+  }
+};
+
 } // namespace tosa
 } // namespace OpTrait
 
+namespace tosa {
+
+bool isa_tosa_shape_type(mlir::Type t);
+
+} // namespace tosa
+
 } // namespace mlir
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..718ca361c05469 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -23,6 +23,7 @@ include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
 
 include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
 include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaTypes.td"
 
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.2
@@ -1689,12 +1690,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$multiples);
+    Tosa_Shape:$multiples);
 
   let results = (outs
     Tosa_Tensor:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
+  }];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -2106,4 +2111,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
 
+include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
+
 #endif // TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
new file mode 100644
index 00000000000000..aacb04f77ce0e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -0,0 +1,79 @@
+//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines shape operators for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_SHAPE_OPS
+#define TOSA_SHAPE_OPS
+
+include "mlir/IR/OpBase.td"
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
+
+include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaTypes.td"
+
+// Op trait: operator has operands and results with TOSA shape type
+def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+
+  let hasFolder = 1;
+}
+
+// op trait: shape operator has same ranks for operands and results
+def TosaShapeOperatorWithSameRanks : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_ShapeOp<mnemonic, !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: ConstShape
+//===----------------------------------------------------------------------===//
+def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
+  let summary = "Constant Shape op.";
+
+  let description = [{
+    A node containing constant data for use as the input to an shape operation. May
+    hold data only in index data type.
+
+    Example:
+
+    ```mlir
+    // Generic form
+    %out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+    ```
+  }];
+
+  let arguments = (ins
+    IndexElementsAttr:$value
+  );
+
+  let results = (outs
+    Tosa_Shape:$output
+  );
+
+  let hasVerifier = 1;
+}
+
+#endif // TOSA_SHAPE_OPS
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
new file mode 100644
index 00000000000000..480248a3216af7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
@@ -0,0 +1,87 @@
+//===-- TosaTypes.td - TOSA type definitions ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the type definitions for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_TYPES
+#define TOSA_TYPES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Tosa Type Definitions.
+//===----------------------------------------------------------------------===//
+
+// The base class for Tosa dialect types.
+class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Tosa_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeType
+//===----------------------------------------------------------------------===//
+def Tosa_Shape : Tosa_Type<"shape", "shape"> {
+  let summary = "Shape with static rank and Index element type";
+  let description = [{
+    Syntax:
+
+    ```
+    shape-type ::= `shape` `<` rank `>`
+    ```
+    Values with shape type represents a shape with a fixed rank and a list of dimensions.
+    Rank must be zero or a positive integer.
+    Each dimension is represented by the builtin Index type.
+
+    Examples:
+
+    ```mlir
+    // Shape with rank of four, for example, [1, 1, 8, 16]:
+    !tosa.shape<4>
+
+    // Shape with rank of one, for example, [16]:
+    !tosa.shape<1>
+
+    // Shape with rank zero, for example, [] (i.e., shape of scalar values):
+    !tosa.shape<0>
+    ```
+  }];
+  let parameters = (ins
+    "int":$rank
+  );
+  let builders = [
+   TypeBuilder<(ins "int":$rank)>
+  ];
+  let assemblyFormat = "`<` $rank `>`";
+
+  let genVerifyDecl = 1;
+}
+
+def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
+
+// Whether a Tosa Shape type has a rank equal to the specified rank.
+class IsTosaShapeOfRankPred<int rank> : And<[
+    IsTosaShapeType,
+    CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
+]>;
+
+class TosaShapeOfRank<int rank> :
+    Type<IsTosaShapeOfRankPred<rank>,
+         "Tosa shape type of rank " # rank
+         >;
+
+def Rank1TosaShape : TosaShapeOfRank<1>;
+def Rank2TosaShape : TosaShapeOfRank<2>;
+def Rank4TosaShape : TosaShapeOfRank<4>;
+
+#endif // TOSA_TYPES
\ No newline at end of file
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1d7ead16e8b631..9295afd36e3ab1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     auto elementTy = inputTy.getElementType();
     int64_t rank = inputTy.getRank();
 
-    ArrayRef<int64_t> multiples = op.getMultiples();
+    SmallVector<int64_t> multiples;
+    if (failed(op.getConstantMultiples(multiples)))
+      return failure();
 
     // Broadcast the newly added dimensions to their appropriate multiple.
     SmallVector<int64_t, 2> genericShape;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 06a7262c467421..8dfa55bef74fc4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
     target.addLegalOp<tosa::ApplyScaleOp>();
     target.addLegalOp<tosa::IfOp>();
     target.addLegalOp<tosa::ConstOp>();
+    target.addLegalOp<tosa::ConstShapeOp>();
     target.addLegalOp<tosa::WhileOp>();
     target.addLegalOp<tosa::ConcatOp>();
     target.addLegalOp<tosa::SliceOp>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f51c3dbce6eefe..f7a596f1ccb192 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
+OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
+
 #define REDUCE_FOLDER(OP)                                                      \
   OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \
     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
-  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
-  if (allOnes && getInput1().getType() == getType())
-    return getInput1();
+  if (getInput1().getType() == getType()) {
+    if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getMultiples())) {
+      if (multiples.isSplat() &&
+          multiples.getSplatValue<APInt>().getSExtValue() == 1)
+        return getInput1();
+      if (auto int_array_attr =
+              llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
+        if (llvm::all_of(int_array_attr.getValues<APInt>(),
+                         [](APInt v) { return v.getSExtValue() == 1; }))
+          return getInput1();
+      }
+    }
+  }
   return {};
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..3354f6908c55cd 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -130,6 +130,10 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
 //===----------------------------------------------------------------------===//
 
 void TosaDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc"
+      >();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
@@ -153,6 +157,10 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                             Type type, Location loc) {
   // Tosa dialect constants only support ElementsAttr unlike standard dialect
   // constant which supports all attributes.
+  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
+    return builder.create<tosa::ConstShapeOp>(
+        loc, type, llvm::cast<DenseIntElementsAttr>(value));
+  }
   if (llvm::isa<ElementsAttr>(value))
     return builder.create<tosa::ConstOp>(loc, type,
                                          llvm::cast<ElementsAttr>(value));
@@ -962,11 +970,32 @@ LogicalResult tosa::TableOp::verify() {
   return success();
 }
 
+LogicalResult
+tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
+  // Multiples must be constants.
+  DenseIntElementsAttr multiplesAttr;
+  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
+    return failure();
+  multiples = llvm::to_vector(
+      llvm::map_range(multiplesAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+  return success();
+}
+
 LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ArrayRef<int64_t> multiples = adaptor.getMultiples();
+  DenseIntElementsAttr multiplesAttr;
+  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
+    return failure();
+
+
+  // ArrayRef<int64_t> multiples = adaptor.getMultiples();
+  SmallVector<int64_t> multiples = llvm::to_vector(
+      llvm::map_range(multiplesAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
@@ -992,20 +1021,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 LogicalResult tosa::TileOp::verify() {
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
-  auto multiples = getMultiples();
+
+  shapeType multiplesType =
+      llvm::cast<tosa::shapeType>(getMultiples().getType());
+
+  auto multiplesRank = multiplesType.getRank();
 
   if (inputType.hasRank()) {
-    if (static_cast<size_t>(inputType.getRank()) != multiples.size())
-      return emitOpError("expect 'multiples' array to have length ")
-             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (inputType.getRank() != multiplesRank)
+      return emitOpError("expect 'multiples' to have rank ")
+             << inputType.getRank() << " but got " << multiplesRank << ".";
     if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
       return emitOpError("expect same input and output tensor rank.");
-  } else if (outputType.hasRank() &&
-             static_cast<size_t>(outputType.getRank()) != multiples.size())
+  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
     return emitOpError("expect 'multiples' array to have length ")
-           << outputType.getRank() << " but got " << multiples.size() << ".";
+           << outputType.getRank() << " but got " << multiplesRank << ".";
 
-  if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
+  SmallVector<int64_t> multiples;
+  if (getConstantMultiples(multiples).succeeded() &&
+      llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
     return emitOpError(
         "expect element of 'multiples' to be positive integer or -1.");
 
@@ -2146,6 +2180,91 @@ void WhileOp::print(OpAsmPrinter &parser) {
   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA Shape and Shape Operators Helper functions.
+//===----------------------------------------------------------------------===//
+
+bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
+  return mlir::isa<tosa::shapeType>(t);
+}
+
+LogicalResult
+mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
+                              int rank) {
+  if (rank < 0)
+    return emitError() << "invalid rank (must be >= 0): " << rank;
+  return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
+  for (auto v : op->getOperands()) {
+    if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
+      Operation *definingOp = v.getDefiningOp();
+      if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
+        return op->emitOpError("shape operand is not compile time resolvable");
+      }
+    }
+  }
+  return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
+  for (auto type : op->getOperandTypes()) {
+    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+      return op->emitOpError("must have operands with tosa shape type");
+    }
+  }
+  for (auto type : op->getResultTypes()) {
+    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+      return op->emitOpError("must have result with tosa shape type");
+    }
+  }
+  return success();
+}
+
+LogicalResult
+OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
+  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
+      failed(verifyTosaShapeOperator(op)))
+    return failure();
+
+  // delegate function that returns rank of shape type
+  auto getRank = [](const Type type) {
+    return mlir::cast<mlir::tosa::shapeType>(type).getRank();
+  };
+  auto operandTypes = op->getOperandTypes();
+  auto resultTypes = op->getResultTypes();
+
+  auto rank = getRank(*op->getOperandTypes().begin());
+  for (auto type : operandTypes) {
+    if (getRank(type) != rank) {
+      return op->emitOpError("operands don't have matching ranks");
+    }
+  }
+  for (auto type : resultTypes) {
+    if (getRank(type) != rank) {
+      return op->emitOpError("result shape has different rank than operands");
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TOSA Shape Operators verify functions.
+//===----------------------------------------------------------------------===//
+
+LogicalResult tosa::ConstShapeOp::verify() {
+  // check that number of elements in value attr equal to rank of result shape
+  auto count = getValue().getNumElements();
+  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
+  if (!(count == rank || (count == 1 && rank == 0))) {
+    return emitOpError("expect number of elements in attribute value (")
+           << count << ") to be equal to the rank (" << rank
+           << ") for the result shape type";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Attribute Definitions.
 //===----------------------------------------------------------------------===//
@@ -2153,6 +2272,12 @@ void WhileOp::print(OpAsmPrinter &parser) {
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// TOSA Type Definitions.
+//===----------------------------------------------------------------------===//
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Definitions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 8588c878bfe4f8..7c834cbff0000e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -536,6 +536,8 @@ bool TosaValidation::isValidElementType(Type type) {
         return true;
       }
     }
+  } else if (mlir::isa<tosa::shapeType>(type)) {
+      return true;
   }
   return false;
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c840fb8648d7b7..1d235092b71d55 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1378,21 +1378,24 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 4, 3>}
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1>} : (tensor<2x3xi8>) -> tensor<4x3xi8>
+  %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 2, 6>}
-  %1 = tosa.tile %arg0 {multiples = array<i64: 1, 2>} : (tensor<2x3xi8>) -> tensor<2x6xi8>
+  %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
   // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape [[GENERIC]] {new_shape = array<i64: 10, 21>}
-  %2 = tosa.tile %arg0 {multiples = array<i64: 5, 7>} : (tensor<2x3xi8>)  -> tensor<10x21xi8>
+  %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>)  -> tensor<10x21xi8>
 
   return
 }
@@ -1412,7 +1415,8 @@ func.func @tile_dyn_input(%arg0 : tensor<?x3xi8>) -> () {
   // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: -9223372036854775808, 3>}
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1>} : (tensor<?x3xi8>)  -> tensor<?x3xi8>
+  %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst21: (tensor<?x3xi8>, !tosa.shape<2>)  -> tensor<?x3xi8>
 
   return
 }
@@ -1432,7 +1436,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () {
   // CHECK: ^bb0(%[[ARG1:.+]]: i8,
   // CHECK:   linalg.yield %[[ARG1]] : i8
   // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array<i64: 2, -9223372036854775808>}
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, -1>} : (tensor<2x3xi8>)  -> tensor<2x?xi8>
+  %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>)  -> tensor<2x?xi8>
 
   return
 }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 60121bb0ea2f12..889e2eda9e5b84 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -588,7 +588,8 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
 // CHECK-LABEL: @tile_fold
 func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
-  %0 = tosa.tile %arg0 { multiples = array<i64: 1, 1> }: (tensor<3x4xf32>) -> tensor<3x4xf32>
+  %cst = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32>
   return %0 : tensor<3x4xf32>
 }
 
@@ -597,7 +598,8 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
 // CHECK-LABEL: @tile_nofold
 func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
   // CHECK: tosa.tile
-  %0 = tosa.tile %arg0 { multiples = array<i64: 1, 2> }: (tensor<3x4xf32>) -> tensor<3x8xf32>
+  %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32>
   return %0 : tensor<3x8xf32>
 }
 
@@ -763,7 +765,8 @@ func.func @fold_reduce_rank_zero() {
 func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   // CHECK-NOT: tosa.tile
   %0 = tensor.empty() : tensor<i32>
-  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  %cst = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
+  %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
   return %1 : tensor<i32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a6d57f8a2f61f3..cc7fd009f01fa6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -621,8 +621,9 @@ func.func @test_slice_invalid_size() {
 
 func.func @test_tile_invalid_multiples() {
   %0 = tensor.empty() : tensor<4x31x31xf32>
-  // expected-error at +1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
-  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
+  %cst = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1>
+  // expected-error at +1 {{'tosa.tile' op expect 'multiples' to have rank 3 but got 1.}}
+  %1 = tosa.tile %0, %cst: (tensor<4x31x31xf32>, !tosa.shape<1>) -> tensor<4x31x31xf32>
   return
 }
 
@@ -630,8 +631,9 @@ func.func @test_tile_invalid_multiples() {
 
 func.func @test_tile_invalid_multiples_value() {
   %0 = tensor.empty() : tensor<4x31xf32>
+  %multiples = tosa.const_shape { value = dense<[2, -2]> : tensor<2xindex> } : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.tile' op expect element of 'multiples' to be positive integer or -1.}}
-  %1 = tosa.tile %0 {multiples = array<i64: 2, -2>} : (tensor<4x31xf32>) -> tensor<4x31xf32>
+  %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31xf32>
   return
 }
 
@@ -639,8 +641,9 @@ func.func @test_tile_invalid_multiples_value() {
 
 func.func @test_tile_io_rank_mismatch() {
   %0 = tensor.empty() : tensor<4x31xf32>
+  %multiples = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
   // expected-error at +1 {{'tosa.tile' op expect same input and output tensor rank.}}
-  %1 = tosa.tile %0 {multiples = array<i64: 2, 2>} : (tensor<4x31xf32>) -> tensor<4x31x31xf32>
+  %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31x31xf32>
   return
 }
 
@@ -993,3 +996,25 @@ func.func @test_non_tosa_ops() {
   %2 = tensor.empty(%0) : tensor<?x27xi64>
   return
 }
+
+// -----
+
+// expected-error at +1 {{invalid rank (must be >= 0): -1}}
+func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> {
+  return %arg0 : !tosa.shape<-1>
+}
+
+// -----
+func.func @test_const_shape() -> !tosa.shape<4> {
+  // expected-error at +1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}}
+  %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4>
+  return %cst : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_const_shape_value() -> !tosa.shape<5> {
+  // expected-error at +1 {{'tosa.const_shape' op expect number of elements in attribute value (4) to be equal to the rank (5) for the result shape type}}
+  %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5>
+  return %cst : !tosa.shape<5>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index ba8ed8a1e5f50f..0fe35d88f0e73a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -95,8 +95,9 @@ func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11
 // -----
 // CHECK-LABEL: tile
 func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> {
+  %cst = tosa.const_shape { value = dense<[1, 1, 1, 1, 3, 1, 2]> : tensor<7xindex> } : () -> !tosa.shape<7>
   // expected-error at +1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.tile"(%arg0) {multiples = array<i64: 1, 1, 1, 1, 3, 1, 2>} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32>
+  %0 = tosa.tile %arg0, %cst : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x39x21x6xf32>
   return %0 : tensor<1x1x1x1x39x21x6xf32>
 }
 
@@ -740,4 +741,3 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) {
           (tensor<*xf32>) -> tensor<*xf32>
   return
 }
-
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f2e1cff72ab281..690e208af1e5f9 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -562,7 +562,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
 // -----
 // CHECK-LABEL: tile
 func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> {
-  %0 = tosa.tile %arg0 {multiples = array<i64: 3, 1, 2>} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32>
+  %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32>
   return %0 : tensor<39x21x6xf32>
 }
 
@@ -692,3 +693,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
   %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
   return %0 : tensor<10xi32>
 }
+
+// -----
+// CHECK-LABEL: const_shape
+func.func @test_const_shape() -> !tosa.shape<4> {
+  %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
+  return %cst : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 82f3e22a387221..f4da66ef561b26 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -543,8 +543,10 @@ func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
 
 // CHECK-LABEL: @test_tile
 func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
-  // CHECK: tosa.tile %arg0 {multiples = array<i64: 2, 1, 5>} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32>
-  %0 = tosa.tile %arg0 {multiples = array<i64: 2, 1, 5>} : (tensor<2x3x?xi32>) -> tensor<?x?x?xi32>
+  // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x3x?xi32>
+  %cst = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
   return
 }
 
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5c2a77ca67fd4c..1211759a86c5eb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -12115,6 +12115,14 @@ gentbl_cc_library(
             ["-gen-dialect-defs"],
             "include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc",
         ),
+        (
+            ["-gen-typedef-decls"],
+            "include/mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc",
+        ),
+        (
+            ["-gen-typedef-defs"],
+            "include/mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc",
+        ),
         (
             ["-gen-attrdef-decls"],
             "include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc",



More information about the llvm-commits mailing list