[Mlir-commits] [mlir] 126e7ea - [tosa] Add option to disable tosa.apply_scale lowering in TosaToStandard
Rob Suderman
llvmlistbot at llvm.org
Mon Apr 4 12:29:50 PDT 2022
Author: Rob Suderman
Date: 2022-04-04T12:22:12-07:00
New Revision: 126e7eaf0d4ea8033a376e47636db8e4c33cfeba
URL: https://github.com/llvm/llvm-project/commit/126e7eaf0d4ea8033a376e47636db8e4c33cfeba
DIFF: https://github.com/llvm/llvm-project/commit/126e7eaf0d4ea8033a376e47636db8e4c33cfeba.diff
LOG: [tosa] Add option to disable tosa.apply_scale lowering in TosaToStandard
Apply scale should be optionally disabled when lowering via TosaToStandard.
In most cases it should persist until the lowering to specific backend.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D122948
Added:
mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
mlir/lib/Conversion/TosaToArith/CMakeLists.txt
mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index d104220737518..ae05446b7b236 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -47,9 +47,10 @@
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
-#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f4bd4d8c65ff2..5dcf9d7415964 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -692,6 +692,30 @@ def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// TosaToArith
+//===----------------------------------------------------------------------===//
+
+def TosaToArith : Pass<"tosa-to-arith"> {
+ let summary = "Lower TOSA to the Arith dialect";
+ let dependentDialects = [
+ "arith::ArithmeticDialect",
+ ];
+ let description = [{
+ Pass that converts TOSA operations to the equivalent operations using the
+ operations in the Arith dialect. The ApplyScale operator is optionally
+ included as it is often preserved until the final invocation.
+ }];
+
+ let options = [
+ Option<"includeApplyRescale", "include-apply-rescale",
+ "bool", /*default=*/"false",
+ "Whether to include the lowering for tosa.apply_rescale to arith">
+ ];
+
+ let constructor = "tosa::createTosaToArith()";
+}
+
//===----------------------------------------------------------------------===//
// TosaToLinalg
//===----------------------------------------------------------------------===//
@@ -738,21 +762,20 @@ def TosaToSCF : Pass<"tosa-to-scf"> {
}
//===----------------------------------------------------------------------===//
-// TosaToStandard
+// TosaToTensor
//===----------------------------------------------------------------------===//
-def TosaToStandard : Pass<"tosa-to-standard"> {
- let summary = "Lower TOSA to the Standard dialect";
+def TosaToTensor : Pass<"tosa-to-tensor"> {
+ let summary = "Lower TOSA to the Tensor dialect";
let dependentDialects = [
- "arith::ArithmeticDialect",
"tensor::TensorDialect",
];
let description = [{
Pass that converts TOSA operations to the equivalent operations using the
- operations in the Standard dialect.
+ operations in the Tensor dialect.
}];
- let constructor = "tosa::createTosaToStandard()";
+ let constructor = "tosa::createTosaToTensor()";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
new file mode 100644
index 0000000000000..91099fbb4b378
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
@@ -0,0 +1,30 @@
+//===-- TosaToArith.h - TOSA optimization pass declarations --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to Standard Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOARITH_TOSATOARITH_H
+#define MLIR_CONVERSION_TOSATOARITH_TOSATOARITH_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToArith();
+
+void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);
+
+void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOARITH_TOSATOARITH_H
diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
similarity index 51%
rename from mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
rename to mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
index fc1284417896c..3a686e5ef00e9 100644
--- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
+++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h
@@ -1,4 +1,4 @@
-//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===//
+//===-- TosaToTensor.h - TOSA to Tensor legalization ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,25 +10,19 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
-#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#ifndef MLIR_CONVERSION_TOSATOTENSOR_TOSATOTENSOR_H
+#define MLIR_CONVERSION_TOSATOTENSOR_TOSATOTENSOR_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace tosa {
-std::unique_ptr<Pass> createTosaToStandard();
+std::unique_ptr<Pass> createTosaToTensor();
-void populateTosaToStandardConversionPatterns(RewritePatternSet *patterns);
-
-void populateTosaRescaleToStandardConversionPatterns(
- RewritePatternSet *patterns);
-
-/// Populates passes to convert from TOSA to Standard.
-void addTosaToStandardPasses(OpPassManager &pm);
+void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns);
} // namespace tosa
} // namespace mlir
-#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#endif // MLIR_CONVERSION_TOSATOTENSOR_TOSATOTENSOR_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6335c4ed371ce..5ef84273cbe0c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,9 +36,10 @@ add_subdirectory(SCFToSPIRV)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(TensorToSPIRV)
+add_subdirectory(TosaToArith)
add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToSCF)
-add_subdirectory(TosaToStandard)
+add_subdirectory(TosaToTensor)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToGPU)
diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToArith/CMakeLists.txt
similarity index 72%
rename from mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
rename to mlir/lib/Conversion/TosaToArith/CMakeLists.txt
index 0c96bfc43546b..d5e9b4bc33b1b 100644
--- a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToArith/CMakeLists.txt
@@ -1,6 +1,6 @@
-add_mlir_conversion_library(MLIRTosaToStandard
- TosaToStandard.cpp
- TosaToStandardPass.cpp
+add_mlir_conversion_library(MLIRTosaToArith
+ TosaToArith.cpp
+ TosaToArithPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
similarity index 84%
rename from mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
rename to mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 88207512b2547..225cd66158340 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -1,4 +1,4 @@
-//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect -------------===//
+//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
-// These rewriters lower from the Tosa to the Standard dialect.
+// These rewriters lower from the Tosa to the Arith dialect.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -33,24 +32,6 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
}
};
-class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
-public:
- using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
- PatternRewriter &rewriter) const final {
- Value input = sliceOp.input();
- SmallVector<int64_t> strides;
- strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
-
- rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
- sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
- ValueRange({}), sliceOp.start(), sliceOp.size(),
- rewriter.getI64ArrayAttr(strides));
- return success();
- }
-};
-
Type matchContainerType(Type element, Type container) {
if (auto shapedTy = container.dyn_cast<ShapedType>())
return shapedTy.clone(element);
@@ -171,13 +152,12 @@ class ApplyScaleOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
} // namespace
-void mlir::tosa::populateTosaToStandardConversionPatterns(
+void mlir::tosa::populateTosaToArithConversionPatterns(
RewritePatternSet *patterns) {
- patterns->add<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
- patterns->getContext());
+ patterns->add<ConstOpConverter>(patterns->getContext());
}
-void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
+void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<ApplyScaleOpConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
new file mode 100644
index 0000000000000..03bdd6d2a8053
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
@@ -0,0 +1,52 @@
+//===- TosaToArithPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Arith dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToArith/TosaToArith.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToArith : public TosaToArithBase<TosaToArith> {
+public:
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ ConversionTarget target(getContext());
+ target.addIllegalOp<tosa::ConstOp>();
+ target.addLegalDialect<arith::ArithmeticDialect>();
+
+ mlir::tosa::populateTosaToArithConversionPatterns(&patterns);
+
+ if (this->includeApplyRescale) {
+ mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns);
+ target.addIllegalOp<tosa::ApplyScaleOp>();
+ }
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToArith() {
+ return std::make_unique<TosaToArith>();
+}
diff --git a/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
new file mode 100644
index 0000000000000..118635e9ce10d
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToTensor/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToTensor
+ TosaToTensor.cpp
+ TosaToTensorPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRTensor
+ MLIRIR
+ MLIRPass
+ MLIRTosa
+ MLIRTosaTransforms
+ MLIRSupport
+ )
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
new file mode 100644
index 0000000000000..c02108eee265e
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -0,0 +1,47 @@
+//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the Tensor dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+
+class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
+public:
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const final {
+ Value input = sliceOp.input();
+ SmallVector<int64_t> strides;
+ strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+ sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
+ ValueRange({}), sliceOp.start(), sliceOp.size(),
+ rewriter.getI64ArrayAttr(strides));
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToTensorConversionPatterns(
+ RewritePatternSet *patterns) {
+ patterns->add<SliceOpConverter>(patterns->getContext());
+}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
similarity index 62%
rename from mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
rename to mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
index 1f3e9ed151ebd..6fe862b46a2ac 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp
@@ -1,4 +1,4 @@
-//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//===- TosaToTensorPass.cpp - Lowering Tosa to Tensor Dialect -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
-// This transformation pass legalizes Tosa operations to the Standard dialect.
+// This transformation pass legalizes Tosa operations to the Tensor dialect.
//
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
-#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
@@ -26,18 +25,16 @@ using namespace mlir;
using namespace tosa;
namespace {
-struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
+struct TosaToTensor : public TosaToTensorBase<TosaToTensor> {
public:
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
- target.addIllegalOp<tosa::ConstOp>();
target.addIllegalOp<tosa::SliceOp>();
- target.addIllegalOp<tosa::ApplyScaleOp>();
- target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<tensor::TensorDialect>();
- mlir::tosa::populateTosaToStandardConversionPatterns(&patterns);
+ mlir::tosa::populateTosaToTensorConversionPatterns(&patterns);
+
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -45,10 +42,6 @@ struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
};
} // namespace
-std::unique_ptr<Pass> mlir::tosa::createTosaToStandard() {
- return std::make_unique<TosaToStandard>();
-}
-
-void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) {
- pm.addNestedPass<FuncOp>(createTosaToStandard());
+std::unique_ptr<Pass> mlir::tosa::createTosaToTensor() {
+ return std::make_unique<TosaToTensor>();
}
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
similarity index 95%
rename from mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
rename to mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 99b476b3d0f55..21bb255bcb210 100644
--- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true" %s -verify-diagnostics -o -| FileCheck %s
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s
// CHECK-LABEL: func @const_test
func @const_test() -> (tensor<i32>) {
@@ -11,14 +12,6 @@ func @const_test() -> (tensor<i32>) {
// -----
-func @slice(%arg0: tensor<6xf32>) ->() {
- // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
- %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
- return
-}
-
-// -----
-
// CHECK-LABEL: @apply_scale_test_i32
func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8
@@ -50,6 +43,7 @@ func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]]
// CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]]
+ // SCALE: "tosa.apply_scale"
%0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32
return %0 : i32
}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
new file mode 100644
index 0000000000000..2290937b180ac
--- /dev/null
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-tensor %s -o -| FileCheck %s
+
+// CHECK-LABLE: func @slice
+func @slice(%arg0: tensor<6xf32>) ->() {
+ // CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
+ %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8130870884b2d..403250becaff8 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2343,9 +2343,10 @@ cc_library(
":SPIRVToLLVM",
":ShapeToStandard",
":TensorToSPIRV",
+ ":TosaToArith",
":TosaToLinalg",
":TosaToSCF",
- ":TosaToStandard",
+ ":TosaToTensor",
":VectorToGPU",
":VectorToLLVM",
":VectorToROCDL",
@@ -7437,6 +7438,30 @@ cc_library(
],
)
+cc_library(
+ name = "TosaToArith",
+ srcs = glob([
+ "lib/Conversion/TosaToArith/*.cpp",
+ "lib/Conversion/TosaToArith/*.h",
+ ]) + [":ConversionPassDetail"],
+ hdrs = glob([
+ "include/mlir/Conversion/TosaToArith/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/TosaToArith",
+ ],
+ deps = [
+ ":ArithmeticDialect",
+ ":ConversionPassIncGen",
+ ":FuncDialect",
+ ":IR",
+ ":Pass",
+ ":TosaDialect",
+ ":Transforms",
+ ],
+)
+
cc_library(
name = "TosaToLinalg",
srcs = glob([
@@ -7492,20 +7517,19 @@ cc_library(
)
cc_library(
- name = "TosaToStandard",
+ name = "TosaToTensor",
srcs = glob([
- "lib/Conversion/TosaToStandard/*.cpp",
- "lib/Conversion/TosaToStandard/*.h",
+ "lib/Conversion/TosaToTensor/*.cpp",
+ "lib/Conversion/TosaToTensor/*.h",
]) + [":ConversionPassDetail"],
hdrs = glob([
- "include/mlir/Conversion/TosaToStandard/*.h",
+ "include/mlir/Conversion/TosaToTensor/*.h",
]),
includes = [
"include",
- "lib/Conversion/TosaToStandard",
+ "lib/Conversion/TosaToTensor",
],
deps = [
- ":ArithmeticDialect",
":ConversionPassIncGen",
":FuncDialect",
":IR",
More information about the Mlir-commits
mailing list