[Mlir-commits] [mlir] 1b00b94 - [mlir][tosa] Tosa shape propagation for tosa.cond_if

Rob Suderman llvmlistbot at llvm.org
Tue Aug 3 17:55:58 PDT 2021


Author: Rob Suderman
Date: 2021-08-03T17:54:54-07:00
New Revision: 1b00b94ffc2d60e07ec8e486dad0fcbcbfb99c62

URL: https://github.com/llvm/llvm-project/commit/1b00b94ffc2d60e07ec8e486dad0fcbcbfb99c62
DIFF: https://github.com/llvm/llvm-project/commit/1b00b94ffc2d60e07ec8e486dad0fcbcbfb99c62.diff

LOG: [mlir][tosa] Tosa shape propagation for tosa.cond_if

We can propagate the shape from tosa.cond_if operands into the true/false
regions then through the connected blocks. Then, using the tosa.yield ops
we can determine what all possible return types are.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105940

Added: 
    mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7a29350467814..e2e8eee074d62 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1789,6 +1789,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
 // Further described in docs/Rationale/RationaleTOSADialect.md .
 //===----------------------------------------------------------------------===//
 def Tosa_IfOp : Tosa_Op<"cond_if", [
+      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                                ["inferReturnTypeComponents"]>,
        SingleBlockImplicitTerminator<"YieldOp">,
        RecursiveSideEffects]> {
   let summary = "Conditional if operator";

diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
new file mode 100644
index 0000000000000..b7f742e4e31c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
@@ -0,0 +1,178 @@
+//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Class declarations for shape utilities meant to assist shape propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
+#define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace tosa {
+/// Statically known information for a particular Value.
+///
+/// This struct currently tracks only information relevant for tensor/array-like
+/// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
+/// type as long as it is in the default "no knowledge" state returned by
+/// `getPessimisticValueState`. The important invariant is that we cannot
+/// claim to know something about a value which is false.
+///
+/// This class could also be called "dataflow facts", "lattice value", etc.
+struct ValueKnowledge {
+  ValueKnowledge() = delete;
+  ValueKnowledge(bool hasRank, llvm::ArrayRef<int64_t> newSizes, Type dtype)
+      : hasError(false), hasRank(hasRank), dtype(dtype) {
+    sizes.reserve(newSizes.size());
+    for (auto size : newSizes)
+      sizes.push_back(size);
+  }
+
+  operator bool() const { return !hasError; }
+
+  // Get the static knowledge intrinsic to `type`.
+  static ValueKnowledge getKnowledgeFromType(Type type) {
+    ValueKnowledge result = getPessimisticValueState();
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      if (shapedType.hasRank()) {
+        result.hasRank = true;
+        result.sizes.reserve(shapedType.getRank());
+        for (auto dim : shapedType.getShape())
+          result.sizes.push_back(dim);
+      }
+      result.dtype = shapedType.getElementType();
+    }
+    return result;
+  }
+
+  // Return a pessimistic/conservative value state without assuming any knowlege
+  // about the IR.
+  static ValueKnowledge getPessimisticValueState() {
+    return ValueKnowledge(false, {}, Type());
+  }
+
+  Type getType() const {
+    if (hasRank)
+      return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
+    return UnrankedTensorType::get(dtype);
+  }
+
+  bool operator==(const ValueKnowledge &rhs) const {
+    return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype;
+  }
+
+  // Given two pieces of static knowledge, calculate conservatively the
+  // information we can be sure about.
+  static ValueKnowledge join(const ValueKnowledge &lhs,
+                             const ValueKnowledge &rhs) {
+    // Mental model: All conditions are checking how to change from the safe "no
+    // knowledge" default-initialized state to a state with more knowledge
+    // consistent with lhs and rhs.
+    ValueKnowledge result = getPessimisticValueState();
+    result.hasError = true;
+
+    if (!lhs || !rhs || lhs.dtype != rhs.dtype)
+      return result;
+
+    result.hasError = false;
+    result.dtype = lhs.dtype;
+
+    if (!lhs.hasRank && !rhs.hasRank)
+      return result;
+
+    if (!rhs.hasRank) {
+      result.hasRank = true;
+      result.sizes = lhs.sizes;
+      return result;
+    }
+
+    if (!lhs.hasRank) {
+      result.hasRank = true;
+      result.sizes = rhs.sizes;
+      return result;
+    }
+
+    if (lhs.sizes.size() != rhs.sizes.size())
+      return result;
+
+    result.hasRank = true;
+    result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
+    for (auto i : llvm::seq<unsigned>(0, result.sizes.size())) {
+      int64_t lhsSize = lhs.sizes[i];
+      int64_t rhsSize = rhs.sizes[i];
+      int64_t &resultSize = result.sizes[i];
+      if (lhsSize == ShapedType::kDynamicSize) {
+        resultSize = rhsSize;
+      } else if (rhsSize == ShapedType::kDynamicSize) {
+        resultSize = lhsSize;
+      } else if (lhsSize == rhsSize) {
+        resultSize = lhsSize;
+      } else {
+        result.hasError = true;
+      }
+    }
+
+    return result;
+  }
+
+  // Given to types, generate a new ValueKnowledge that meets to cover both
+  // cases. E.g. if the rank of the LHS and RHS 
diff er, the resulting tensor
+  // has unknown rank.
+  static ValueKnowledge meet(const ValueKnowledge &lhs,
+                             const ValueKnowledge &rhs) {
+    ValueKnowledge result = getPessimisticValueState();
+    result.hasError = true;
+
+    if (!rhs || !rhs || lhs.dtype != rhs.dtype)
+      return result;
+
+    result.hasError = false;
+    result.dtype = lhs.dtype;
+
+    if (!lhs.hasRank || !rhs.hasRank) {
+      result.hasRank = false;
+      return result;
+    }
+
+    if (lhs.sizes.size() != rhs.sizes.size()) {
+      result.hasRank = false;
+      return result;
+    }
+
+    result.hasRank = true;
+    result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
+    for (int i = 0, e = lhs.sizes.size(); i < e; i++) {
+      if (lhs.sizes[i] == rhs.sizes[i]) {
+        result.sizes[i] = lhs.sizes[i];
+      }
+    }
+
+    return result;
+  }
+
+  // Whether the value information has an error.
+  bool hasError;
+  // Whether the value has known rank.
+  bool hasRank;
+  // If `hasRank`, the sizes along each rank. Unknown sizes are represented as
+  // `ShapedType::kDynamicSize`.
+  llvm::SmallVector<int64_t> sizes;
+  // The dtype of a tensor.
+  // This is equal to nullptr if we don't know that it is a specific concrete
+  // type.
+  Type dtype;
+};
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9ae2e95d146ae..d6f9905506ca7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1301,6 +1302,54 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult IfOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<tosa::YieldOp> yieldOps;
+  for (Region *region : regions) {
+    for (auto &block : *region)
+      if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
+        yieldOps.push_back(returnOp);
+  }
+
+  if (yieldOps.empty())
+    return failure();
+
+  // Get the initial type information for the yield op.
+  llvm::SmallVector<ValueKnowledge> resultKnowledge;
+  resultKnowledge.reserve(yieldOps.front().getNumOperands());
+  for (auto operand : yieldOps.front().getOperands()) {
+    resultKnowledge.push_back(
+        ValueKnowledge::getKnowledgeFromType(operand.getType()));
+  }
+
+  for (auto yieldOp : yieldOps) {
+    if (resultKnowledge.size() != yieldOp.getNumOperands())
+      return failure();
+
+    for (auto it : llvm::enumerate(yieldOp.getOperands())) {
+      int32_t index = it.index();
+      auto meet = ValueKnowledge::meet(
+          resultKnowledge[index],
+          ValueKnowledge::getKnowledgeFromType(it.value().getType()));
+      if (!meet)
+        continue;
+      resultKnowledge[index] = meet;
+    }
+  }
+
+  for (auto result : resultKnowledge) {
+    if (result.hasRank) {
+      inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
+    } else {
+      inferredReturnShapes.push_back(ShapedTypeComponents());
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Definitions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index eca63e1e8ab39..390950e7550de 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -30,137 +31,57 @@ using namespace mlir::tosa;
 
 namespace {
 
-// -----------------------------------------------------------------------------
-// Analysis.
-// -----------------------------------------------------------------------------
-
-static Type joinElementTypes(Type lhs, Type rhs) {
-  return lhs == rhs ? lhs : Type();
-}
-
-namespace {
-// Statically known information for a particular Value.
-//
-// This struct currently tracks only information relevant for tensor/array-like
-// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
-// type as long as it is in the default "no knowledge" state returned by
-// `getPessimisticValueState`. The important invariant is that we cannot
-// claim to know something about a value which is false.
-//
-// This class could also be called "dataflow facts", "lattice value", etc.
-struct ValueKnowledge {
-  ValueKnowledge() = delete;
-  ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
-      : hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
-    assert(sizes.size() == 0 || hasSizes);
-  }
-
-  // Get the static knowledge intrinsic to `type`.
-  static ValueKnowledge getKnowledgeFromType(Type type) {
-    ValueKnowledge result = getPessimisticValueState(type.getContext());
-    if (auto shapedType = type.dyn_cast<ShapedType>()) {
-      if (shapedType.hasRank()) {
-        result.hasSizes = true;
-        result.sizes = shapedType.getShape();
-      }
-      result.dtype = shapedType.getElementType();
+void propagateShapesInRegion(Region &region);
+
+void propagateShapesToTosaIf(Operation &op) {
+  tosa::IfOp ifOp = dyn_cast<tosa::IfOp>(op);
+  if (!ifOp)
+    return;
+
+  for (auto &region : op.getRegions()) {
+    Block &frontBlock = region.front();
+    if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+      return;
+
+    for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+      ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+          ifOp.getOperand(i + 1).getType());
+      ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+          frontBlock.getArgument(i).getType());
+      ValueKnowledge joinedKnowledge =
+          ValueKnowledge::join(operandKnowledge, blockKnowledge);
+      if (!joinedKnowledge)
+        continue;
+      frontBlock.getArgument(i).setType(joinedKnowledge.getType());
     }
-    return result;
-  }
 
-  // Return a pessimistic/conservative value state without assuming any knowlege
-  // about the IR.
-  static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
-    return ValueKnowledge(false, {}, Type());
+    propagateShapesInRegion(region);
   }
 
-  Type getType() const {
-    if (hasSizes) {
-      return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
-    }
-    return UnrankedTensorType::get(dtype);
-  }
-
-  bool operator==(const ValueKnowledge &rhs) const {
-    return std::make_tuple(hasSizes, sizes, dtype) ==
-           std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
-  }
-
-  // Given two pieces of static knowledge, calculate conservatively the
-  // information we can be sure about.
-  static ValueKnowledge join(const ValueKnowledge &lhs,
-                             const ValueKnowledge &rhs) {
-    // Mental model: All conditions are checking how to change from the safe "no
-    // knowledge" default-initialized state to a state with more knowledge
-    // consistent with lhs and rhs.
-    ValueKnowledge result = getPessimisticValueState(nullptr);
-
-    if (lhs.hasSizes && !rhs.hasSizes) {
-      result.hasSizes = true;
-      result.sizes = lhs.sizes;
-    } else if (!lhs.hasSizes && rhs.hasSizes) {
-      result.hasSizes = true;
-      result.sizes = rhs.sizes;
-    } else if (lhs.hasSizes && rhs.hasSizes &&
-               lhs.sizes.size() == rhs.sizes.size()) {
-      result.hasSizes = true;
-      result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
-      for (int i = 0, e = result.sizes.size(); i != e; i++) {
-        int64_t lhsSize = lhs.sizes[i];
-        int64_t rhsSize = rhs.sizes[i];
-        int64_t &resultSize = result.sizes[i];
-        if (lhsSize == ShapedType::kDynamicSize) {
-          resultSize = rhsSize;
-        } else if (rhsSize == ShapedType::kDynamicSize) {
-          resultSize = lhsSize;
-        } else if (lhsSize == rhsSize) {
-          resultSize = lhsSize;
-        }
-      }
-    }
-
-    result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
-    return result;
-  }
-
-  // Whether the Value is known to have a list of sizes.
-  bool hasSizes;
-  // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
-  // `ShapedType::kDynamicSize`.
-  std::vector<int64_t> sizes;
-  // The dtype of a tensor.
-  // This is equal to nullptr if we don't know that it is a specific concrete
-  // type.
-  Type dtype;
-};
-
-} // namespace
+  return;
+}
 
-/// Pass that enables broadcast by making all input arrays have the same
-/// number of dimensions. Insert RESHAPE operations to lower rank operand
-struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
-public:
-  void runOnFunction() override {
-    FuncOp func = getOperation();
+void propagateShapesInRegion(Region &region) {
+  for (auto &block : region) {
+    for (Operation &op : block) {
+      if (op.getDialect()->getNamespace() !=
+          tosa::TosaDialect::getDialectNamespace())
+        continue;
 
-    IRRewriter rewriter(func.getContext());
+      propagateShapesToTosaIf(op);
 
-    func.walk([&](Operation *op) {
-      if (op->getDialect()->getNamespace() !=
-          tosa::TosaDialect::getDialectNamespace())
-        return;
       InferShapedTypeOpInterface shapeInterface =
           dyn_cast<InferShapedTypeOpInterface>(op);
       if (!shapeInterface)
-        return;
+        continue;
 
       SmallVector<ShapedTypeComponents> returnedShapes;
       if (shapeInterface
               .inferReturnTypeComponents(
-                  op->getContext(), op->getLoc(), op->getOperands(),
-                  op->getAttrDictionary(), op->getRegions(), returnedShapes)
+                  op.getContext(), op.getLoc(), op.getOperands(),
+                  op.getAttrDictionary(), op.getRegions(), returnedShapes)
               .succeeded()) {
-        for (auto it : llvm::zip(op->getResults(), returnedShapes)) {
+        for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
           Value result = std::get<0>(it);
           ShapedTypeComponents predictedShape = std::get<1>(it);
 
@@ -183,11 +104,10 @@ struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
               ValueKnowledge::getKnowledgeFromType(resultTy);
 
           // Compute the knowledge based on the inferred type.
-          auto inferredKnowledge =
-              ValueKnowledge::getPessimisticValueState(op->getContext());
+          auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
           inferredKnowledge.dtype =
               resultTy.cast<ShapedType>().getElementType();
-          inferredKnowledge.hasSizes = predictedShape.hasRank();
+          inferredKnowledge.hasRank = predictedShape.hasRank();
           if (predictedShape.hasRank()) {
             for (auto dim : predictedShape.getDims()) {
               inferredKnowledge.sizes.push_back(dim);
@@ -200,10 +120,25 @@ struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
           // Compute the new type based on the joined version.
           auto newKnowledge =
               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+          if (!newKnowledge)
+            continue;
           result.setType(newKnowledge.getType());
         }
       }
-    });
+    }
+  }
+}
+
+/// Pass that performs shape propagation across TOSA operations. This includes
+/// migrating to within the regions of if/while operations.
+struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
+public:
+  void runOnFunction() override {
+    FuncOp func = getOperation();
+
+    IRRewriter rewriter(func.getContext());
+
+    propagateShapesInRegion(func.body());
 
     // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
     // the FuncOp type.

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d42dd7e7f88c3..50189d46a3882 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -774,7 +774,6 @@ func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32
 
 // -----
 
-
 // CHECK-LABEL: @conv2d_strided
 func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () {
   // CHECK: -> tensor<1x5x7x1xf32>
@@ -1033,12 +1032,71 @@ func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
   %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
   return
 }
-
-// -----
-
 // CHECK-LABEL: @resize_fp_offsetted
 func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
   // CHECK: -> tensor<1x4x6x1xi32>
   %0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @if_test_simple
+func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
+    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+  }, {
+  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
+    "tosa.yield"(%arg6) : (tensor<f32>) -> ()
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @if_test_dynamic
+func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>):
+    "tosa.yield"(%arg3) : (tensor<2xf32>) -> ()
+  }, {
+  ^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>):
+    "tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
+  }) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @if_test_unranked
+func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<3xf32>):
+    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+  }, {
+  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<3xf32>):
+    "tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
+  }) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> (tensor<*xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @if_test_propagate
+func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>):
+    %1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+    "tosa.yield"(%1) : (tensor<*xf32>) -> ()
+  }, {
+  ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>):
+    %1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
+    "tosa.yield"(%1) : (tensor<*xf32>) -> ()
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list