[Mlir-commits] [mlir] b73e832 - [mlir][tosa] Updates tosa.equal to use the InferTensorType interface

Jacques Pienaar llvmlistbot at llvm.org
Mon Aug 8 16:12:07 PDT 2022


Author: not-jenni
Date: 2022-08-08T16:11:30-07:00
New Revision: b73e8325fb25b27f898190d81a0d62e92c06694c

URL: https://github.com/llvm/llvm-project/commit/b73e8325fb25b27f898190d81a0d62e92c06694c
DIFF: https://github.com/llvm/llvm-project/commit/b73e8325fb25b27f898190d81a0d62e92c06694c.diff

LOG: [mlir][tosa] Updates tosa.equal to use the InferTensorType interface

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D130373

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c88f8729c83da..3f00587b342b3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1139,10 +1139,8 @@ def Tosa_SelectOp : Tosa_Op<"select", [
 //===----------------------------------------------------------------------===//
 // Operator: equal
 //===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    ResultsBroadcastableShape, Commutative, NoSideEffect]> {
+def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
+    Commutative, NoSideEffect]> {
   let summary = "Returns the truth value of (x == y) element-wise.";
 
   let description = [{
@@ -1157,6 +1155,12 @@ def Tosa_EqualOp : Tosa_Op<"equal", [
   let results = (outs
     I1Tensor:$output
   );
+
+  let extraClassDeclaration = [{
+    /// Returns when two result types are compatible for this op; method used by
+    /// InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 11ff2a00fcba1..90fe70d1a77dc 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -339,6 +340,44 @@ static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
   }
 }
 
+static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
+                                           SmallVector<int64_t> &outShape) {
+  int64_t outRank = 0;
+  for (int i = 0, e = operands.size(); i != e; ++i) {
+    auto shape = operands.getShape(i);
+    if (!shape.hasRank()) {
+      // TODO(jennik): Update function to have better case handling for invalid
+      // operands and for ranked tensors.
+      return failure();
+    }
+    outRank = std::max<int64_t>(outRank, shape.getRank());
+  }
+
+  outShape.resize(outRank, 1);
+
+  for (int i = 0, e = operands.size(); i != e; ++i) {
+    auto shape = operands.getShape(i);
+    auto rankDiff = outShape.size() - shape.getRank();
+
+    for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
+      auto dim1 = outShape[i + rankDiff];
+      auto dim2 = shape.getDimSize(i);
+      auto resolvedDim = dim1;
+
+      if (dim1 == 1) {
+        resolvedDim = dim2;
+      } else if (dim2 == 1) {
+        resolvedDim = dim1;
+      } else if (dim1 != dim2) {
+        return failure();
+      }
+      outShape[i + rankDiff] = resolvedDim;
+    }
+  }
+
+  return success();
+}
+
 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -421,6 +460,27 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::EqualOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<int64_t> outShape;
+  if (resolveBroadcastShape(operands, outShape).failed()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
+  return success();
+}
+
+bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != r.size() || l.size() != 1)
+    return false;
+  return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -870,42 +930,6 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 #undef REDUCE_SHAPE_INFER
 
-static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
-                                           SmallVector<int64_t> &outShape) {
-  int64_t outRank = 0;
-  for (int i = 0, e = operands.size(); i != e; ++i) {
-    auto shape = operands.getShape(i);
-    if (!shape.hasRank()) {
-      return failure();
-    }
-    outRank = std::max<int64_t>(outRank, shape.getRank());
-  }
-
-  outShape.resize(outRank, 1);
-
-  for (int i = 0, e = operands.size(); i != e; ++i) {
-    auto shape = operands.getShape(i);
-    auto rankDiff = outShape.size() - shape.getRank();
-
-    for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
-      auto dim1 = outShape[i + rankDiff];
-      auto dim2 = shape.getDimSize(i);
-      auto resolvedDim = dim1;
-
-      if (dim1 == 1) {
-        resolvedDim = dim2;
-      } else if (dim2 == 1) {
-        resolvedDim = dim1;
-      } else if (dim1 != dim2) {
-        return failure();
-      }
-      outShape[i + rankDiff] = resolvedDim;
-    }
-  }
-
-  return success();
-}
-
 static LogicalResult NAryInferReturnTypes(
     const ValueShapeRange &operands,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -939,7 +963,6 @@ NARY_SHAPE_INFER(tosa::CeilOp)
 NARY_SHAPE_INFER(tosa::ClampOp)
 NARY_SHAPE_INFER(tosa::ClzOp)
 NARY_SHAPE_INFER(tosa::DivOp)
-NARY_SHAPE_INFER(tosa::EqualOp)
 NARY_SHAPE_INFER(tosa::ExpOp)
 NARY_SHAPE_INFER(tosa::FloorOp)
 NARY_SHAPE_INFER(tosa::GreaterEqualOp)


        


More information about the Mlir-commits mailing list