[Mlir-commits] [mlir] TosaToTensor: Support reshape on tensors of unsigned integer (PR #91734)

Matthias Gehre llvmlistbot at llvm.org
Thu May 23 23:53:37 PDT 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/91734

>From b06875a5c12f8ffc00b0caef26033c62b17909b5 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Tue, 7 May 2024 16:27:53 +0200
Subject: [PATCH 1/2] TosaToTensor: Support reshape on tensors of unsigned
 integer

---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  3 ++
 .../Conversion/TosaToTensor/TosaToTensor.h    |  4 ++-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 34 +++++++++++++++++++
 .../Conversion/TosaToTensor/CMakeLists.txt    |  1 +
 .../Conversion/TosaToTensor/TosaToTensor.cpp  | 33 +++++++++++-------
 .../TosaToTensor/TosaToTensorPass.cpp         |  6 +++-
 .../TosaToTensor/tosa-to-tensor.mlir          | 14 ++++++++
 7 files changed, 80 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 5fd77c8a0211a..d3024c7389b9c 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -18,6 +18,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOLINALG
 #define GEN_PASS_DECL_TOSATOLINALGNAMED
@@ -52,6 +53,8 @@ void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
 void populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
 
+void populateTosaToLinalgTypeConversion(TypeConverter &converter);
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index 3953c83f3aa10..76a4b1b156336 100644
--- a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -16,6 +16,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOTENSOR
 #include "mlir/Conversion/Passes.h.inc"
@@ -24,7 +25,8 @@ namespace tosa {
 
 std::unique_ptr<Pass> createTosaToTensor();
 
-void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToTensorConversionPatterns(TypeConverter &converter,
+                                            RewritePatternSet *patterns);
 
 } // namespace tosa
 } // namespace mlir
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e6ba6e6bc602d..dcb15012bda88 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2617,3 +2617,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       TileConverter>(patterns->getContext());
   // clang-format on
 }
+
+void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
+  converter.addConversion([&](Type type) -> std::optional<Type> {
+    if (type.isUnsignedInteger()) {
+      return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
+                              IntegerType::SignednessSemantics::Signless);
+    }
+    return type;
+  });
+  converter.addConversion([&](TensorType type) -> std::optional<Type> {
+    auto converted = converter.convertType(type.getElementType());
+    if (!converted)
+      return {};
+    return type.clone(converted);
+  });
+  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+}
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
index 2870baa20757b..b1e7c9cba1a78 100644
--- a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRTosaToTensor
   MLIRIR
   MLIRPass
   MLIRTosaDialect
+  MLIRTosaToLinalg
   MLIRTosaTransforms
   MLIRSupport
   )
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index cd6da35582469..33f388faf6648 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -225,8 +225,17 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     auto loc = reshape.getLoc();
-    auto resultType = reshape.getResult().getType();
-    auto input = reshape.getInput1();
+    auto resultType = cast_if_present<ShapedType>(
+        getTypeConverter()->convertType(reshape.getType()));
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(reshape.getLoc(),
+                                         "could not convert result type");
+    }
+    auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
+    if (!input) {
+      return rewriter.notifyMatchFailure(reshape.getLoc(),
+                                         "expected input type to be tensor");
+    }
     auto newShape = reshape.getNewShape();
 
     // Infer all intermediate types
@@ -289,12 +298,13 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
   }
 };
 
-class PadConverter : public OpRewritePattern<tosa::PadOp> {
+class PadConverter : public OpConversionPattern<tosa::PadOp> {
 public:
-  using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(tosa::PadOp padOp,
-                                PatternRewriter &rewriter) const final {
+  LogicalResult
+  matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
     auto loc = padOp.getLoc();
     auto input = padOp.getInput1();
     auto padding = padOp.getPadding();
@@ -429,11 +439,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
 } // namespace
 
 void mlir::tosa::populateTosaToTensorConversionPatterns(
-    RewritePatternSet *patterns) {
-  patterns->add<
-    ConcatConverter,
-    PadConverter,
-    ReshapeConverter,
-    SliceConverter
-  >(patterns->getContext());
+    TypeConverter &converter, RewritePatternSet *patterns) {
+  patterns
+      ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
+          converter, patterns->getContext());
 }
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 50dc55667fb94..9ae5edcce291e 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <mlir/Conversion/TosaToLinalg/TosaToLinalg.h>
 
 namespace mlir {
 #define GEN_PASS_DEF_TOSATOTENSOR
@@ -42,7 +43,10 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
     target.addLegalDialect<arith::ArithDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
-    mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+    TypeConverter converter;
+    mlir::tosa::populateTosaToLinalgTypeConversion(converter);
+
+    mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index b8c3d56f21f10..2eddde9a55660 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -405,6 +405,20 @@ func.func @test_reshape_6d_down_s2s_explicit(%arg0: tensor<1x2x3x5x7x11xf32>) ->
 
 // -----
 
+// CHECK-LABEL: @test_reshape_samerank_unsigned
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>)
+func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> {
+  // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8>
+  // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8>
+  // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] output_shape {{\[}}2, 3] : tensor<6xi8> into tensor<2x3xi8>
+  // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8
+  %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xui8>) -> tensor<2x3xui8>
+  // CHECK-NEXT: return %[[CAST2]]
+  return %0 : tensor<2x3xui8>
+}
+
+// -----
+
 // CHECK-LABEL: func @slice
 func.func @slice(%arg0: tensor<6xf32>) ->() {
   // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]

>From 3757481ace0674e795f06fe929f6e1ba62336c45 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 24 May 2024 08:53:02 +0200
Subject: [PATCH 2/2] Move TypeConverter to
 mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp

---
 .../Conversion/TosaToLinalg/TosaToLinalg.h    |  3 --
 .../mlir/Dialect/Tosa/Transforms/Passes.h     |  3 ++
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 34 ------------
 .../Conversion/TosaToTensor/CMakeLists.txt    |  1 -
 .../TosaToTensor/TosaToTensorPass.cpp         |  3 +-
 .../Dialect/Tosa/Transforms/CMakeLists.txt    |  1 +
 .../Tosa/Transforms/TosaTypeConverters.cpp    | 52 +++++++++++++++++++
 7 files changed, 57 insertions(+), 40 deletions(-)
 create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp

diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index d3024c7389b9c..5fd77c8a0211a 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -18,7 +18,6 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
-class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOLINALG
 #define GEN_PASS_DECL_TOSATOLINALGNAMED
@@ -53,8 +52,6 @@ void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
 void populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options);
 
-void populateTosaToLinalgTypeConversion(TypeConverter &converter);
-
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index fbfc56dfe2cf4..1f9522b51a4cf 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -18,6 +18,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 namespace tosa {
 
 #define GEN_PASS_DECL
@@ -38,6 +39,8 @@ void populateTosaConstantReduction(MLIRContext *ctx,
                                    RewritePatternSet &patterns,
                                    bool aggressiveReduceConstant);
 
+void populateTosaTypeConversion(TypeConverter &converter);
+
 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
 std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass(
     const TosaLayerwiseConstantFoldPassOptions &options);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index dcb15012bda88..e6ba6e6bc602d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2617,37 +2617,3 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       TileConverter>(patterns->getContext());
   // clang-format on
 }
-
-void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) {
-  converter.addConversion([&](Type type) -> std::optional<Type> {
-    if (type.isUnsignedInteger()) {
-      return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
-                              IntegerType::SignednessSemantics::Signless);
-    }
-    return type;
-  });
-  converter.addConversion([&](TensorType type) -> std::optional<Type> {
-    auto converted = converter.convertType(type.getElementType());
-    if (!converted)
-      return {};
-    return type.clone(converted);
-  });
-  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                                         ValueRange inputs,
-                                         Location loc) -> std::optional<Value> {
-    if (inputs.size() != 1)
-      return std::nullopt;
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
-  });
-  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                                         ValueRange inputs,
-                                         Location loc) -> std::optional<Value> {
-    if (inputs.size() != 1)
-      return std::nullopt;
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
-  });
-}
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
index b1e7c9cba1a78..2870baa20757b 100644
--- a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -15,7 +15,6 @@ add_mlir_conversion_library(MLIRTosaToTensor
   MLIRIR
   MLIRPass
   MLIRTosaDialect
-  MLIRTosaToLinalg
   MLIRTosaTransforms
   MLIRSupport
   )
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 9ae5edcce291e..fa1c2cf7fba98 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -20,7 +20,6 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <mlir/Conversion/TosaToLinalg/TosaToLinalg.h>
 
 namespace mlir {
 #define GEN_PASS_DEF_TOSATOTENSOR
@@ -44,7 +43,7 @@ struct TosaToTensor : public impl::TosaToTensorBase<TosaToTensor> {
     target.addLegalDialect<tensor::TensorDialect>();
 
     TypeConverter converter;
-    mlir::tosa::populateTosaToLinalgTypeConversion(converter);
+    mlir::tosa::populateTosaTypeConversion(converter);
 
     mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns);
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 0e6510ba1e925..c78a74b874aff 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
   TosaLayerwiseConstantFoldPass.cpp
   TosaMakeBroadcastable.cpp
   TosaOptionalDecompositions.cpp
+  TosaTypeConverters.cpp
   TosaValidation.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
new file mode 100644
index 0000000000000..d2650de8cd7f0
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
@@ -0,0 +1,52 @@
+
+//===- TosaTypeConverters.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Type converters for lowering TOSA to linalg/arith.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) {
+  converter.addConversion([&](Type type) -> std::optional<Type> {
+    if (type.isUnsignedInteger()) {
+      return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(),
+                              IntegerType::SignednessSemantics::Signless);
+    }
+    return type;
+  });
+  converter.addConversion([&](TensorType type) -> std::optional<Type> {
+    auto converted = converter.convertType(type.getElementType());
+    if (!converted)
+      return {};
+    return type.clone(converted);
+  });
+  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                                         ValueRange inputs,
+                                         Location loc) -> std::optional<Value> {
+    if (inputs.size() != 1)
+      return std::nullopt;
+
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+}



More information about the Mlir-commits mailing list