[Mlir-commits] [mlir] 55ced04 - [MLIR][Shape] Allow `num_elements` to operate on extent tensors
Jacques Pienaar
llvmlistbot at llvm.org
Sat Jul 25 14:42:03 PDT 2020
Author: Frederik Gossen
Date: 2020-07-25T14:41:05-07:00
New Revision: 55ced04d6bc13fd0f9396a0cfc393b44378d8784
URL: https://github.com/llvm/llvm-project/commit/55ced04d6bc13fd0f9396a0cfc393b44378d8784
DIFF: https://github.com/llvm/llvm-project/commit/55ced04d6bc13fd0f9396a0cfc393b44378d8784.diff
LOG: [MLIR][Shape] Allow `num_elements` to operate on extent tensors
Differential Revision: https://reviews.llvm.org/D84445
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
mlir/test/Dialect/Shape/shape-to-shape.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 797dc0bc0cb6..abbd8f093109 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -333,19 +333,19 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
let summary = "Returns the number of elements for a given shape";
let description = [{
Returns the number of elements for a given shape which is the product of its
- dimensions.
-
- ```mlir
- %product = shape.mul %lhs, %rhs
- ```
+ extents. If the argument is of type `shape` then the result will be of type
+ `size` and potential errors will be propagated. Otherwise, if the argument
+ is and extent tensor `tensor<?xindex>` then the result will be of type
+ `index`.
}];
- let arguments = (ins Shape_ShapeType:$shape);
- let results = (outs Shape_SizeType:$result);
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
+ let results = (outs Shape_SizeOrIndexType:$result);
- let assemblyFormat = "$shape attr-dict";
+ let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict";
let hasFolder = 1;
+ let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
}
def Shape_ReduceOp : Shape_Op<"reduce",
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 577656a0b362..e147fbeb81ac 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -217,7 +217,7 @@ func @num_elements() -> !shape.size {
// CHECK-NOT: shape.const_shape
%shape = shape.const_shape [4, 5, 6] : !shape.shape
// CHECK-NOT: shape.num_elements
- %num_elements = shape.num_elements %shape
+ %num_elements = shape.num_elements %shape : !shape.shape -> !shape.size
// CHECK: %[[NUM:.*]] = shape.const_size 120
// CHECK-NEXT: return %[[NUM]] : !shape.size
return %num_elements : !shape.size
@@ -229,7 +229,7 @@ func @num_elements() -> !shape.size {
// CHECK-LABEL: func @nonfoldable_num_elements
func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// CHECK-NOT: shape.const_{{.*}}
- %num_elements = shape.num_elements %shape
+ %num_elements = shape.num_elements %shape : !shape.shape -> !shape.size
return %num_elements : !shape.size
}
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index b4900e491fb8..0bbd6cec777d 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -146,3 +146,19 @@ func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
return %result : index
}
+// -----
+
+func @num_elements_error_free(%arg : tensor<?xindex>) -> !shape.size {
+ // expected-error at +1 {{if none of the operands can hold error values then the result must be of type `index`}}
+ %result = shape.num_elements %arg : tensor<?xindex> -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
+func @num_elements_error_possible(%arg : !shape.shape) -> index {
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
+ %result = shape.num_elements %arg : !shape.shape -> index
+ return %result : index
+}
+
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 3a0cb7781ec7..f57826097d34 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -195,3 +195,14 @@ func @any() {
return
}
+func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index {
+ %result = shape.num_elements %arg : tensor<?xindex> -> index
+ return %result : index
+}
+
+func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
+ %result = shape.num_elements %arg : !shape.shape -> !shape.size
+ return %result : !shape.size
+}
+
+
diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir
index 9a75f0b9ca1b..d1b00bc12a22 100644
--- a/mlir/test/Dialect/Shape/shape-to-shape.mlir
+++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s
-// CHECK-LABEL: func @num_elements_to_reduce(
-// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> !shape.size {
+// CHECK-LABEL: func @num_elements_to_reduce
+// CHECK-SAME: ([[ARG:%.*]]: !shape.shape) -> !shape.size
func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
- %num_elements = shape.num_elements %shape
+ %num_elements = shape.num_elements %shape : !shape.shape -> !shape.size
return %num_elements : !shape.size
}
// CHECK: [[C1:%.*]] = shape.const_size 1
More information about the Mlir-commits
mailing list