[Mlir-commits] [mlir] [mlir][tosa] Add tosa.dim to the TOSA dialect (PR #77706)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 10 16:07:02 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Spenser Bauman (sabauma)
<details>
<summary>Changes</summary>
The TOSA Spec defines a tosa.dim operation which is currently missing from the TOSA dialect definition.
This change adds:
1. tosa.dim to the dialect
2. some folding rules for tosa.dim
3. a conversion pattern to tosa-to-tensor for tosa.dim
---
Full diff: https://github.com/llvm/llvm-project/pull/77706.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+29-2)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+17-13)
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+27)
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp (+1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+21)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+9)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+17)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+21-1)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+7)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9dde59f634d739..b17ad427bd515a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1253,7 +1253,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 2.8
+// TOSA Spec Section 2.9
// Operator Class: Reduction Ops.
//===----------------------------------------------------------------------===//
@@ -1464,7 +1464,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
}
//===----------------------------------------------------------------------===//
-// TOSA Spec Section 2.9
+// TOSA Spec Section 2.10
// Operator Class: Data Layout / Memory Reinterpretation.
//===----------------------------------------------------------------------===//
@@ -1542,6 +1542,33 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: dim
+//===----------------------------------------------------------------------===//
+def Tosa_DimOp : Tosa_InferShapedTypeOp<"dim"> {
+ let summary = "The size of the specified dimension.";
+
+ let description = [{
+ Returns a rank zero tensor whose value is the runtime size of the input tensor
+ along the dimension of the axis input.
+
+ ```mlir
+ %size = tosa.dim %input1, %axis : (tensor<?x?xf32>, tensor<i32>) -> (tensor<i32>)
+ ```
+ }];
+
+ let arguments = (ins
+ Tosa_RankedTensor:$input1,
+ TensorRankOf<[Tosa_Int32], [0]>:$axis
+ );
+
+ let results = (outs
+ TensorRankOf<[Tosa_SignedInt], [0]>:$output
+ );
+
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: reshape
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index c55ddaafdda76e..de916f0e51411a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -52,7 +52,9 @@ def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
Tosa_Int16,
Tosa_Int32,
Tosa_Int48,
- Tosa_Int64]>;
+ Tosa_Int64],
+ "TOSA signed integer type",
+ "::mlir::IntegerType">;
def Tosa_Bool : I<1>;
@@ -60,10 +62,13 @@ def Tosa_Bool : I<1>;
def Tosa_Int : AnyTypeOf<[Tosa_Bool,
Tosa_UInt8,
Tosa_UInt16,
- Tosa_SignedInt]>;
+ Tosa_SignedInt],
+ "TOSA integer type",
+ "::mlir::IntegerType">;
-def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
- Tosa_Int64]>;
+def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, Tosa_Int64],
+ "TOSA 32/64 bit integer type",
+ "::mlir::IntegerType">;
//===----------------------------------------------------------------------===//
// Quantized Integer Types.
@@ -77,19 +82,18 @@ def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
// int8 : symmetric per tensor/per channel, signed
// int16 : symmetric per tensor, signed
//===----------------------------------------------------------------------===//
-def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
- Tosa_QuantizedType<"int4", [4, 0], 1>,
- Tosa_QuantizedType<"int8", [8, 0], 1>,
- Tosa_QuantizedType<"int16", [16, 0], 1>,
- Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
+ Tosa_QuantizedType<"int4", [4, 0], 1>,
+ Tosa_QuantizedType<"int8", [8, 0], 1>,
+ Tosa_QuantizedType<"int16", [16, 0], 1>,
+ Tosa_QuantizedType<"int32", [32, 0], 1>]>;
//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
-def Tosa_Float : AnyTypeOf<[
- F32,
- F16,
- BF16]>;
+def Tosa_Float : AnyTypeOf<[F32,
+ F16,
+ BF16]>;
//===----------------------------------------------------------------------===//
// Multi-category types.
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 06ec53d19b1e95..24d15c4f08a27f 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -412,6 +412,32 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
}
};
+struct DimConverter : public OpConversionPattern<tosa::DimOp> {
+ using OpConversionPattern<tosa::DimOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::DimOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const override {
+ auto loc = op.getLoc();
+ auto input = adaptor.getInput1();
+ auto axis = adaptor.getAxis();
+
+ auto axisScalar = rewriter.create<tensor::ExtractOp>(loc, axis, ValueRange{});
+ auto axisScalarIndex = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), axisScalar);
+
+ auto dimSize = rewriter.create<tensor::DimOp>(loc, input, axisScalarIndex);
+ auto dimSizeAsInteger = rewriter.create<arith::IndexCastOp>(
+ loc, op.getType().getElementType(), dimSize);
+
+ rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(
+ op, op.getType(), ValueRange{dimSizeAsInteger});
+
+ return success();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
@@ -420,4 +446,5 @@ void mlir::tosa::populateTosaToTensorConversionPatterns(
patterns->getContext());
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
+ patterns->add<DimConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94e..1455dd3ff9ffd5 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -36,6 +36,7 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConcatOp>();
+ target.addIllegalOp<tosa::DimOp>();
target.addIllegalOp<tosa::ReshapeOp>();
target.addIllegalOp<tosa::SliceOp>();
target.addIllegalOp<tosa::PadOp>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 26c39ff3523434..3aff47fd2fe4d1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -836,6 +836,27 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
return {};
}
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
+ auto axis = adaptor.getAxis();
+ if (!axis)
+ return {};
+
+ auto axisIndex =
+ llvm::cast<DenseIntElementsAttr>(axis).getSplatValue<int32_t>();
+
+ auto inputType = getInput1().getType();
+
+ // Do not fold if the axis is out of bounds or the axis
+ // size is not known statically.
+ if (axisIndex < 0 || axisIndex >= inputType.getRank() ||
+ inputType.isDynamicDim(axisIndex))
+ return {};
+
+ auto elementType = getType().getElementType();
+ auto attr = IntegerAttr::get(elementType, inputType.getDimSize(axisIndex));
+ return DenseIntElementsAttr::get(getType(), {static_cast<Attribute>(attr)});
+}
+
// Fold away cases where a tosa.resize operation returns a copy
// of the input image.
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 661126f4df9976..945e1194b4b841 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -806,6 +806,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::DimOp::inferReturnTypeComponents(
+ MLIRContext *ctx, ::std::optional<Location> location,
+ DimOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t> outputShape;
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ return success();
+}
+
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return dim == -1 ? ShapedType::kDynamic : dim;
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index daaa68a7260b71..ce6a88cfb9957f 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -313,3 +313,20 @@ func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf
%0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i32}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @dim_op
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
+func.func @dim_op(%arg0: tensor<?x1xf32>, %arg1: tensor<i32>) -> (tensor<i32>) {
+ // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG1]][] : tensor<i32>
+ // CHECK: %[[AS_INDEX:.+]] = arith.index_cast %[[EXTRACTED]] : i32 to index
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[AS_INDEX]] : tensor<?x1xf32>
+ // CHECK: %[[AS_INT:.+]] = arith.index_cast %[[DIM]] : index to i32
+ // CHECK: %[[RESULT:.+]] = tensor.from_elements %[[AS_INT]] : tensor<i32>
+ // CHECK: return %[[RESULT]] : tensor<i32>
+
+ %0 = tosa.dim %arg0, %arg1 : (tensor<?x1xf32>, tensor<i32>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index fd51d287bca058..325994dfb33d1a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
+// RUN: mlir-opt -canonicalize="test-convergence" -split-input-file %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
@@ -613,3 +613,23 @@ func.func nested @fold_tile_rank_zero() -> tensor<i32> {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @fold
+func.func nested @fold_tile_rank_zero() -> tensor<i32> {
+ // CHECK-NOT: tosa.tile
+ %0 = tensor.empty() : tensor<i32>
+ %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_dim_op
+func.func @fold_dim_op(%arg0: tensor<1x2x3xf32>) -> tensor<i32> {
+ // CHECK-NOT: tosa.dim
+ %axis = arith.constant dense<0> : tensor<i32>
+ %1 = tosa.dim %arg0, %axis : (tensor<1x2x3xf32>, tensor<i32>) -> tensor<i32>
+ return %1 : tensor<i32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a3d0b5e5447aa3..25f828337f1255 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -661,3 +661,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> {
%0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>)
return %0 : tensor<10xi32>
}
+
+// -----
+// CHECK-LABEL: test_dim
+func.func @test_dim(%arg0: tensor<10xi32>, %arg1: tensor<i32>) -> tensor<i32> {
+ %0 = tosa.dim %arg0, %arg1 : (tensor<10xi32>, tensor<i32>) -> (tensor<i32>)
+ return %0 : tensor<i32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/77706
More information about the Mlir-commits
mailing list