[Mlir-commits] [mlir] 16abaca - [MLIR][TOSA] Resubmit Tosa to Standard/SCF Lowerings (const, if, while)"

Rob Suderman llvmlistbot at llvm.org
Fri Feb 26 17:50:01 PST 2021


Author: Rob Suderman
Date: 2021-02-26T17:44:12-08:00
New Revision: 16abacaea9db653b41808fc37277b68168438059

URL: https://github.com/llvm/llvm-project/commit/16abacaea9db653b41808fc37277b68168438059
DIFF: https://github.com/llvm/llvm-project/commit/16abacaea9db653b41808fc37277b68168438059.diff

LOG: [MLIR][TOSA] Resubmit Tosa to Standard/SCF Lowerings (const, if, while)"

Includes a lowering for tosa.const, tosa.if, and tosa.while to Standard/SCF dialects. TosaToStandard is
used for constant lowerings and TosaToSCF handles the if/while ops.

Resubmission of https://reviews.llvm.org/D97518 with ASAN fixes.

Differential Revision: https://reviews.llvm.org/D97529

Added: 
    mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
    mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
    mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
    mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
    mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
    mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
    mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
    mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
    mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
    mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 121dae6f46f8..21e604eabecd 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -31,6 +31,8 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index aa228784e48a..f37283868a8e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -440,6 +440,36 @@ def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> {
   let constructor = "tosa::createTosaToLinalgOnTensors()";
 }
 
+//===----------------------------------------------------------------------===//
+// TosaToSCF
+//===----------------------------------------------------------------------===//
+
+def TosaToSCF : Pass<"tosa-to-scf"> {
+  let summary = "Lower TOSA to the SCF dialect";
+  let dependentDialects = ["tensor::TensorDialect, scf::SCFDialect"];
+  let description = [{
+    Pass that converts TOSA's control flow operations to the equivalent SCF
+    operations.
+  }];
+
+  let constructor = "tosa::createTosaToSCF()";
+}
+
+//===----------------------------------------------------------------------===//
+// TosaToStandard
+//===----------------------------------------------------------------------===//
+
+def TosaToStandard : Pass<"tosa-to-standard"> {
+  let summary = "Lower TOSA to the Standard dialect";
+  let dependentDialects = ["StandardOpsDialect"];
+  let description = [{
+    Pass that converts TOSA operations to the equivalent operations using the
+    operations in the Standard dialect.
+  }];
+
+  let constructor = "tosa::createTosaToStandard()";
+}
+
 //===----------------------------------------------------------------------===//
 // VectorToSCF
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
new file mode 100644
index 000000000000..68ed0e0b6525
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
@@ -0,0 +1,32 @@
+//===-- TosaToSCF.h - TOSA to SCF dialect lowerings -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to SCF Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
+#define MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToSCF();
+
+void populateTosaToSCFConversionPatterns(MLIRContext *context,
+                                         OwningRewritePatternList *patterns);
+
+/// Populates passes to convert from TOSA to SCF.
+void addTosaToSCFPasses(OpPassManager &pm);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H

diff  --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
new file mode 100644
index 000000000000..82555003661e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
@@ -0,0 +1,32 @@
+//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to Standard Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToStandard();
+
+void populateTosaToStandardConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns);
+
+/// Populates passes to convert from TOSA to Standard.
+void addTosaToStandardPasses(OpPassManager &pm);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6ba8d415e30b..2f8008489df5 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -22,6 +22,8 @@ add_subdirectory(SPIRVToLLVM)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
 add_subdirectory(TosaToLinalg)
+add_subdirectory(TosaToSCF)
+add_subdirectory(TosaToStandard)
 add_subdirectory(ArmSVEToLLVM)
 add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index c0e1791dc59b..7c1db73d486a 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -59,6 +59,14 @@ namespace spirv {
 class SPIRVDialect;
 } // end namespace spirv
 
+namespace tensor {
+class TensorDialect;
+} // end namespace tensor
+
+namespace tosa {
+class TosaDialect;
+} // end namespace tosa
+
 namespace vector {
 class VectorDialect;
 } // end namespace vector

diff  --git a/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
new file mode 100644
index 000000000000..189c25c2d89c
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRTosaToSCF
+  TosaToSCF.cpp
+  TosaToSCFPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSCF
+  MLIRStandard
+  MLIRPass
+  MLIRTensor
+  MLIRTosa
+  MLIRTosaTransforms
+  MLIRSupport
+  )

diff  --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
new file mode 100644
index 000000000000..55ed64b10322
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -0,0 +1,109 @@
+//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the SCF dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+static void inlineIfCase(Region &srcRegion, Region &dstRegion,
+                         OperandRange operands, PatternRewriter &rewriter) {
+  rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
+  rewriter.eraseBlock(&dstRegion.back());
+
+  Block *headBlock = &dstRegion.front();
+  for (auto it : llvm::zip(headBlock->getArguments(), operands))
+    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+
+  auto yield = cast<YieldOp>(headBlock->getTerminator());
+  rewriter.setInsertionPoint(yield);
+  rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
+  rewriter.eraseOp(yield);
+
+  headBlock->eraseArguments(
+      llvm::to_vector<4>(llvm::seq<unsigned>(0, headBlock->getNumArguments())));
+}
+
+static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
+                            PatternRewriter &rewriter, bool isCond) {
+  rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
+  rewriter.eraseBlock(&dstRegion.back());
+
+  Block *headBlock = &dstRegion.front();
+
+  auto yield = cast<YieldOp>(headBlock->getTerminator());
+  rewriter.setInsertionPoint(yield);
+  if (isCond) {
+    auto condition =
+        rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
+    rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
+                                      headBlock->getArguments());
+  } else {
+    rewriter.setInsertionPoint(yield);
+    rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
+  }
+  rewriter.eraseOp(yield);
+}
+
+namespace {
+
+class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
+public:
+  using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::IfOp op,
+                                PatternRewriter &rewriter) const final {
+    auto condition = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.cond());
+    auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
+                                            condition, true);
+
+    inlineIfCase(op.then_branch(), newIf.thenRegion(), op.inputs(), rewriter);
+    inlineIfCase(op.else_branch(), newIf.elseRegion(), op.inputs(), rewriter);
+
+    rewriter.replaceOp(op, newIf.getResults());
+    return success();
+  }
+};
+
+class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
+public:
+  using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::WhileOp op,
+                                PatternRewriter &rewriter) const final {
+    auto newWhile = rewriter.create<scf::WhileOp>(
+        op.getLoc(), op.getResultTypes(), op.inputs());
+    rewriter.createBlock(&newWhile.before());
+    rewriter.createBlock(&newWhile.after());
+
+    inlineWhileCase(op.cond(), newWhile.before(), rewriter, true);
+    inlineWhileCase(op.body(), newWhile.after(), rewriter, false);
+
+    rewriter.replaceOp(op, newWhile.getResults());
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToSCFConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns) {
+  patterns->insert<IfOpConverter>(context);
+  patterns->insert<WhileOpConverter>(context);
+}

diff  --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
new file mode 100644
index 000000000000..f403a4658b97
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
@@ -0,0 +1,53 @@
+//===- TosaToSCFPass.cpp - Lowering Tosa to SCF Dialect -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the SCF dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
+public:
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    ConversionTarget target(getContext());
+    target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
+    target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
+    target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+    auto *op = getOperation();
+    mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
+                                                    &patterns);
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToSCF() {
+  return std::make_unique<TosaToSCF>();
+}
+
+void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTosaToSCF());
+}

diff  --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
new file mode 100644
index 000000000000..43032f0f5656
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToStandard
+  TosaToStandard.cpp
+  TosaToStandardPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRStandard
+  MLIRPass
+  MLIRTosa
+  MLIRTosaTransforms
+  MLIRSupport
+  )

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
new file mode 100644
index 000000000000..21a8da291aee
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -0,0 +1,40 @@
+//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+
+class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
+public:
+  using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ConstOp op,
+                                PatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value());
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToStandardConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns) {
+  patterns->insert<ConstOpConverter>(context);
+}

diff  --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
new file mode 100644
index 000000000000..225855e78bda
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -0,0 +1,52 @@
+//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
+public:
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    ConversionTarget target(getContext());
+    target.addIllegalOp<tosa::ConstOp>();
+    target.addLegalOp<ConstantOp>();
+
+    auto *op = getOperation();
+    mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
+                                                         &patterns);
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToStandard() {
+  return std::make_unique<TosaToStandard>();
+}
+
+void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTosaToStandard());
+}

diff  --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
new file mode 100644
index 000000000000..82fa2c9f0bb5
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK-LABEL: func @while_test
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>)
+func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
+  // CHECK: [[WHILE:%.+]] = scf.while ([[ARG1:%.+]] = [[ARG0]])
+  %1 = "tosa.while_loop"(%arg0) ( {
+  ^bb0(%arg2: tensor<i32>):
+    // CHECK: "tosa.const"
+    %2 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK: [[COMPARE:%.+]] = "tosa.greater_equal"
+    %3 = "tosa.greater_equal"(%2, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+
+    // CHECK: [[EX:%.+]] = tensor.extract [[COMPARE]]
+    // CHECK: scf.condition([[EX]]) [[ARG1]]
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  // CHECK: ^bb0([[ARG1:%.+]]: tensor<i32>)
+  ^bb0(%arg2: tensor<i32>):
+    // CHECK: tosa.const
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK: [[ADD:%.+]] = "tosa.add"
+    %3 = "tosa.add"(%arg2, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+
+    // CHECK: scf.yield [[ADD]]
+    "tosa.yield"(%3) : (tensor<i32>) -> ()
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return %1 : tensor<i32>
+}
+
+// ----
+
+// CHECK-LABEL: func @if_test
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<i1>)
+func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
+  // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
+  // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+
+  // CHECK:   scf.yield [[ARG0]]
+  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
+    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+
+  // CHECK: } else {
+  }, {
+
+  // CHECK:   scf.yield [[ARG1]]
+  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
+    "tosa.yield"(%arg6) : (tensor<f32>) -> ()
+
+  // CHECK: }
+  // CHECK: return [[IF]]
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>)
+
+  return %0 : tensor<f32>
+}

diff  --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
new file mode 100644
index 000000000000..86304dcba862
--- /dev/null
+++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK-LABEL: func @const_test
+func @const_test() -> (tensor<i32>) {
+  // CHECK: [[C3:%.+]] = constant dense<3> : tensor<i32>
+  %0 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+  // CHECK: return [[C3]]
+  return %0 : tensor<i32>
+}


        


More information about the Mlir-commits mailing list