[Mlir-commits] [mlir] [mlir][tosa] Add tosa.dim to the TOSA dialect (PR #77706)

Spenser Bauman llvmlistbot at llvm.org
Wed Jan 10 16:06:35 PST 2024


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/77706

The TOSA Spec defines a tosa.dim operation which is currently missing from the TOSA dialect definition.

This change adds:

1. tosa.dim to the dialect
2. some folding rules for tosa.dim
3. a conversion pattern to tosa-to-tensor for tosa.dim

>From 3b5d0adf88ba2e96075ad98d4039d14648a5ad21 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Sun, 7 Jan 2024 13:08:38 -0500
Subject: [PATCH] [mlir][tosa] Add tosa.dim to the TOSA dialect

The TOSA Spec defines a tosa.dim operation which is currently missing
from the TOSA dialect definition.

This change adds:

1. tosa.dim to the dialect
2. some folding rules for tosa.dim
3. a conversion pattern to tosa-to-tensor for tosa.dim
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 31 +++++++++++++++++--
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     | 30 ++++++++++--------
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 27 ++++++++++++++++
 .../TosaToTensor/TosaToTensorPass.cpp         |  1 +
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 21 +++++++++++++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |  9 ++++++
 .../TosaToTensor/tosa-to-tensor.mlir          | 17 ++++++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 22 ++++++++++++-
 mlir/test/Dialect/Tosa/ops.mlir               |  7 +++++
 9 files changed, 149 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9dde59f634d739..b17ad427bd515a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1253,7 +1253,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
 }
 
 //===----------------------------------------------------------------------===//
-// TOSA Spec Section 2.8
+// TOSA Spec Section 2.9
 // Operator Class: Reduction Ops.
 //===----------------------------------------------------------------------===//
 
@@ -1464,7 +1464,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
 }
 
 //===----------------------------------------------------------------------===//
-// TOSA Spec Section 2.9
+// TOSA Spec Section 2.10
 // Operator Class: Data Layout / Memory Reinterpretation.
 //===----------------------------------------------------------------------===//
 
@@ -1542,6 +1542,33 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: dim
+//===----------------------------------------------------------------------===//
+def Tosa_DimOp : Tosa_InferShapedTypeOp<"dim"> {
+  let summary = "The size of the specified dimension.";
+
+  let description = [{
+    Returns a rank zero tensor whose value is the runtime size of the input tensor
+    along the dimension of the axis input.
+
+    ```mlir
+    %size = tosa.dim %input1, %axis : (tensor<?x?xf32>, tensor<i32>)  -> (tensor<i32>)
+    ```
+  }];
+
+  let arguments = (ins
+    Tosa_RankedTensor:$input1,
+    TensorRankOf<[Tosa_Int32], [0]>:$axis
+  );
+
+  let results = (outs
+    TensorRankOf<[Tosa_SignedInt], [0]>:$output
+  );
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: reshape
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index c55ddaafdda76e..de916f0e51411a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -52,7 +52,9 @@ def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
                                 Tosa_Int16,
                                 Tosa_Int32,
                                 Tosa_Int48,
-                                Tosa_Int64]>;
+                                Tosa_Int64],
+                               "TOSA signed integer type",
+                               "::mlir::IntegerType">;
 
 def Tosa_Bool : I<1>;
 
@@ -60,10 +62,13 @@ def Tosa_Bool : I<1>;
 def Tosa_Int : AnyTypeOf<[Tosa_Bool,
                           Tosa_UInt8,
                           Tosa_UInt16,
-                          Tosa_SignedInt]>;
+                          Tosa_SignedInt],
+                         "TOSA integer type",
+                         "::mlir::IntegerType">;
 
-def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
-                   	        Tosa_Int64]>;
+def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, Tosa_Int64],
+                               "TOSA 32/64 bit integer type",
+                               "::mlir::IntegerType">;
 
 //===----------------------------------------------------------------------===//
 // Quantized Integer Types.
@@ -77,19 +82,18 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
 // int8  : symmetric  per tensor/per channel, signed
 // int16 : symmetric  per tensor,             signed
 //===----------------------------------------------------------------------===//
-def Tosa_QuantizedInt	: AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
-                                     Tosa_QuantizedType<"int4", [4, 0], 1>,
-                                     Tosa_QuantizedType<"int8", [8, 0], 1>,
-                                     Tosa_QuantizedType<"int16", [16, 0], 1>,
-                                     Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
+                                   Tosa_QuantizedType<"int4", [4, 0], 1>,
+                                   Tosa_QuantizedType<"int8", [8, 0], 1>,
+                                   Tosa_QuantizedType<"int16", [16, 0], 1>,
+                                   Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
 //===----------------------------------------------------------------------===//
 // Floating-point types.
 //===----------------------------------------------------------------------===//
-def Tosa_Float : AnyTypeOf<[
-                            F32,
-			    F16,
-			    BF16]>;
+def Tosa_Float : AnyTypeOf<[F32,
+                            F16,
+                            BF16]>;
 
 //===----------------------------------------------------------------------===//
 // Multi-category types.
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 06ec53d19b1e95..24d15c4f08a27f 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -412,6 +412,32 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
   }
 };
 
+struct DimConverter : public OpConversionPattern<tosa::DimOp> {
+  using OpConversionPattern<tosa::DimOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::DimOp op,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter& rewriter) const override {
+    auto loc = op.getLoc();
+    auto input = adaptor.getInput1();
+    auto axis = adaptor.getAxis();
+
+    auto axisScalar = rewriter.create<tensor::ExtractOp>(loc, axis, ValueRange{});
+    auto axisScalarIndex = rewriter.create<arith::IndexCastOp>(
+      loc, rewriter.getIndexType(), axisScalar);
+
+    auto dimSize = rewriter.create<tensor::DimOp>(loc, input, axisScalarIndex);
+    auto dimSizeAsInteger = rewriter.create<arith::IndexCastOp>(
+      loc, op.getType().getElementType(), dimSize);
+
+    rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
+      op, op.getType(), ValueRange{dimSizeAsInteger});
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
@@ -420,4 +446,5 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
       patterns->getContext());
 
   patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
+  patterns->add<DimConverter>(patterns->getContext());
 }
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94e..1455dd3ff9ffd5 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -36,6 +36,7 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::ConcatOp>();
+    target.addIllegalOp<tosa::DimOp>();
     target.addIllegalOp<tosa::ReshapeOp>();
     target.addIllegalOp<tosa::SliceOp>();
     target.addIllegalOp<tosa::PadOp>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..3aff47fd2fe4d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -836,6 +836,27 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
+  auto axis = adaptor.getAxis();
+  if (!axis)
+    return {};
+
+  auto axisIndex =
+      llvm::cast<DenseIntElementsAttr>(axis).getSplatValue<int32_t>();
+
+  auto inputType = getInput1().getType();
+
+  // Do not fold if the axis is out of bounds or the axis
+  // size is not known statically.
+  if (axisIndex < 0 || axisIndex >= inputType.getRank() ||
+      inputType.isDynamicDim(axisIndex))
+    return {};
+
+  auto elementType = getType().getElementType();
+  auto attr = IntegerAttr::get(elementType, inputType.getDimSize(axisIndex));
+  return DenseIntElementsAttr::get(getType(), {static_cast<Attribute>(attr)});
+}
+
 // Fold away cases where a tosa.resize operation returns a copy
 // of the input image.
 OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 661126f4df9976..945e1194b4b841 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -806,6 +806,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::DimOp::inferReturnTypeComponents(
+    MLIRContext *ctx, ::std::optional<Location> location,
+    DimOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  SmallVector<int64_t> outputShape;
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
 static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
   return to_vector(llvm::map_range(shape, [](int64_t dim) {
     return dim == -1 ? ShapedType::kDynamic : dim;
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index daaa68a7260b71..ce6a88cfb9957f 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -313,3 +313,20 @@ func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf
   %0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i32}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @dim_op
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME:  %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @dim_op(%arg0: tensor<?x1xf32>, %arg1: tensor<i32>) -> (tensor<i32>) {
+  // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG1]][] : tensor<i32>
+  // CHECK: %[[AS_INDEX:.+]] = arith.index_cast %[[EXTRACTED]] : i32 to index
+  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[AS_INDEX]] : tensor<?x1xf32>
+  // CHECK: %[[AS_INT:.+]] = arith.index_cast %[[DIM]] : index to i32
+  // CHECK: %[[RESULT:.+]] = tensor.from_elements %[[AS_INT]] : tensor<i32>
+  // CHECK: return %[[RESULT]] : tensor<i32>
+
+  %0 = tosa.dim %arg0, %arg1 : (tensor<?x1xf32>, tensor<i32>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca058..325994dfb33d1a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
+// RUN: mlir-opt -canonicalize="test-convergence" -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
 func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
@@ -613,3 +613,23 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
   %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
   return %1 : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold
+func.func nested @fold_tile_rank_zero() -> tensor<i32> {
+  // CHECK-NOT: tosa.tile
+  %0 = tensor.empty() : tensor<i32>
+  %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+  return %1 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_dim_op
+func.func @fold_dim_op(%arg0: tensor<1x2x3xf32>) -> tensor<i32> {
+  // CHECK-NOT: tosa.dim
+  %axis = arith.constant dense<0> : tensor<i32>
+  %1 = tosa.dim %arg0, %axis  : (tensor<1x2x3xf32>, tensor<i32>) -> tensor<i32>
+  return %1 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a3d0b5e5447aa3..25f828337f1255 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -661,3 +661,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
   %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
   return %0 : tensor<10xi32>
 }
+
+// -----
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<10xi32>, %arg1: tensor<i32>) -> tensor<i32> {
+  %0 = tosa.dim %arg0, %arg1  : (tensor<10xi32>, tensor<i32>) -> (tensor<i32>)
+  return %0 : tensor<i32>
+}



More information about the Mlir-commits mailing list