[Mlir-commits] [mlir] 0e1a42e - [MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 01:35:06 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T08:34:37Z
New Revision: 0e1a42efd8b8702a6adcb09802b84bdd8e727a19
URL: https://github.com/llvm/llvm-project/commit/0e1a42efd8b8702a6adcb09802b84bdd8e727a19
DIFF: https://github.com/llvm/llvm-project/commit/0e1a42efd8b8702a6adcb09802b84bdd8e727a19.diff
LOG: [MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors
`shape.get_extent` now accepts extent tensors `tensor<?xindex>` as an argument.
Differential Revision: https://reviews.llvm.org/D84158
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 1bdfd9a071e3..2302c5110f65 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -228,17 +228,15 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
}
def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
- let summary = "Gets the specified extent from a shape";
+ let summary = "Gets the specified extent from a shape or extent tensor";
let description = [{
- Gets the extent indexed by `dim` from `shape`.
- If the shape is an error, it returns an error size.
+ Gets the extent indexed by `dim` from the `shape` operand. If the shape is
+ an error then it returns an error size.
}];
- let arguments = (ins
- Shape_ShapeType:$shape,
- Shape_SizeType:$dim
- );
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
+ Shape_SizeType:$dim);
let results = (outs Shape_SizeType:$extent);
- let assemblyFormat = "$shape `,` $dim attr-dict";
+ let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
let builders = [
// Builder that allows passing a constant dimension as a simple integer.
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 0619a7314e40..f50b6530d9d7 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -139,7 +139,7 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
// CHECK: return %[[RESULT]] : index
%shape = shape.shape_of %arg : tensor<2x3xf32>
- %result = shape.get_extent %shape, %idx
+ %result = shape.get_extent %shape, %idx : !shape.shape
return %result : !shape.size
}
@@ -154,7 +154,7 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
// CHECK: return %[[RESULT]] : index
%shape = shape.from_extent_tensor %extents : tensor<?xindex>
- %result = shape.get_extent %shape, %idx
+ %result = shape.get_extent %shape, %idx : !shape.shape
return %result : !shape.size
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 20f21bbc877e..9e691b88b016 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -239,9 +239,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
- %0 = shape.const_shape [0, 1, 2] : !shape.shape
+ %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%c2 = shape.const_size 2
- %1 = shape.get_extent %0, %c2
+ %1 = shape.get_extent %0, %c2 : tensor<?xindex>
return %1 : !shape.size
}
@@ -252,9 +252,9 @@ func @basic() -> !shape.size {
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
- %0 = shape.const_shape [0, 1, 2] : !shape.shape
+ %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%c3 = shape.const_size 3
- %1 = shape.get_extent %0, %c3
+ %1 = shape.get_extent %0, %c3 : tensor<?xindex>
return %1 : !shape.size
}
@@ -262,10 +262,10 @@ func @out_of_bounds() -> !shape.size {
// Should not fold.
// CHECK-LABEL: func @not_const
-func @not_const(%arg0: !shape.shape) -> !shape.size {
+func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
// CHECK: shape.get_extent
%c3 = shape.const_size 3
- %0 = shape.get_extent %arg0, %c3
+ %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
return %0 : !shape.size
}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index aace26de0ea2..66b5834ff653 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -161,3 +161,15 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
return %result : i1
}
+
+func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
+ %c0 = shape.const_size 0
+ %result = shape.get_extent %arg, %c0 : !shape.shape
+ return %result : !shape.size
+}
+
+func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
+ %c0 = shape.const_size 0
+ %result = shape.get_extent %arg, %c0 : tensor<?xindex>
+ return %result : !shape.size
+}
More information about the Mlir-commits
mailing list