[Mlir-commits] [mlir] 23a6564 - [MLIR][Shape] Allow `shape.rank` to operate on extent tensors

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 03:44:02 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T10:43:39Z
New Revision: 23a65648c0cd412becbef9b914e2e33f6114ba87

URL: https://github.com/llvm/llvm-project/commit/23a65648c0cd412becbef9b914e2e33f6114ba87
DIFF: https://github.com/llvm/llvm-project/commit/23a65648c0cd412becbef9b914e2e33f6114ba87.diff

LOG: [MLIR][Shape] Allow `shape.rank` to operate on extent tensors

Differential Revision: https://reviews.llvm.org/D84429

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index 3cbebe723921..4c00c1f3c8f7 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -122,4 +122,6 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
                                                Shape_ExtentTensorType],
                                               "shape or extent tensor">;
 
+def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
+
 #endif // SHAPE_BASE_TD

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 70f8d75748f4..014b72cd1339 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -201,12 +201,13 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
   }];
 
   let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
-  let results = (outs Shape_SizeType:$rank);
+  let results = (outs Shape_SizeOrIndexType:$rank);
 
-  let assemblyFormat = "$shape `:` type($shape) attr-dict";
+  let assemblyFormat = "$shape `:` type($shape) `->` type($rank) attr-dict";
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
+  let verifier = [{ return ::verify(*this); }];
 }
 
 def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c5f11a9a95d3..a7a9cb97e76b 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Shape/IR/Shape.h"
 
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -52,6 +53,8 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
     return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
   if (type.isa<WitnessType>())
     return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
+  if (type.isa<IndexType>())
+    return builder.create<ConstantOp>(loc, type, value);
   return nullptr;
 }
 
@@ -563,7 +566,17 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
 // RankOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+static LogicalResult verify(shape::RankOp op) {
+  Type argTy = op.shape().getType();
+  Type resultTy = op.rank().getType();
+  if (argTy.isa<ShapeType>() && !resultTy.isa<SizeType>())
+    return op.emitOpError()
+           << "if operand is of type `shape` then the result must be of type "
+              "`size` to propagate potential errors";
+  return success();
+}
+
+OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
   if (!shape)
     return {};
@@ -587,10 +600,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
 /// %rank = shape.const_size 3
 
 namespace {
-struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
-  using OpRewritePattern<RankOp>::OpRewritePattern;
+struct RankShapeOfCanonicalizationPattern
+    : public OpRewritePattern<shape::RankOp> {
+  using OpRewritePattern<shape::RankOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(RankOp op,
+  LogicalResult matchAndRewrite(shape::RankOp op,
                                 PatternRewriter &rewriter) const override {
     auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
     if (!shapeOfOp)
@@ -599,15 +613,18 @@ struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
         shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
     if (!rankedTensorType)
       return failure();
+    assert(op.getType().isa<IndexType>() &&
+           "expected `rank(shape_of( ... )]` based on a shaped argument to "
+           "yield an index type");
     int64_t rank = rankedTensorType.getRank();
-    rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
+    rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
     return success();
   }
 };
 } // namespace
 
-void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
-                                         MLIRContext *context) {
+void shape::RankOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
   patterns.insert<RankShapeOfCanonicalizationPattern>(context);
 }
 

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 67fb7cdb7910..934a28a6ed44 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -122,12 +122,12 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
 // Convert `rank` to `dim` of the first dimension.
 // CHECK-LABEL: @rank
 // CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
-func @rank(%shape : !shape.shape) -> !shape.size {
-  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
-  // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
-  // CHECK-DAG: return %[[RESULT]] : index
-  %rank = shape.rank %shape : !shape.shape
-  return %rank : !shape.size
+func @rank(%shape : tensor<?xindex>) -> index {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
+  // CHECK: return %[[RESULT]] : index
+  %rank = shape.rank %shape : tensor<?xindex> -> index
+  return %rank : index
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index e2874e09cc8a..e5b77a870a85 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -496,10 +496,10 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
 // Fold `rank` based on constant shape.
 // CHECK-LABEL: @fold_rank
 func @fold_rank() -> !shape.size {
-  // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
-  // CHECK-DAG: return %[[RESULT]] : !shape.size
+  // CHECK: %[[RESULT:.*]] = shape.const_size 5
+  // CHECK: return %[[RESULT]] : !shape.size
   %shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
-  %rank = shape.rank %shape : !shape.shape
+  %rank = shape.rank %shape : !shape.shape -> !shape.size
   return %rank : !shape.size
 }
 
@@ -509,38 +509,64 @@ func @fold_rank() -> !shape.size {
 // CHECK-LABEL: @dont_fold_rank
 // CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size
 func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
-  // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
-  // CHECK-DAG: return %[[RESULT]] : !shape.size
-  %rank = shape.rank %shape : !shape.shape
+  // CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
+  // CHECK: return %[[RESULT]] : !shape.size
+  %rank = shape.rank %shape : !shape.shape -> !shape.size
   return %rank : !shape.size
 }
 
 // -----
 
+// Fold `rank` based on constant extent tensor.
+// CHECK-LABEL: @fold_rank
+func @fold_rank() -> index {
+  // CHECK: %[[RESULT:.*]] = constant 5 : index
+  // CHECK: return %[[RESULT]] : index
+  %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
+  %rank = shape.rank %shape : tensor<?xindex> -> index
+  return %rank : index
+}
+
+// -----
+
+// Do not fold `rank` for non-constant extent tensors.
+// CHECK-LABEL: @dont_fold_rank
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @dont_fold_rank(%shape : tensor<?xindex>) -> index {
+  // CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]] : tensor<?xindex> -> index
+  // CHECK: return %[[RESULT]] : index
+  %rank = shape.rank %shape : tensor<?xindex> -> index
+  return %rank : index
+}
+
+// -----
+
 // Canonicalize `rank` when shape is derived from ranked tensor.
 // CHECK-LABEL: @canonicalize_rank
-func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
-  // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
-  // CHECK-DAG: return %[[RESULT]] : !shape.size
+func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index {
+  // CHECK: %[[RESULT:.*]] = constant 3 : index
+  // CHECK: return %[[RESULT]] : index
   %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor<?xindex>
-  %rank = shape.rank %shape : tensor<?xindex>
-  return %rank : !shape.size
+  %rank = shape.rank %shape : tensor<?xindex> -> index
+  return %rank : index
 }
 
 // -----
 
 // Do not canonicalize `rank` when shape is derived from unranked tensor.
 // CHECK-LABEL: @dont_canonicalize_rank
-// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
-func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
-  // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
-  // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
-  // CHECK-DAG: return %[[SIZE]] : !shape.size
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index
+func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index {
+  // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
+  // CHECK: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
+  // CHECK: return %[[SIZE]] : index
   %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
-  %rank = shape.rank %shape : tensor<?xindex>
-  return %rank : !shape.size
+  %rank = shape.rank %shape : tensor<?xindex> -> index
+  return %rank : index
 }
 
+// -----
+
 // Canonicalize redundant conversion from `index` to `size` and back.
 // CHECK-LABEL: @index_to_size_to_index
 // CHECK-SAME: (%[[IDX:.*]]: index) -> index

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index 23d0daf79378..ae25ba90c360 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -95,3 +95,10 @@ func @shape_of(%value_arg : !shape.value_shape,
   %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
 }
 
+// -----
+
+func @rank(%arg : !shape.shape) {
+  // expected-error at +1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
+  %0 = shape.rank %arg : !shape.shape -> index
+}
+

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index d0275aaf692e..f023c02c510c 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -137,13 +137,13 @@ func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
 }
 
 func @rank(%shape : !shape.shape) -> !shape.size {
-  %rank = shape.rank %shape : !shape.shape
+  %rank = shape.rank %shape : !shape.shape -> !shape.size
   return %rank : !shape.size
 }
 
-func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> !shape.size {
-  %rank = shape.rank %shape : tensor<?xindex>
-  return %rank : !shape.size
+func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
+  %rank = shape.rank %shape : tensor<?xindex> -> index
+  return %rank : index
 }
 
 


        


More information about the Mlir-commits mailing list