[Mlir-commits] [mlir] f67d57c - [mlir][Shape] Add a pattern to turn extract from shape_of into tensor.dim
Benjamin Kramer
llvmlistbot at llvm.org
Tue Oct 12 10:09:27 PDT 2021
Author: Benjamin Kramer
Date: 2021-10-12T19:09:21+02:00
New Revision: f67d57c95f50fabdfa0bbd454faa564f5059d5f4
URL: https://github.com/llvm/llvm-project/commit/f67d57c95f50fabdfa0bbd454faa564f5059d5f4
DIFF: https://github.com/llvm/llvm-project/commit/f67d57c95f50fabdfa0bbd454faa564f5059d5f4.diff
LOG: [mlir][Shape] Add a pattern to turn extract from shape_of into tensor.dim
If I remember correctly this wasn't done previously because dim used to
be in the memref dialect.
Differential Revision: https://reviews.llvm.org/D111651
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 62b4e022408aa..0a5da5d32c426 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1473,7 +1473,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+ ExtractFromShapeOfExtentTensor>(context);
}
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 7460dc5f3d33d..0825f0f680979 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -16,6 +16,9 @@ def HasStaticShape : Constraint<CPred< [{
$0.getType().dyn_cast<ShapedType>().hasStaticShape()
}]>>;
+// Helper that takes the first element of a range.
+def TakeFront : NativeCodeCall<"$0.front()">;
+
// Canonicalization patterns.
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
@@ -43,3 +46,9 @@ def SizeToIndexToSizeCanonicalization : Pat<
def TensorCastConstShape : Pat <
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;
+
+// tensor.extract from shape_of -> tensor.dim. We can take the first index
+// because shape_of always returns a 1D tensor.
+def ExtractFromShapeOfExtentTensor : Pat<
+ (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
+ (Tensor_DimOp $arg, (TakeFront $indices))>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b0c2181b5b7ba..a6b93e850761e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1380,3 +1380,17 @@ func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
-> tensor<?xindex>
return %0 : tensor<?xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @extract_shapeof
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64>
+func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
+ %c1 = constant 1 : index
+// CHECK: %[[C1:.*]] = constant 1
+ %shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+ %result = tensor.extract %shape[%c1] : tensor<2xindex>
+// CHECK: return %[[DIM]]
+ return %result : index
+}
More information about the Mlir-commits
mailing list