[Mlir-commits] [mlir] 2e7baf6 - [MLIR][Shape] Allow `shape.add` to operate on indices
Frederik Gossen
llvmlistbot at llvm.org
Wed Jul 29 03:23:57 PDT 2020
Author: Frederik Gossen
Date: 2020-07-29T10:23:37Z
New Revision: 2e7baf61970a08a77540f3afdbefbb20b4a78348
URL: https://github.com/llvm/llvm-project/commit/2e7baf61970a08a77540f3afdbefbb20b4a78348
DIFF: https://github.com/llvm/llvm-project/commit/2e7baf61970a08a77540f3afdbefbb20b4a78348.diff
LOG: [MLIR][Shape] Allow `shape.add` to operate on indices
Differential Revision: https://reviews.llvm.org/D84441
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 7b676a2b0598..9ba838dbbb26 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -53,18 +53,23 @@ def Shape_WitnessType : DialectType<ShapeDialect,
class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ShapeDialect, mnemonic, traits>;
-def Shape_AddOp : Shape_Op<"add", [Commutative, SameOperandsAndResultType]> {
- let summary = "Addition of sizes";
+def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
+ let summary = "Addition of sizes and indices";
let description = [{
- Adds two valid sizes as follows:
- * lhs + rhs = unknown if either lhs or rhs unknown;
- * lhs + rhs = (int)lhs + (int)rhs if known;
+ Adds two sizes or indices. If either operand is an error it will be
+ propagated to the result. The operands can be of type `size` or `index`. If
+ at least one of the operands can hold an error, i.e. if it is of type `size`,
+ then also the result must be of type `size`.
}];
- let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
- let results = (outs Shape_SizeType:$result);
+ let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
+ let results = (outs Shape_SizeOrIndexType:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict";
+ let assemblyFormat = [{
+ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ }];
+
+ let verifier = [{ return verifySizeOrIndexOp(*this); }];
}
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 4a45181d4587..d804efbd9e4e 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -105,7 +105,7 @@ func @rank(%arg : !shape.shape) {
// -----
-func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
+func @get_extent(%arg : tensor<?xindex>) -> index {
%c0 = shape.const_size 0
// expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
@@ -114,7 +114,7 @@ func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
// -----
-func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
+func @mul(%lhs : !shape.size, %rhs : index) -> index {
// expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.mul %lhs, %rhs : !shape.size, index -> index
return %result : index
@@ -122,9 +122,17 @@ func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
// -----
-func @num_elements_error_possible(%arg : !shape.shape) -> index {
+func @num_elements(%arg : !shape.shape) -> index {
// expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.num_elements %arg : !shape.shape -> index
return %result : index
}
+// -----
+
+func @add(%lhs : !shape.size, %rhs : index) -> index {
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
+ %result = shape.add %lhs, %rhs : !shape.size, index -> index
+ return %result : index
+}
+
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 87af623fe0f7..229f3948d31d 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -120,6 +120,15 @@ func @mul(%size_arg : !shape.size, %index_arg : index) {
return
}
+func @add(%size_arg : !shape.size, %index_arg : index) {
+ %size_sum = shape.add %size_arg, %size_arg
+ : !shape.size, !shape.size -> !shape.size
+ %index_sum = shape.add %index_arg, %index_arg : index, index -> index
+ %mixed_sum = shape.add %size_arg, %index_arg
+ : !shape.size, index -> !shape.size
+ return
+}
+
func @const_size() {
// CHECK: %c1 = shape.const_size 1
// CHECK: %c2 = shape.const_size 2
More information about the Mlir-commits
mailing list