[Mlir-commits] [mlir] 7bca97d - [MLIR][Shape] Add canonicalization pattern for `shape.rank`

Frederik Gossen llvmlistbot at llvm.org
Thu Jun 25 01:40:22 PDT 2020


Author: Frederik Gossen
Date: 2020-06-25T08:39:35Z
New Revision: 7bca97d960ab9451185a997208057a89355b406a

URL: https://github.com/llvm/llvm-project/commit/7bca97d960ab9451185a997208057a89355b406a
DIFF: https://github.com/llvm/llvm-project/commit/7bca97d960ab9451185a997208057a89355b406a.diff

LOG: [MLIR][Shape] Add canonicalization pattern for `shape.rank`

Replace any `rank(shape_of(tensor))` that relies on a ranked tensor with the
corresponding constant `const_size`.

Differential Revision: https://reviews.llvm.org/D82077

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 379f861151dc..2430fe62f13b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -130,6 +130,10 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
   let arguments = (ins IndexAttr:$value);
   let results = (outs Shape_SizeType:$result);
 
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value">
+  ];
+
   let assemblyFormat = "$value attr-dict";
   let hasFolder = 1;
 }
@@ -181,6 +185,7 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
   let assemblyFormat = "attr-dict $shape";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index cdbc89289c4e..2d952183050e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -364,6 +364,11 @@ OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
 // ConstSizeOp
 //===----------------------------------------------------------------------===//
 
+void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
+                        int64_t value) {
+  build(builder, result, builder.getIndexAttr(value));
+}
+
 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
 
 void ConstSizeOp::getAsmResultNames(
@@ -450,6 +455,45 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
   return builder.getIndexAttr(rank);
 }
 
+/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
+/// Constant folding fails in cases where only the rank is constant, not the
+/// shape itself.
+/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
+///
+/// Example:
+///
+/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
+/// %rank = shape.rank %shape
+///
+/// becomes
+///
+/// %rank = shape.const_size 3
+
+namespace {
+struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
+  using OpRewritePattern<RankOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(RankOp op,
+                                PatternRewriter &rewriter) const override {
+    auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
+    if (!shapeOfOp)
+      return failure();
+    auto rankedTensorType =
+        shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
+    if (!rankedTensorType)
+      return failure();
+    int64_t rank = rankedTensorType.getRank();
+    rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
+    return success();
+  }
+};
+} // namespace
+
+void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+                                         MLIRContext *context) {
+  patterns.insert<RankShapeOfCanonicalizationPattern>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // NumElementsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 00f6b361a93e..9fb48e6c896e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -466,3 +466,29 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
   %rank = shape.rank %shape
   return %rank : !shape.size
 }
+
+// -----
+
+// Canonicalize `rank` when shape is derived from ranked tensor.
+// CHECK-LABEL: @canonicalize_rank
+func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
+// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
+// CHECK-DAG: return %[[RESULT]] : !shape.size
+%shape = shape.shape_of %arg : tensor<1x2x?xf32>
+%rank = shape.rank %shape
+return %rank : !shape.size
+}
+
+// -----
+
+// Do not canonicalize `rank` when shape is derived from unranked tensor.
+// CHECK-LABEL: @dont_canonicalize_rank
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
+func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
+// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
+// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
+// CHECK-DAG: return %[[SIZE]] : !shape.size
+%shape = shape.shape_of %arg : tensor<*xf32>
+%rank = shape.rank %shape
+return %rank : !shape.size
+}


        


More information about the Mlir-commits mailing list