[Mlir-commits] [mlir] 126e7ea - [tosa] Add option to disable tosa.apply_scale lowering in TosaToStandard

Rob Suderman llvmlistbot at llvm.org
Mon Apr 4 12:29:50 PDT 2022


Author: Rob Suderman
Date: 2022-04-04T12:22:12-07:00
New Revision: 126e7eaf0d4ea8033a376e47636db8e4c33cfeba

URL: https://github.com/llvm/llvm-project/commit/126e7eaf0d4ea8033a376e47636db8e4c33cfeba
DIFF: https://github.com/llvm/llvm-project/commit/126e7eaf0d4ea8033a376e47636db8e4c33cfeba.diff

LOG: [tosa] Add option to disable tosa.apply_scale lowering in TosaToStandard

Apply scale should be optionally disabled when lowering via TosaToStandard.
In most cases it should persist until the lowering to specific backend.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D122948

Added: 
    mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
    mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
    mlir/lib/Conversion/TosaToArith/CMakeLists.txt
    mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
    mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
    mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
    mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
    mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
    mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
    mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
    mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
    mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
    mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
    mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index d104220737518..ae05446b7b236 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -47,9 +47,10 @@
 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
-#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f4bd4d8c65ff2..5dcf9d7415964 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -692,6 +692,30 @@ def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// TosaToArith
+//===----------------------------------------------------------------------===//
+
+def TosaToArith : Pass<"tosa-to-arith"> {
+  let summary = "Lower TOSA to the Arith dialect";
+  let dependentDialects = [
+    "arith::ArithmeticDialect",
+  ];
+  let description = [{
+    Pass that converts TOSA operations to the equivalent operations using the
+    operations in the Arith dialect. The ApplyScale operator is optionally
+    included as it is often preserved until the final invocation.
+  }];
+
+  let options = [
+    Option<"includeApplyRescale", "include-apply-rescale",
+           "bool", /*default=*/"false",
+           "Whether to include the lowering for tosa.apply_rescale to arith">
+  ];
+
+  let constructor = "tosa::createTosaToArith()";
+}
+
 //===----------------------------------------------------------------------===//
 // TosaToLinalg
 //===----------------------------------------------------------------------===//
@@ -738,21 +762,20 @@ def TosaToSCF : Pass<"tosa-to-scf"> {
 }
 
 //===----------------------------------------------------------------------===//
-// TosaToStandard
+// TosaToTensor
 //===----------------------------------------------------------------------===//
 
-def TosaToStandard : Pass<"tosa-to-standard"> {
-  let summary = "Lower TOSA to the Standard dialect";
+def TosaToTensor : Pass<"tosa-to-tensor"> {
+  let summary = "Lower TOSA to the Tensor dialect";
   let dependentDialects = [
-    "arith::ArithmeticDialect",
     "tensor::TensorDialect",
   ];
   let description = [{
     Pass that converts TOSA operations to the equivalent operations using the
-    operations in the Standard dialect.
+    operations in the Tensor dialect.
   }];
 
-  let constructor = "tosa::createTosaToStandard()";
+  let constructor = "tosa::createTosaToTensor()";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
new file mode 100644
index 0000000000000..91099fbb4b378
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
@@ -0,0 +1,30 @@
+//===-- TosaToArith.h - TOSA optimization pass declarations --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to Standard Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOARITH_TOSATOARITH_H
+#define MLIR_CONVERSION_TOSATOARITH_TOSATOARITH_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToArith();
+
+void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);
+
+void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOARITH_TOSATOARITH_H

diff  --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
similarity index 51%
rename from mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
rename to mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index fc1284417896c..3a686e5ef00e9 100644
--- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -1,4 +1,4 @@
-//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===//
+//===-- TosaToTensor.h - TOSA to Tensor legalization ------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,25 +10,19 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
-#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#ifndef MLIR_CONVERSION_TOSATOTENSOR_TOSATOTENSOR_H
+#define MLIR_CONVERSION_TOSATOTENSOR_TOSATOTENSOR_H
 
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
 namespace tosa {
 
-std::unique_ptr<Pass> createTosaToStandard();
+std::unique_ptr<Pass> createTosaToTensor();
 
-void populateTosaToStandardConversionPatterns(RewritePatternSet *patterns);
-
-void populateTosaRescaleToStandardConversionPatterns(
-    RewritePatternSet *patterns);
-
-/// Populates passes to convert from TOSA to Standard.
-void addTosaToStandardPasses(OpPassManager &pm);
+void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
 
 } // namespace tosa
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#endif // MLIR_CONVERSION_TOSATOTENSOR_TOSATOTENSOR_H

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6335c4ed371ce..5ef84273cbe0c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,9 +36,10 @@ add_subdirectory(SCFToSPIRV)
 add_subdirectory(ShapeToStandard)
 add_subdirectory(SPIRVToLLVM)
 add_subdirectory(TensorToSPIRV)
+add_subdirectory(TosaToArith)
 add_subdirectory(TosaToLinalg)
 add_subdirectory(TosaToSCF)
-add_subdirectory(TosaToStandard)
+add_subdirectory(TosaToTensor)
 add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToGPU)

diff  --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToArith/CMakeLists.txt
similarity index 72%
rename from mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
rename to mlir/lib/Conversion/TosaToArith/CMakeLists.txt
index 0c96bfc43546b..d5e9b4bc33b1b 100644
--- a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToArith/CMakeLists.txt
@@ -1,6 +1,6 @@
-add_mlir_conversion_library(MLIRTosaToStandard
-  TosaToStandard.cpp
-  TosaToStandardPass.cpp
+add_mlir_conversion_library(MLIRTosaToArith
+  TosaToArith.cpp
+  TosaToArithPass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
similarity index 84%
rename from mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
rename to mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 88207512b2547..225cd66158340 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -1,4 +1,4 @@
-//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect -------------===//
+//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// These rewriters lower from the Tosa to the Standard dialect.
+// These rewriters lower from the Tosa to the Arith dialect.
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -33,24 +32,6 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
   }
 };
 
-class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
-public:
-  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
-                                PatternRewriter &rewriter) const final {
-    Value input = sliceOp.input();
-    SmallVector<int64_t> strides;
-    strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
-
-    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
-        sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
-        ValueRange({}), sliceOp.start(), sliceOp.size(),
-        rewriter.getI64ArrayAttr(strides));
-    return success();
-  }
-};
-
 Type matchContainerType(Type element, Type container) {
   if (auto shapedTy = container.dyn_cast<ShapedType>())
     return shapedTy.clone(element);
@@ -171,13 +152,12 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
 } // namespace
 
-void mlir::tosa::populateTosaToStandardConversionPatterns(
+void mlir::tosa::populateTosaToArithConversionPatterns(
     RewritePatternSet *patterns) {
-  patterns->add<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
-      patterns->getContext());
+  patterns->add<ConstOpConverter>(patterns->getContext());
 }
 
-void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
+void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<ApplyScaleOpConverter>(patterns->getContext());
 }

diff  --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
new file mode 100644
index 0000000000000..03bdd6d2a8053
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
@@ -0,0 +1,52 @@
+//===- TosaToArithPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Arith dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToArith : public TosaToArithBase<TosaToArith> {
+public:
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    ConversionTarget target(getContext());
+    target.addIllegalOp<tosa::ConstOp>();
+    target.addLegalDialect<arith::ArithmeticDialect>();
+
+    mlir::tosa::populateTosaToArithConversionPatterns(&patterns);
+
+    if (this->includeApplyRescale) {
+      mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns);
+      target.addIllegalOp<tosa::ApplyScaleOp>();
+    }
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToArith() {
+  return std::make_unique<TosaToArith>();
+}

diff  --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
new file mode 100644
index 0000000000000..118635e9ce10d
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToTensor
+  TosaToTensor.cpp
+  TosaToTensorPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRTensor
+  MLIRIR
+  MLIRPass
+  MLIRTosa
+  MLIRTosaTransforms
+  MLIRSupport
+  )

diff  --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
new file mode 100644
index 0000000000000..c02108eee265e
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -0,0 +1,47 @@
+//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the Tensor dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+
+class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
+public:
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const final {
+    Value input = sliceOp.input();
+    SmallVector<int64_t> strides;
+    strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
+
+    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+        sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
+        ValueRange({}), sliceOp.start(), sliceOp.size(),
+        rewriter.getI64ArrayAttr(strides));
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToTensorConversionPatterns(
+    RewritePatternSet *patterns) {
+  patterns->add<SliceOpConverter>(patterns->getContext());
+}

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
similarity index 62%
rename from mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
rename to mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 1f3e9ed151ebd..6fe862b46a2ac 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -1,4 +1,4 @@
-//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//===- TosaToTensorPass.cpp - Lowering Tosa to Tensor Dialect -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This transformation pass legalizes Tosa operations to the Standard dialect.
+// This transformation pass legalizes Tosa operations to the Tensor dialect.
 //
 //===----------------------------------------------------------------------===//
 
 #include "../PassDetail.h"
-#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -26,18 +25,16 @@ using namespace mlir;
 using namespace tosa;
 
 namespace {
-struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
+struct TosaToTensor : public TosaToTensorBase<TosaToTensor> {
 public:
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
-    target.addIllegalOp<tosa::ConstOp>();
     target.addIllegalOp<tosa::SliceOp>();
-    target.addIllegalOp<tosa::ApplyScaleOp>();
-    target.addLegalDialect<arith::ArithmeticDialect>();
     target.addLegalDialect<tensor::TensorDialect>();
 
-    mlir::tosa::populateTosaToStandardConversionPatterns(&patterns);
+    mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -45,10 +42,6 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::tosa::createTosaToStandard() {
-  return std::make_unique<TosaToStandard>();
-}
-
-void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) {
-  pm.addNestedPass<FuncOp>(createTosaToStandard());
+std::unique_ptr<Pass> mlir::tosa::createTosaToTensor() {
+  return std::make_unique<TosaToTensor>();
 }

diff  --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
similarity index 95%
rename from mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
rename to mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 99b476b3d0f55..21bb255bcb210 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s
 
 // CHECK-LABEL: func @const_test
 func @const_test() -> (tensor<i32>) {
@@ -11,14 +12,6 @@ func @const_test() -> (tensor<i32>) {
 
 // -----
 
-func @slice(%arg0: tensor<6xf32>) ->() {
-  // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
-  %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>)  -> (tensor<1xf32>)
-  return
-}
-
-// -----
-
 // CHECK-LABEL: @apply_scale_test_i32
 func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8
@@ -50,6 +43,7 @@ func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
   // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
 
+  // SCALE: "tosa.apply_scale"
   %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
   return %0 : i32
 }

diff  --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
new file mode 100644
index 0000000000000..2290937b180ac
--- /dev/null
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
+
+// CHECK-LABLE: func @slice
+func @slice(%arg0: tensor<6xf32>) ->() {
+  // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
+  %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>)  -> (tensor<1xf32>)
+  return
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8130870884b2d..403250becaff8 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2343,9 +2343,10 @@ cc_library(
         ":SPIRVToLLVM",
         ":ShapeToStandard",
         ":TensorToSPIRV",
+        ":TosaToArith",
         ":TosaToLinalg",
         ":TosaToSCF",
-        ":TosaToStandard",
+        ":TosaToTensor",
         ":VectorToGPU",
         ":VectorToLLVM",
         ":VectorToROCDL",
@@ -7437,6 +7438,30 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TosaToArith",
+    srcs = glob([
+        "lib/Conversion/TosaToArith/*.cpp",
+        "lib/Conversion/TosaToArith/*.h",
+    ]) + [":ConversionPassDetail"],
+    hdrs = glob([
+        "include/mlir/Conversion/TosaToArith/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/TosaToArith",
+    ],
+    deps = [
+        ":ArithmeticDialect",
+        ":ConversionPassIncGen",
+        ":FuncDialect",
+        ":IR",
+        ":Pass",
+        ":TosaDialect",
+        ":Transforms",
+    ],
+)
+
 cc_library(
     name = "TosaToLinalg",
     srcs = glob([
@@ -7492,20 +7517,19 @@ cc_library(
 )
 
 cc_library(
-    name = "TosaToStandard",
+    name = "TosaToTensor",
     srcs = glob([
-        "lib/Conversion/TosaToStandard/*.cpp",
-        "lib/Conversion/TosaToStandard/*.h",
+        "lib/Conversion/TosaToTensor/*.cpp",
+        "lib/Conversion/TosaToTensor/*.h",
     ]) + [":ConversionPassDetail"],
     hdrs = glob([
-        "include/mlir/Conversion/TosaToStandard/*.h",
+        "include/mlir/Conversion/TosaToTensor/*.h",
     ]),
     includes = [
         "include",
-        "lib/Conversion/TosaToStandard",
+        "lib/Conversion/TosaToTensor",
     ],
     deps = [
-        ":ArithmeticDialect",
         ":ConversionPassIncGen",
         ":FuncDialect",
         ":IR",


        


More information about the Mlir-commits mailing list