[Mlir-commits] [mlir] 2e7baf6 - [MLIR][Shape] Allow `shape.add` to operate on indices

Frederik Gossen llvmlistbot at llvm.org
Wed Jul 29 03:23:57 PDT 2020


Author: Frederik Gossen
Date: 2020-07-29T10:23:37Z
New Revision: 2e7baf61970a08a77540f3afdbefbb20b4a78348

URL: https://github.com/llvm/llvm-project/commit/2e7baf61970a08a77540f3afdbefbb20b4a78348
DIFF: https://github.com/llvm/llvm-project/commit/2e7baf61970a08a77540f3afdbefbb20b4a78348.diff

LOG: [MLIR][Shape] Allow `shape.add` to operate on indices

Differential Revision: https://reviews.llvm.org/D84441

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 7b676a2b0598..9ba838dbbb26 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -53,18 +53,23 @@ def Shape_WitnessType : DialectType<ShapeDialect,
 class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
     Op<ShapeDialect, mnemonic, traits>;
 
-def Shape_AddOp : Shape_Op<"add", [Commutative, SameOperandsAndResultType]> {
-  let summary = "Addition of sizes";
+def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
+  let summary = "Addition of sizes and indices";
   let description = [{
-    Adds two valid sizes as follows:
-    * lhs + rhs = unknown if either lhs or rhs unknown;
-    * lhs + rhs = (int)lhs + (int)rhs if known;
+    Adds two sizes or indices. If either operand is an error it will be
+    propagated to the result. The operands can be of type `size` or `index`. If
+    at least one of the operands can hold an error, i.e. if it is of type `size`,
+    then also the result must be of type `size`.
   }];
 
-  let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs);
-  let results = (outs Shape_SizeType:$result);
+  let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
+  let results = (outs Shape_SizeOrIndexType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict";
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+  }];
+
+  let verifier = [{ return verifySizeOrIndexOp(*this); }];
 }
 
 def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 4a45181d4587..d804efbd9e4e 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -105,7 +105,7 @@ func @rank(%arg : !shape.shape) {
 
 // -----
 
-func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
+func @get_extent(%arg : tensor<?xindex>) -> index {
   %c0 = shape.const_size 0
   // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
   %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
@@ -114,7 +114,7 @@ func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
 
 // -----
 
-func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
+func @mul(%lhs : !shape.size, %rhs : index) -> index {
   // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
   %result = shape.mul %lhs, %rhs : !shape.size, index -> index
   return %result : index
@@ -122,9 +122,17 @@ func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
 
 // -----
 
-func @num_elements_error_possible(%arg : !shape.shape) -> index {
+func @num_elements(%arg : !shape.shape) -> index {
   // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
   %result = shape.num_elements %arg : !shape.shape -> index
   return %result : index
 }
 
+// -----
+
+func @add(%lhs : !shape.size, %rhs : index) -> index {
+  // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
+  %result = shape.add %lhs, %rhs : !shape.size, index -> index
+  return %result : index
+}
+

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 87af623fe0f7..229f3948d31d 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -120,6 +120,15 @@ func @mul(%size_arg : !shape.size, %index_arg : index) {
   return
 }
 
+func @add(%size_arg : !shape.size, %index_arg : index) {
+  %size_sum = shape.add %size_arg, %size_arg
+      : !shape.size, !shape.size -> !shape.size
+  %index_sum = shape.add %index_arg, %index_arg : index, index -> index
+  %mixed_sum = shape.add %size_arg, %index_arg
+      : !shape.size, index -> !shape.size
+  return
+}
+
 func @const_size() {
   // CHECK: %c1 = shape.const_size 1
   // CHECK: %c2 = shape.const_size 2


        


More information about the Mlir-commits mailing list