[Mlir-commits] [mlir] 1b00b94 - [mlir][tosa] Tosa shape propagation for tosa.cond_if
Rob Suderman
llvmlistbot at llvm.org
Tue Aug 3 17:55:58 PDT 2021
Author: Rob Suderman
Date: 2021-08-03T17:54:54-07:00
New Revision: 1b00b94ffc2d60e07ec8e486dad0fcbcbfb99c62
URL: https://github.com/llvm/llvm-project/commit/1b00b94ffc2d60e07ec8e486dad0fcbcbfb99c62
DIFF: https://github.com/llvm/llvm-project/commit/1b00b94ffc2d60e07ec8e486dad0fcbcbfb99c62.diff
LOG: [mlir][tosa] Tosa shape propagation for tosa.cond_if
We can propagate the shape from tosa.cond_if operands into the true/false
regions then through the connected blocks. Then, using the tosa.yield ops
we can determine what all possible return types are.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D105940
Added:
mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7a29350467814..e2e8eee074d62 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1789,6 +1789,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
// Further described in docs/Rationale/RationaleTOSADialect.md .
//===----------------------------------------------------------------------===//
def Tosa_IfOp : Tosa_Op<"cond_if", [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveSideEffects]> {
let summary = "Conditional if operator";
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
new file mode 100644
index 0000000000000..b7f742e4e31c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
@@ -0,0 +1,178 @@
+//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Class declarations for shape utilities meant to assist shape propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
+#define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace tosa {
+/// Statically known information for a particular Value.
+///
+/// This struct currently tracks only information relevant for tensor/array-like
+/// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
+/// type as long as it is in the default "no knowledge" state returned by
+/// `getPessimisticValueState`. The important invariant is that we cannot
+/// claim to know something about a value which is false.
+///
+/// This class could also be called "dataflow facts", "lattice value", etc.
+struct ValueKnowledge {
+ ValueKnowledge() = delete;
+ ValueKnowledge(bool hasRank, llvm::ArrayRef<int64_t> newSizes, Type dtype)
+ : hasError(false), hasRank(hasRank), dtype(dtype) {
+ sizes.reserve(newSizes.size());
+ for (auto size : newSizes)
+ sizes.push_back(size);
+ }
+
+ operator bool() const { return !hasError; }
+
+ // Get the static knowledge intrinsic to `type`.
+ static ValueKnowledge getKnowledgeFromType(Type type) {
+ ValueKnowledge result = getPessimisticValueState();
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ if (shapedType.hasRank()) {
+ result.hasRank = true;
+ result.sizes.reserve(shapedType.getRank());
+ for (auto dim : shapedType.getShape())
+ result.sizes.push_back(dim);
+ }
+ result.dtype = shapedType.getElementType();
+ }
+ return result;
+ }
+
+ // Return a pessimistic/conservative value state without assuming any knowlege
+ // about the IR.
+ static ValueKnowledge getPessimisticValueState() {
+ return ValueKnowledge(false, {}, Type());
+ }
+
+ Type getType() const {
+ if (hasRank)
+ return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
+ return UnrankedTensorType::get(dtype);
+ }
+
+ bool operator==(const ValueKnowledge &rhs) const {
+ return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype;
+ }
+
+ // Given two pieces of static knowledge, calculate conservatively the
+ // information we can be sure about.
+ static ValueKnowledge join(const ValueKnowledge &lhs,
+ const ValueKnowledge &rhs) {
+ // Mental model: All conditions are checking how to change from the safe "no
+ // knowledge" default-initialized state to a state with more knowledge
+ // consistent with lhs and rhs.
+ ValueKnowledge result = getPessimisticValueState();
+ result.hasError = true;
+
+ if (!lhs || !rhs || lhs.dtype != rhs.dtype)
+ return result;
+
+ result.hasError = false;
+ result.dtype = lhs.dtype;
+
+ if (!lhs.hasRank && !rhs.hasRank)
+ return result;
+
+ if (!rhs.hasRank) {
+ result.hasRank = true;
+ result.sizes = lhs.sizes;
+ return result;
+ }
+
+ if (!lhs.hasRank) {
+ result.hasRank = true;
+ result.sizes = rhs.sizes;
+ return result;
+ }
+
+ if (lhs.sizes.size() != rhs.sizes.size())
+ return result;
+
+ result.hasRank = true;
+ result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
+ for (auto i : llvm::seq<unsigned>(0, result.sizes.size())) {
+ int64_t lhsSize = lhs.sizes[i];
+ int64_t rhsSize = rhs.sizes[i];
+ int64_t &resultSize = result.sizes[i];
+ if (lhsSize == ShapedType::kDynamicSize) {
+ resultSize = rhsSize;
+ } else if (rhsSize == ShapedType::kDynamicSize) {
+ resultSize = lhsSize;
+ } else if (lhsSize == rhsSize) {
+ resultSize = lhsSize;
+ } else {
+ result.hasError = true;
+ }
+ }
+
+ return result;
+ }
+
+ // Given to types, generate a new ValueKnowledge that meets to cover both
+ // cases. E.g. if the rank of the LHS and RHS
diff er, the resulting tensor
+ // has unknown rank.
+ static ValueKnowledge meet(const ValueKnowledge &lhs,
+ const ValueKnowledge &rhs) {
+ ValueKnowledge result = getPessimisticValueState();
+ result.hasError = true;
+
+ if (!rhs || !rhs || lhs.dtype != rhs.dtype)
+ return result;
+
+ result.hasError = false;
+ result.dtype = lhs.dtype;
+
+ if (!lhs.hasRank || !rhs.hasRank) {
+ result.hasRank = false;
+ return result;
+ }
+
+ if (lhs.sizes.size() != rhs.sizes.size()) {
+ result.hasRank = false;
+ return result;
+ }
+
+ result.hasRank = true;
+ result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
+ for (int i = 0, e = lhs.sizes.size(); i < e; i++) {
+ if (lhs.sizes[i] == rhs.sizes[i]) {
+ result.sizes[i] = lhs.sizes[i];
+ }
+ }
+
+ return result;
+ }
+
+ // Whether the value information has an error.
+ bool hasError;
+ // Whether the value has known rank.
+ bool hasRank;
+ // If `hasRank`, the sizes along each rank. Unknown sizes are represented as
+ // `ShapedType::kDynamicSize`.
+ llvm::SmallVector<int64_t> sizes;
+ // The dtype of a tensor.
+ // This is equal to nullptr if we don't know that it is a specific concrete
+ // type.
+ Type dtype;
+};
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9ae2e95d146ae..d6f9905506ca7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@@ -1301,6 +1302,54 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
return success();
}
+LogicalResult IfOp::inferReturnTypeComponents(
+ MLIRContext *context, ::llvm::Optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ llvm::SmallVector<tosa::YieldOp> yieldOps;
+ for (Region *region : regions) {
+ for (auto &block : *region)
+ if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
+ yieldOps.push_back(returnOp);
+ }
+
+ if (yieldOps.empty())
+ return failure();
+
+ // Get the initial type information for the yield op.
+ llvm::SmallVector<ValueKnowledge> resultKnowledge;
+ resultKnowledge.reserve(yieldOps.front().getNumOperands());
+ for (auto operand : yieldOps.front().getOperands()) {
+ resultKnowledge.push_back(
+ ValueKnowledge::getKnowledgeFromType(operand.getType()));
+ }
+
+ for (auto yieldOp : yieldOps) {
+ if (resultKnowledge.size() != yieldOp.getNumOperands())
+ return failure();
+
+ for (auto it : llvm::enumerate(yieldOp.getOperands())) {
+ int32_t index = it.index();
+ auto meet = ValueKnowledge::meet(
+ resultKnowledge[index],
+ ValueKnowledge::getKnowledgeFromType(it.value().getType()));
+ if (!meet)
+ continue;
+ resultKnowledge[index] = meet;
+ }
+ }
+
+ for (auto result : resultKnowledge) {
+ if (result.hasRank) {
+ inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
+ } else {
+ inferredReturnShapes.push_back(ShapedTypeComponents());
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index eca63e1e8ab39..390950e7550de 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -30,137 +31,57 @@ using namespace mlir::tosa;
namespace {
-// -----------------------------------------------------------------------------
-// Analysis.
-// -----------------------------------------------------------------------------
-
-static Type joinElementTypes(Type lhs, Type rhs) {
- return lhs == rhs ? lhs : Type();
-}
-
-namespace {
-// Statically known information for a particular Value.
-//
-// This struct currently tracks only information relevant for tensor/array-like
-// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
-// type as long as it is in the default "no knowledge" state returned by
-// `getPessimisticValueState`. The important invariant is that we cannot
-// claim to know something about a value which is false.
-//
-// This class could also be called "dataflow facts", "lattice value", etc.
-struct ValueKnowledge {
- ValueKnowledge() = delete;
- ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
- : hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
- assert(sizes.size() == 0 || hasSizes);
- }
-
- // Get the static knowledge intrinsic to `type`.
- static ValueKnowledge getKnowledgeFromType(Type type) {
- ValueKnowledge result = getPessimisticValueState(type.getContext());
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- if (shapedType.hasRank()) {
- result.hasSizes = true;
- result.sizes = shapedType.getShape();
- }
- result.dtype = shapedType.getElementType();
+void propagateShapesInRegion(Region ®ion);
+
+void propagateShapesToTosaIf(Operation &op) {
+ tosa::IfOp ifOp = dyn_cast<tosa::IfOp>(op);
+ if (!ifOp)
+ return;
+
+ for (auto ®ion : op.getRegions()) {
+ Block &frontBlock = region.front();
+ if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+ return;
+
+ for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+ ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+ ifOp.getOperand(i + 1).getType());
+ ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+ frontBlock.getArgument(i).getType());
+ ValueKnowledge joinedKnowledge =
+ ValueKnowledge::join(operandKnowledge, blockKnowledge);
+ if (!joinedKnowledge)
+ continue;
+ frontBlock.getArgument(i).setType(joinedKnowledge.getType());
}
- return result;
- }
- // Return a pessimistic/conservative value state without assuming any knowlege
- // about the IR.
- static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
- return ValueKnowledge(false, {}, Type());
+ propagateShapesInRegion(region);
}
- Type getType() const {
- if (hasSizes) {
- return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
- }
- return UnrankedTensorType::get(dtype);
- }
-
- bool operator==(const ValueKnowledge &rhs) const {
- return std::make_tuple(hasSizes, sizes, dtype) ==
- std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
- }
-
- // Given two pieces of static knowledge, calculate conservatively the
- // information we can be sure about.
- static ValueKnowledge join(const ValueKnowledge &lhs,
- const ValueKnowledge &rhs) {
- // Mental model: All conditions are checking how to change from the safe "no
- // knowledge" default-initialized state to a state with more knowledge
- // consistent with lhs and rhs.
- ValueKnowledge result = getPessimisticValueState(nullptr);
-
- if (lhs.hasSizes && !rhs.hasSizes) {
- result.hasSizes = true;
- result.sizes = lhs.sizes;
- } else if (!lhs.hasSizes && rhs.hasSizes) {
- result.hasSizes = true;
- result.sizes = rhs.sizes;
- } else if (lhs.hasSizes && rhs.hasSizes &&
- lhs.sizes.size() == rhs.sizes.size()) {
- result.hasSizes = true;
- result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
- for (int i = 0, e = result.sizes.size(); i != e; i++) {
- int64_t lhsSize = lhs.sizes[i];
- int64_t rhsSize = rhs.sizes[i];
- int64_t &resultSize = result.sizes[i];
- if (lhsSize == ShapedType::kDynamicSize) {
- resultSize = rhsSize;
- } else if (rhsSize == ShapedType::kDynamicSize) {
- resultSize = lhsSize;
- } else if (lhsSize == rhsSize) {
- resultSize = lhsSize;
- }
- }
- }
-
- result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
- return result;
- }
-
- // Whether the Value is known to have a list of sizes.
- bool hasSizes;
- // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
- // `ShapedType::kDynamicSize`.
- std::vector<int64_t> sizes;
- // The dtype of a tensor.
- // This is equal to nullptr if we don't know that it is a specific concrete
- // type.
- Type dtype;
-};
-
-} // namespace
+ return;
+}
-/// Pass that enables broadcast by making all input arrays have the same
-/// number of dimensions. Insert RESHAPE operations to lower rank operand
-struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
-public:
- void runOnFunction() override {
- FuncOp func = getOperation();
+void propagateShapesInRegion(Region ®ion) {
+ for (auto &block : region) {
+ for (Operation &op : block) {
+ if (op.getDialect()->getNamespace() !=
+ tosa::TosaDialect::getDialectNamespace())
+ continue;
- IRRewriter rewriter(func.getContext());
+ propagateShapesToTosaIf(op);
- func.walk([&](Operation *op) {
- if (op->getDialect()->getNamespace() !=
- tosa::TosaDialect::getDialectNamespace())
- return;
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
if (!shapeInterface)
- return;
+ continue;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
- op->getContext(), op->getLoc(), op->getOperands(),
- op->getAttrDictionary(), op->getRegions(), returnedShapes)
+ op.getContext(), op.getLoc(), op.getOperands(),
+ op.getAttrDictionary(), op.getRegions(), returnedShapes)
.succeeded()) {
- for (auto it : llvm::zip(op->getResults(), returnedShapes)) {
+ for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
@@ -183,11 +104,10 @@ struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
ValueKnowledge::getKnowledgeFromType(resultTy);
// Compute the knowledge based on the inferred type.
- auto inferredKnowledge =
- ValueKnowledge::getPessimisticValueState(op->getContext());
+ auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype =
resultTy.cast<ShapedType>().getElementType();
- inferredKnowledge.hasSizes = predictedShape.hasRank();
+ inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
@@ -200,10 +120,25 @@ struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
// Compute the new type based on the joined version.
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+ if (!newKnowledge)
+ continue;
result.setType(newKnowledge.getType());
}
}
- });
+ }
+ }
+}
+
+/// Pass that performs shape propagation across TOSA operations. This includes
+/// migrating to within the regions of if/while operations.
+struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
+public:
+ void runOnFunction() override {
+ FuncOp func = getOperation();
+
+ IRRewriter rewriter(func.getContext());
+
+ propagateShapesInRegion(func.body());
// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
// the FuncOp type.
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d42dd7e7f88c3..50189d46a3882 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -774,7 +774,6 @@ func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32
// -----
-
// CHECK-LABEL: @conv2d_strided
func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () {
// CHECK: -> tensor<1x5x7x1xf32>
@@ -1033,12 +1032,71 @@ func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
%0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
return
}
-
-// -----
-
// CHECK-LABEL: @resize_fp_offsetted
func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
// CHECK: -> tensor<1x4x6x1xi32>
%0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
return
}
+
+// -----
+
+// CHECK-LABEL: @if_test_simple
+func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
+ // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
+ "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+ }, {
+ ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
+ "tosa.yield"(%arg6) : (tensor<f32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @if_test_dynamic
+func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+ // CHECK: (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>):
+ "tosa.yield"(%arg3) : (tensor<2xf32>) -> ()
+ }, {
+ ^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>):
+ "tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
+ }) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @if_test_unranked
+func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+ // CHECK: (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<3xf32>):
+ "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+ }, {
+ ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<3xf32>):
+ "tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> (tensor<*xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @if_test_propagate
+func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
+ // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>):
+ %1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ "tosa.yield"(%1) : (tensor<*xf32>) -> ()
+ }, {
+ ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>):
+ %1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+ "tosa.yield"(%1) : (tensor<*xf32>) -> ()
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
+ return
+}
More information about the Mlir-commits
mailing list