[Mlir-commits] [mlir] af22e27 - TosaToTensor: Support reshape on tensors of unsigned integer (#91734)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 08:59:27 PDT 2024
Author: Matthias Gehre
Date: 2024-05-28T17:59:23+02:00
New Revision: af22e274e9c5643780f25066442e05b5bd453328
URL: https://github.com/llvm/llvm-project/commit/af22e274e9c5643780f25066442e05b5bd453328
DIFF: https://github.com/llvm/llvm-project/commit/af22e274e9c5643780f25066442e05b5bd453328.diff
LOG: TosaToTensor: Support reshape on tensors of unsigned integer (#91734)
This adds
- `mlir::tosa::populateTosaToLinalgTypeConversion` which converts
tensors of unsigned integers into tensors of signless integers
- modifies the `tosa.reshape` lowering in TosaToTensor to use the type
converter correctly
I choose to implement the type converter in
`mlir/Conversion/TosaToLinalg/TosaToLinalg.h` instead of
`mlir/Conversion/TosaToTensor/TosaToTensor.h` because I need the same
type converter in the TosaToLinalg lowerings (future PR).
Alternatively, I could duplicate the type converter so it exists both in
TosaToLinalg and TosaToTensor. Let me know if you prefer that.
Added:
mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
Modified:
mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index 3953c83f3aa10..76a4b1b156336 100644
--- a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -16,6 +16,7 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class TypeConverter;
#define GEN_PASS_DECL_TOSATOTENSOR
#include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToTensor();
-void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToTensorConversionPatterns(TypeConverter &converter,
+ RewritePatternSet *patterns);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index fbfc56dfe2cf4..1f9522b51a4cf 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -18,6 +18,7 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+class TypeConverter;
namespace tosa {
#define GEN_PASS_DECL
@@ -38,6 +39,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
RewritePatternSet &patterns,
bool aggressiveReduceConstant);
+void populateTosaTypeConversion(TypeConverter &converter);
+
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(
const TosaLayerwiseConstantFoldPassOptions &options);
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 89f956a5e7017..c0c015ab34aab 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -224,8 +224,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = reshape.getLoc();
- auto resultType = reshape.getResult().getType();
- auto input = reshape.getInput1();
+ auto resultType = cast_if_present<ShapedType>(
+ getTypeConverter()->convertType(reshape.getType()));
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(reshape.getLoc(),
+ "could not convert result type");
+ }
+ auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
+ if (!input) {
+ return rewriter.notifyMatchFailure(reshape.getLoc(),
+ "expected input type to be tensor");
+ }
auto newShape = reshape.getNewShape();
// Infer all intermediate types
@@ -288,12 +297,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
}
};
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
+class PadConverter : public OpConversionPattern<tosa::PadOp> {
public:
- using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+ using OpConversionPattern::OpConversionPattern;
- LogicalResult matchAndRewrite(tosa::PadOp padOp,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
auto padding = padOp.getPadding();
@@ -428,11 +438,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
- RewritePatternSet *patterns) {
- patterns->add<
- ConcatConverter,
- PadConverter,
- ReshapeConverter,
- SliceConverter
- >(patterns->getContext());
+ TypeConverter &converter, RewritePatternSet *patterns) {
+ patterns
+ ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
+ converter, patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94..fa1c2cf7fba98 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -42,7 +42,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<tensor::TensorDialect>();
- mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+ TypeConverter converter;
+ mlir::tosa::populateTosaTypeConversion(converter);
+
+ mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 0e6510ba1e925..c78a74b874aff 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaLayerwiseConstantFoldPass.cpp
TosaMakeBroadcastable.cpp
TosaOptionalDecompositions.cpp
+ TosaTypeConverters.cpp
TosaValidation.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
new file mode 100644
index 0000000000000..d2650de8cd7f0
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
@@ -0,0 +1,52 @@
+
+//===- TosaTypeConverters.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Type converters for lowering TOSA to linalg/arith.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) {
+ converter.addConversion([&](Type type) -> std::optional<Type> {
+ if (type.isUnsignedInteger()) {
+ return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+ return type;
+ });
+ converter.addConversion([&](TensorType type) -> std::optional<Type> {
+ auto converted = converter.convertType(type.getElementType());
+ if (!converted)
+ return {};
+ return type.clone(converted);
+ });
+ converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1)
+ return std::nullopt;
+
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+ converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ if (inputs.size() != 1)
+ return std::nullopt;
+
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 72e7e4cc84088..1e62e25176a00 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -420,6 +420,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
// -----
+// CHECK-LABEL: @test_reshape_samerank_unsigned
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
+func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
+ // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
+ // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
+ // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
+ // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
+ %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
+ // CHECK-NEXT: return %[[CAST2]]
+ return %0 : tensor<2x3xui8>
+}
+
+// -----
+
// CHECK-LABEL: func @slice
func.func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
More information about the Mlir-commits
mailing list