[Mlir-commits] [mlir] b73e832 - [mlir][tosa] Updates tosa.equal to use the InferTensorType interface
Jacques Pienaar
llvmlistbot at llvm.org
Mon Aug 8 16:12:07 PDT 2022
Author: not-jenni
Date: 2022-08-08T16:11:30-07:00
New Revision: b73e8325fb25b27f898190d81a0d62e92c06694c
URL: https://github.com/llvm/llvm-project/commit/b73e8325fb25b27f898190d81a0d62e92c06694c
DIFF: https://github.com/llvm/llvm-project/commit/b73e8325fb25b27f898190d81a0d62e92c06694c.diff
LOG: [mlir][tosa] Updates tosa.equal to use the InferTensorType interface
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D130373
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c88f8729c83da..3f00587b342b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1139,10 +1139,8 @@ def Tosa_SelectOp : Tosa_Op<"select", [
//===----------------------------------------------------------------------===//
// Operator: equal
//===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Commutative, NoSideEffect]> {
+def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
+ Commutative, NoSideEffect]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
@@ -1157,6 +1155,12 @@ def Tosa_EqualOp : Tosa_Op<"equal", [
let results = (outs
I1Tensor:$output
);
+
+ let extraClassDeclaration = [{
+ /// Returns when two result types are compatible for this op; method used by
+ /// InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 11ff2a00fcba1..90fe70d1a77dc 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -339,6 +340,44 @@ static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
}
}
+static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
+ SmallVector<int64_t> &outShape) {
+ int64_t outRank = 0;
+ for (int i = 0, e = operands.size(); i != e; ++i) {
+ auto shape = operands.getShape(i);
+ if (!shape.hasRank()) {
+ // TODO(jennik): Update function to have better case handling for invalid
+ // operands and for ranked tensors.
+ return failure();
+ }
+ outRank = std::max<int64_t>(outRank, shape.getRank());
+ }
+
+ outShape.resize(outRank, 1);
+
+ for (int i = 0, e = operands.size(); i != e; ++i) {
+ auto shape = operands.getShape(i);
+ auto rankDiff = outShape.size() - shape.getRank();
+
+ for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
+ auto dim1 = outShape[i + rankDiff];
+ auto dim2 = shape.getDimSize(i);
+ auto resolvedDim = dim1;
+
+ if (dim1 == 1) {
+ resolvedDim = dim2;
+ } else if (dim2 == 1) {
+ resolvedDim = dim1;
+ } else if (dim1 != dim2) {
+ return failure();
+ }
+ outShape[i + rankDiff] = resolvedDim;
+ }
+ }
+
+ return success();
+}
+
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -421,6 +460,27 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::EqualOp::inferReturnTypeComponents(
+ MLIRContext *context, ::llvm::Optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ llvm::SmallVector<int64_t> outShape;
+ if (resolveBroadcastShape(operands, outShape).failed()) {
+ inferredReturnShapes.push_back(ShapedTypeComponents());
+ return success();
+ }
+
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
+ return success();
+}
+
+bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != r.size() || l.size() != 1)
+ return false;
+ return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -870,42 +930,6 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
-static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
- SmallVector<int64_t> &outShape) {
- int64_t outRank = 0;
- for (int i = 0, e = operands.size(); i != e; ++i) {
- auto shape = operands.getShape(i);
- if (!shape.hasRank()) {
- return failure();
- }
- outRank = std::max<int64_t>(outRank, shape.getRank());
- }
-
- outShape.resize(outRank, 1);
-
- for (int i = 0, e = operands.size(); i != e; ++i) {
- auto shape = operands.getShape(i);
- auto rankDiff = outShape.size() - shape.getRank();
-
- for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
- auto dim1 = outShape[i + rankDiff];
- auto dim2 = shape.getDimSize(i);
- auto resolvedDim = dim1;
-
- if (dim1 == 1) {
- resolvedDim = dim2;
- } else if (dim2 == 1) {
- resolvedDim = dim1;
- } else if (dim1 != dim2) {
- return failure();
- }
- outShape[i + rankDiff] = resolvedDim;
- }
- }
-
- return success();
-}
-
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -939,7 +963,6 @@ NARY_SHAPE_INFER(tosa::CeilOp)
NARY_SHAPE_INFER(tosa::ClampOp)
NARY_SHAPE_INFER(tosa::ClzOp)
NARY_SHAPE_INFER(tosa::DivOp)
-NARY_SHAPE_INFER(tosa::EqualOp)
NARY_SHAPE_INFER(tosa::ExpOp)
NARY_SHAPE_INFER(tosa::FloorOp)
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
More information about the Mlir-commits
mailing list