[Mlir-commits] [mlir] 630afc6 - [MLIR][Shape] Canonicalize casted dynamic extent tensor
Frederik Gossen
llvmlistbot at llvm.org
Mon Mar 29 04:59:35 PDT 2021
Author: Frederik Gossen
Date: 2021-03-29T13:59:19+02:00
New Revision: 630afc61a85429c2b0e6dbc9ef08e6013be4ad52
URL: https://github.com/llvm/llvm-project/commit/630afc61a85429c2b0e6dbc9ef08e6013be4ad52
DIFF: https://github.com/llvm/llvm-project/commit/630afc61a85429c2b0e6dbc9ef08e6013be4ad52.diff
LOG: [MLIR][Shape] Canonicalize casted dynamic extent tensor
Differential Revision: https://reviews.llvm.org/D99161
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a1419322afb3..bb7ed5cf05ce 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -987,11 +987,43 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
return success();
}
};
+
+// Canonicalize
+// ```
+// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
+// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
+// ```
+// to
+// ```
+// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
+// ```
+struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
+ using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::CastOp op,
+ PatternRewriter &rewriter) const override {
+ auto ty = op.getType().dyn_cast<RankedTensorType>();
+ if (!ty || ty.getRank() != 1)
+ return failure();
+
+ auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
+ if (!shapeOfOp)
+ return failure();
+
+ // Argument type must be ranked and must not conflict.
+ auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
+ if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
+ return success();
+ }
+};
} // namespace
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfWithTensor>(context);
+ patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 39f17e9d253f..b0c12ea0b149 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -648,7 +648,7 @@ func @f() {
// CHECK: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
- %cs0 = shape.const_shape [8, 1] : !shape.shape
+ %cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, 8] : !shape.shape
%cs2 = shape.const_shape [1, -1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
@@ -1144,3 +1144,47 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
"use"(%0) : (tensor<?xindex>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
+func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
+ // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
+ // CHECK: return %[[RESULT]] : tensor<?xindex>
+ %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
+ %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
+func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
+ // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
+ // CHECK: return %[[RESULT]] : tensor<3xindex>
+ %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
+ %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
+ return %1 : tensor<3xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
+ // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+ %0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
+ %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
+ return %1 : tensor<3xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @casted_extent_tensor
+func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
+ // CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+ %0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
+ %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
+ return %1 : tensor<3xindex>
+}
More information about the Mlir-commits
mailing list