[Mlir-commits] [mlir] a813e9b - [MLIR][TOSA] Added Tosa to Standard/SCF Lowerings (const, if, while)
Rob Suderman
llvmlistbot at llvm.org
Thu Feb 25 14:43:06 PST 2021
Author: Rob Suderman
Date: 2021-02-25T14:35:21-08:00
New Revision: a813e9be5bc91203508bde239c1a15c5b8f8c0cc
URL: https://github.com/llvm/llvm-project/commit/a813e9be5bc91203508bde239c1a15c5b8f8c0cc
DIFF: https://github.com/llvm/llvm-project/commit/a813e9be5bc91203508bde239c1a15c5b8f8c0cc.diff
LOG: [MLIR][TOSA] Added Tosa to Standard/SCF Lowerings (const, if, while)
Includes a lowering for tosa.const, tosa.if, and tosa.while to Standard/SCF dialects. TosaToStandard is
used for constant lowerings and TosaToSCF handles the if/while ops.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D97352
Added:
mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/PassDetail.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 121dae6f46f8..21e604eabecd 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -31,6 +31,8 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index aa228784e48a..f37283868a8e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -440,6 +440,36 @@ def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> {
let constructor = "tosa::createTosaToLinalgOnTensors()";
}
+//===----------------------------------------------------------------------===//
+// TosaToSCF
+//===----------------------------------------------------------------------===//
+
+def TosaToSCF : Pass<"tosa-to-scf"> {
+ let summary = "Lower TOSA to the SCF dialect";
+ let dependentDialects = ["tensor::TensorDialect, scf::SCFDialect"];
+ let description = [{
+ Pass that converts TOSA's control flow operations to the equivalent SCF
+ operations.
+ }];
+
+ let constructor = "tosa::createTosaToSCF()";
+}
+
+//===----------------------------------------------------------------------===//
+// TosaToStandard
+//===----------------------------------------------------------------------===//
+
+def TosaToStandard : Pass<"tosa-to-standard"> {
+ let summary = "Lower TOSA to the Standard dialect";
+ let dependentDialects = ["StandardOpsDialect"];
+ let description = [{
+ Pass that converts TOSA operations to the equivalent operations using the
+ operations in the Standard dialect.
+ }];
+
+ let constructor = "tosa::createTosaToStandard()";
+}
+
//===----------------------------------------------------------------------===//
// VectorToSCF
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
new file mode 100644
index 000000000000..68ed0e0b6525
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
@@ -0,0 +1,32 @@
+//===-- TosaToSCF.h - TOSA to SCF dialect lowerings -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to SCF Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
+#define MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToSCF();
+
+void populateTosaToSCFConversionPatterns(MLIRContext *context,
+ OwningRewritePatternList *patterns);
+
+/// Populates passes to convert from TOSA to SCF.
+void addTosaToSCFPasses(OpPassManager &pm);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
new file mode 100644
index 000000000000..82555003661e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
@@ -0,0 +1,32 @@
+//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to Standard Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToStandard();
+
+void populateTosaToStandardConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns);
+
+/// Populates passes to convert from TOSA to Standard.
+void addTosaToStandardPasses(OpPassManager &pm);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6ba8d415e30b..2f8008489df5 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -22,6 +22,8 @@ add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
add_subdirectory(TosaToLinalg)
+add_subdirectory(TosaToSCF)
+add_subdirectory(TosaToStandard)
add_subdirectory(ArmSVEToLLVM)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index c0e1791dc59b..7c1db73d486a 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -59,6 +59,14 @@ namespace spirv {
class SPIRVDialect;
} // end namespace spirv
+namespace tensor {
+class TensorDialect;
+} // end namespace tensor
+
+namespace tosa {
+class TosaDialect;
+} // end namespace tosa
+
namespace vector {
class VectorDialect;
} // end namespace vector
diff --git a/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
new file mode 100644
index 000000000000..189c25c2d89c
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRTosaToSCF
+ TosaToSCF.cpp
+ TosaToSCFPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSCF
+ MLIRStandard
+ MLIRPass
+ MLIRTensor
+ MLIRTosa
+ MLIRTosaTransforms
+ MLIRSupport
+ )
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
new file mode 100644
index 000000000000..dfc97dceab84
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -0,0 +1,113 @@
+//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the SCF dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+static void inlineIfCase(Region &srcRegion, Region &dstRegion,
+ OperandRange operands, PatternRewriter &rewriter) {
+ BlockAndValueMapping mapper;
+ dstRegion.takeBody(srcRegion);
+ Block *headBlock = &dstRegion.front();
+ for (auto it : llvm::zip(headBlock->getArguments(), operands))
+ std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+
+ for (auto &block : dstRegion) {
+ llvm::SmallVector<Operation *> toDelete;
+ block.walk([&](tosa::YieldOp yield) {
+ rewriter.setInsertionPoint(yield);
+ rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
+ toDelete.push_back(yield);
+ });
+ for (Operation *val : toDelete)
+ rewriter.eraseOp(val);
+ }
+
+ headBlock->eraseArguments(
+ llvm::to_vector<4>(llvm::seq<unsigned>(0, headBlock->getNumArguments())));
+}
+
+static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
+ OperandRange operands, PatternRewriter &rewriter,
+ bool isCond) {
+ BlockAndValueMapping mapper;
+ dstRegion.takeBody(srcRegion);
+
+ for (auto &block : dstRegion) {
+ llvm::SmallVector<Operation *> toDelete;
+ block.walk([&](tosa::YieldOp yield) {
+ rewriter.setInsertionPoint(yield);
+ if (isCond) {
+ auto condition = rewriter.create<tensor::ExtractOp>(
+ yield.getLoc(), yield.getOperand(0));
+ rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
+ block.getArguments());
+ } else {
+ rewriter.setInsertionPoint(yield);
+ rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
+ }
+ toDelete.push_back(yield);
+ });
+ for (Operation *val : toDelete)
+ rewriter.eraseOp(val);
+ }
+}
+
+namespace {
+
+class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
+public:
+ using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::IfOp op,
+ PatternRewriter &rewriter) const final {
+ auto condition = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.cond());
+ auto newIf = rewriter.replaceOpWithNewOp<scf::IfOp>(op, op.getResultTypes(),
+ condition, true);
+
+ inlineIfCase(op.then_branch(), newIf.thenRegion(), op.inputs(), rewriter);
+ inlineIfCase(op.else_branch(), newIf.elseRegion(), op.inputs(), rewriter);
+ return success();
+ }
+};
+
+class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
+public:
+ using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::WhileOp op,
+ PatternRewriter &rewriter) const final {
+ auto newWhile = rewriter.replaceOpWithNewOp<scf::WhileOp>(
+ op, op.getResultTypes(), op.inputs());
+
+ inlineWhileCase(op.cond(), newWhile.before(), op.inputs(), rewriter, true);
+ inlineWhileCase(op.body(), newWhile.after(), op.inputs(), rewriter, false);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToSCFConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns) {
+ patterns->insert<IfOpConverter>(context);
+ patterns->insert<WhileOpConverter>(context);
+}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
new file mode 100644
index 000000000000..a69f15f57b20
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
@@ -0,0 +1,53 @@
+//===- TosaToSCFPass.cpp - Lowering Tosa to SCF Dialect -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the SCF dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
+public:
+ void runOnOperation() override {
+ OwningRewritePatternList patterns;
+ TypeConverter typeConverter;
+ ConversionTarget target(getContext());
+ target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
+ target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
+
+ auto op = getOperation();
+ mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
+ &patterns);
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToSCF() {
+ return std::make_unique<TosaToSCF>();
+}
+
+void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) {
+ pm.addNestedPass<FuncOp>(createTosaToSCF());
+}
diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
new file mode 100644
index 000000000000..43032f0f5656
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToStandard
+ TosaToStandard.cpp
+ TosaToStandardPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRStandard
+ MLIRPass
+ MLIRTosa
+ MLIRTosaTransforms
+ MLIRSupport
+ )
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
new file mode 100644
index 000000000000..21a8da291aee
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -0,0 +1,40 @@
+//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+
+class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
+public:
+ using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::ConstOp op,
+ PatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value());
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToStandardConversionPatterns(
+ MLIRContext *context, OwningRewritePatternList *patterns) {
+ patterns->insert<ConstOpConverter>(context);
+}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
new file mode 100644
index 000000000000..cb6ffc6f0441
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -0,0 +1,52 @@
+//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
+public:
+ void runOnOperation() override {
+ OwningRewritePatternList patterns;
+ ConversionTarget target(getContext());
+ target.addIllegalOp<tosa::ConstOp>();
+ target.addLegalOp<ConstantOp>();
+
+ auto op = getOperation();
+ mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
+ &patterns);
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToStandard() {
+ return std::make_unique<TosaToStandard>();
+}
+
+void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) {
+ pm.addNestedPass<FuncOp>(createTosaToStandard());
+}
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
new file mode 100644
index 000000000000..82fa2c9f0bb5
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK-LABEL: func @while_test
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>)
+func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
+ // CHECK: [[WHILE:%.+]] = scf.while ([[ARG1:%.+]] = [[ARG0]])
+ %1 = "tosa.while_loop"(%arg0) ( {
+ ^bb0(%arg2: tensor<i32>):
+ // CHECK: "tosa.const"
+ %2 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+ // CHECK: [[COMPARE:%.+]] = "tosa.greater_equal"
+ %3 = "tosa.greater_equal"(%2, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+
+ // CHECK: [[EX:%.+]] = tensor.extract [[COMPARE]]
+ // CHECK: scf.condition([[EX]]) [[ARG1]]
+ "tosa.yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ // CHECK: ^bb0([[ARG1:%.+]]: tensor<i32>)
+ ^bb0(%arg2: tensor<i32>):
+ // CHECK: tosa.const
+ %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+ // CHECK: [[ADD:%.+]] = "tosa.add"
+ %3 = "tosa.add"(%arg2, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+
+ // CHECK: scf.yield [[ADD]]
+ "tosa.yield"(%3) : (tensor<i32>) -> ()
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return %1 : tensor<i32>
+}
+
+// ----
+
+// CHECK-LABEL: func @if_test
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<i1>)
+func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
+ // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
+ // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+
+ // CHECK: scf.yield [[ARG0]]
+ ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
+ "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+
+ // CHECK: } else {
+ }, {
+
+ // CHECK: scf.yield [[ARG1]]
+ ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
+ "tosa.yield"(%arg6) : (tensor<f32>) -> ()
+
+ // CHECK: }
+ // CHECK: return [[IF]]
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>)
+
+ return %0 : tensor<f32>
+}
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
new file mode 100644
index 000000000000..86304dcba862
--- /dev/null
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK-LABEL: func @const_test
+func @const_test() -> (tensor<i32>) {
+ // CHECK: [[C3:%.+]] = constant dense<3> : tensor<i32>
+ %0 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+ // CHECK: return [[C3]]
+ return %0 : tensor<i32>
+}
More information about the Mlir-commits
mailing list