[Mlir-commits] [mlir] 71e7a37 - [MLIR][Shape] Allow `shape.rank` to accept extent tensors `tensor?xindex>`

Frederik Gossen llvmlistbot at llvm.org
Mon Jul 20 07:47:36 PDT 2020


Author: Frederik Gossen
Date: 2020-07-20T14:47:19Z
New Revision: 71e7a37e7eafdffab9b9382d4a0abc0462eb78ce

URL: https://github.com/llvm/llvm-project/commit/71e7a37e7eafdffab9b9382d4a0abc0462eb78ce
DIFF: https://github.com/llvm/llvm-project/commit/71e7a37e7eafdffab9b9382d4a0abc0462eb78ce.diff

LOG: [MLIR][Shape] Allow `shape.rank` to accept extent tensors `tensor?xindex>`

Differential Revision: https://reviews.llvm.org/D84156

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 46400c8e9eff..703353c35f5a 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -195,13 +195,13 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
   let summary = "Gets the rank of a shape";
   let description = [{
-    Returns the rank of the shape, i.e. the number of extents.
+    Returns the rank of the shape or extent tensor, i.e. the number of extents.
   }];
 
-  let arguments = (ins Shape_ShapeType:$shape);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
   let results = (outs Shape_SizeType:$rank);
 
-  let assemblyFormat = "attr-dict $shape";
+  let assemblyFormat = "$shape `:` type($shape) attr-dict";
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 22206637adbf..0619a7314e40 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -124,7 +124,7 @@ func @rank(%shape : !shape.shape) -> !shape.size {
   // CHECK-DAG: %[[C0:.*]] = constant 0 : index
   // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
   // CHECK-DAG: return %[[RESULT]] : index
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 156063ea002c..80b7cb9ddb94 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -499,7 +499,7 @@ func @fold_rank() -> !shape.size {
   // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
   // CHECK-DAG: return %[[RESULT]] : !shape.size
   %shape = shape.const_shape [3, 4, 5, 6, 7]
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
@@ -511,7 +511,7 @@ func @fold_rank() -> !shape.size {
 func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
   // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
   // CHECK-DAG: return %[[RESULT]] : !shape.size
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
@@ -520,11 +520,11 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
 // Canonicalize `rank` when shape is derived from ranked tensor.
 // CHECK-LABEL: @canonicalize_rank
 func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
-// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
-// CHECK-DAG: return %[[RESULT]] : !shape.size
-%shape = shape.shape_of %arg : tensor<1x2x?xf32>
-%rank = shape.rank %shape
-return %rank : !shape.size
+  // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
+  // CHECK-DAG: return %[[RESULT]] : !shape.size
+  %shape = shape.shape_of %arg : tensor<1x2x?xf32>
+  %rank = shape.rank %shape : !shape.shape
+  return %rank : !shape.size
 }
 
 // -----
@@ -533,12 +533,12 @@ return %rank : !shape.size
 // CHECK-LABEL: @dont_canonicalize_rank
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
 func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
-// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
-// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
-// CHECK-DAG: return %[[SIZE]] : !shape.size
-%shape = shape.shape_of %arg : tensor<*xf32>
-%rank = shape.rank %shape
-return %rank : !shape.size
+  // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
+  // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
+  // CHECK-DAG: return %[[SIZE]] : !shape.size
+  %shape = shape.shape_of %arg : tensor<*xf32>
+  %rank = shape.rank %shape : !shape.shape
+  return %rank : !shape.size
 }
 
 // Canonicalize redundant conversion from `index` to `size` and back.

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 30cf29a083ec..1187d7ad92bb 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -130,10 +130,16 @@ func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
 }
 
 func @rank(%shape : !shape.shape) -> !shape.size {
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
+func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> !shape.size {
+  %rank = shape.rank %shape : tensor<?xindex>
+  return %rank : !shape.size
+}
+
+
 func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1


        


More information about the Mlir-commits mailing list