[Mlir-commits] [mlir] 28fe1a4 - [mlir] Add trait SameOperandsAndResultRank
Eric Kunze
llvmlistbot at llvm.org
Thu Sep 7 10:01:53 PDT 2023
Author: Tai Ly
Date: 2023-09-07T16:55:54Z
New Revision: 28fe1a4e5e8af39a6a0fa253b3538cb0905069dc
URL: https://github.com/llvm/llvm-project/commit/28fe1a4e5e8af39a6a0fa253b3538cb0905069dc
DIFF: https://github.com/llvm/llvm-project/commit/28fe1a4e5e8af39a6a0fa253b3538cb0905069dc.diff
LOG: [mlir] Add trait SameOperandsAndResultRank
This adds a native op trait SameOperandsAndResultRank
and associated verifier that checks that an operator's
operands and result types have same ranks if their ranks
are known.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I2d536f77be10f3710d0c8d84c907ff492a984fda
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D156369
Added:
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/lib/IR/Operation.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/types.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 67d923adbf9374b..84ba46f4d6f3ec1 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -341,6 +341,7 @@ LogicalResult verifySameOperandsAndResultShape(Operation *op);
LogicalResult verifySameOperandsElementType(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
+LogicalResult verifySameOperandsAndResultRank(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
@@ -1114,6 +1115,17 @@ class SameOperandsAndResultType
}
};
+/// This class verifies that op has same ranks for all
+/// operands and results types, if known.
+template <typename ConcreteType>
+class SameOperandsAndResultRank
+ : public TraitBase<ConcreteType, SameOperandsAndResultRank> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return impl::verifySameOperandsAndResultRank(op);
+ }
+};
+
/// This class verifies that any results of the specified op have a boolean
/// type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1ceaf25a994e01a..54c1c13fd029dbc 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -369,4 +369,7 @@ def ReifyRankedShapedTypeOpInterface :
// TODO: Change from hard coded to utilizing type inference trait.
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
+// Op has the same ranks for all operands and results types, if known.
+def SameOperandsAndResultRank : NativeOpTrait<"SameOperandsAndResultRank">;
+
#endif // MLIR_INFERTYPEOPINTERFACE
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 888c146f49539f6..ef98a89f4bb49b6 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1082,6 +1082,51 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
return success();
}
+LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
+ if (failed(verifyAtLeastNOperands(op, 1)))
+ return failure();
+
+ // delegate function that returns true if type is a shaped type with known
+ // rank
+ auto hasRank = [](const Type type) {
+ if (auto shaped_type = dyn_cast<ShapedType>(type))
+ return shaped_type.hasRank();
+
+ return false;
+ };
+
+ auto rankedOperandTypes =
+ llvm::make_filter_range(op->getOperandTypes(), hasRank);
+ auto rankedResultTypes =
+ llvm::make_filter_range(op->getResultTypes(), hasRank);
+
+ // If all operands and results are unranked, then no further verification.
+ if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ return success();
+
+ // delegate function that returns rank of shaped type with known rank
+ auto getRank = [](const Type type) {
+ return type.cast<ShapedType>().getRank();
+ };
+
+ auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+ : getRank(*rankedResultTypes.begin());
+
+ for (const auto type : rankedOperandTypes) {
+ if (rank != getRank(type)) {
+ return op->emitOpError("operands don't have matching ranks");
+ }
+ }
+
+ for (const auto type : rankedResultTypes) {
+ if (rank != getRank(type)) {
+ return op->emitOpError("result type has
diff erent rank than operands");
+ }
+ }
+
+ return success();
+}
+
LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
Block *block = op->getBlock();
// Verify that the operation is at the end of the respective parent block.
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2d7f5b0043ba0f6..9ceadab8fa4a086 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -692,6 +692,12 @@ def OperandZeroAndResultHaveSameRank :
let results = (outs AnyShaped:$res);
}
+def OperandsAndResultHaveSameRank :
+ TEST_Op<"operands_and_result_have_same_rank", [SameOperandsAndResultRank]> {
+ let arguments = (ins AnyShaped:$x, AnyShaped:$y);
+ let results = (outs AnyShaped:$res);
+}
+
def OperandZeroAndResultHaveSameShape :
TEST_Op<"operand0_and_result_have_same_shape",
[AllShapesMatch<["x", "res"]>]> {
diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index d36a390dff8f185..7652a87037c9295 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -377,6 +377,33 @@ func.func @same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
// -----
+// CHECK-LABEL: same_rank_if_known_success
+func.func @same_rank_if_known_success(%t1xi : tensor<1xi32>, %t2xf : tensor<2xf32>, %m3xi : memref<3xi32>, %t1x2xf : tensor<1x2xf32>, %tuxi : tensor<*xi32>) {
+ %0 = "test.operands_and_result_have_same_rank"(%t1xi, %t2xf) : (tensor<1xi32>, tensor<2xf32>) -> (tensor<3xf64>)
+ %1 = "test.operands_and_result_have_same_rank"(%t1xi, %m3xi) : (tensor<1xi32>, memref<3xi32>) -> (tensor<3xi64>)
+ %3 = "test.operands_and_result_have_same_rank"(%tuxi, %t2xf) : (tensor<*xi32>, tensor<2xf32>) -> (tensor<2xf32>)
+ %4 = "test.operands_and_result_have_same_rank"(%t1x2xf, %tuxi) : (tensor<1x2xf32>, tensor<*xi32>) -> (tensor<1x2xf64>)
+ return
+}
+
+// -----
+
+func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
+ // expected-error at +1 {{operands don't have matching ranks}}
+ %0 = "test.operands_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<*xf32>)
+ return
+}
+
+// -----
+
+func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
+ // expected-error at +1 {{result type has
diff erent rank than operands}}
+ %0 = "test.operands_and_result_have_same_rank"(%arg1, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2x3xf32>)
+ return
+}
+
+// -----
+
// CHECK-LABEL: same_shape_success
func.func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)
More information about the Mlir-commits
mailing list