[Mlir-commits] [mlir] 0e1a42e - [MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 01:35:06 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T08:34:37Z
New Revision: 0e1a42efd8b8702a6adcb09802b84bdd8e727a19

URL: https://github.com/llvm/llvm-project/commit/0e1a42efd8b8702a6adcb09802b84bdd8e727a19
DIFF: https://github.com/llvm/llvm-project/commit/0e1a42efd8b8702a6adcb09802b84bdd8e727a19.diff

LOG: [MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors

`shape.get_extent` now accepts extent tensors `tensor<?xindex>` as an argument.

Differential Revision: https://reviews.llvm.org/D84158

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 1bdfd9a071e3..2302c5110f65 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -228,17 +228,15 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
 }
 
 def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
-  let summary = "Gets the specified extent from a shape";
+  let summary = "Gets the specified extent from a shape or extent tensor";
   let description = [{
-    Gets the extent indexed by `dim` from `shape`.
-    If the shape is an error, it returns an error size.
+    Gets the extent indexed by `dim` from the `shape` operand. If the shape is
+    an error then it returns an error size.
   }];
-  let arguments = (ins
-    Shape_ShapeType:$shape,
-    Shape_SizeType:$dim
-  );
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
+                       Shape_SizeType:$dim);
   let results = (outs Shape_SizeType:$extent);
-  let assemblyFormat = "$shape `,` $dim attr-dict";
+  let assemblyFormat = "$shape `,` $dim `:` type($shape)  attr-dict";
 
   let builders = [
     // Builder that allows passing a constant dimension as a simple integer.

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 0619a7314e40..f50b6530d9d7 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -139,7 +139,7 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
   // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
   // CHECK: return %[[RESULT]] : index
   %shape = shape.shape_of %arg : tensor<2x3xf32>
-  %result = shape.get_extent %shape, %idx
+  %result = shape.get_extent %shape, %idx : !shape.shape
   return %result : !shape.size
 }
 
@@ -154,7 +154,7 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
   // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
   // CHECK: return %[[RESULT]] : index
   %shape = shape.from_extent_tensor %extents : tensor<?xindex>
-  %result = shape.get_extent %shape, %idx
+  %result = shape.get_extent %shape, %idx : !shape.shape
   return %result : !shape.size
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 20f21bbc877e..9e691b88b016 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -239,9 +239,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
 // CHECK-LABEL: func @basic
 func @basic() -> !shape.size {
   // CHECK: shape.const_size 2
-  %0 = shape.const_shape [0, 1, 2] : !shape.shape
+  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
   %c2 = shape.const_size 2
-  %1 = shape.get_extent %0, %c2
+  %1 = shape.get_extent %0, %c2 : tensor<?xindex>
   return %1 : !shape.size
 }
 
@@ -252,9 +252,9 @@ func @basic() -> !shape.size {
 func @out_of_bounds() -> !shape.size {
   // CHECK: shape.const_shape
   // CHECK: shape.get_extent
-  %0 = shape.const_shape [0, 1, 2] : !shape.shape
+  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
   %c3 = shape.const_size 3
-  %1 = shape.get_extent %0, %c3
+  %1 = shape.get_extent %0, %c3 : tensor<?xindex>
   return %1 : !shape.size
 }
 
@@ -262,10 +262,10 @@ func @out_of_bounds() -> !shape.size {
 
 // Should not fold.
 // CHECK-LABEL: func @not_const
-func @not_const(%arg0: !shape.shape) -> !shape.size {
+func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
   // CHECK: shape.get_extent
   %c3 = shape.const_size 3
-  %0 = shape.get_extent %arg0, %c3
+  %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
   return %0 : !shape.size
 }
 

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index aace26de0ea2..66b5834ff653 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -161,3 +161,15 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
   return %result : i1
 }
+
+func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
+  %c0 = shape.const_size 0
+  %result = shape.get_extent %arg, %c0 : !shape.shape
+  return %result : !shape.size
+}
+
+func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
+  %c0 = shape.const_size 0
+  %result = shape.get_extent %arg, %c0 : tensor<?xindex>
+  return %result : !shape.size
+}


        


More information about the Mlir-commits mailing list