[Mlir-commits] [mlir] f67d57c - [mlir][Shape] Add a pattern to turn extract from shape_of into tensor.dim

Benjamin Kramer llvmlistbot at llvm.org
Tue Oct 12 10:09:27 PDT 2021


Author: Benjamin Kramer
Date: 2021-10-12T19:09:21+02:00
New Revision: f67d57c95f50fabdfa0bbd454faa564f5059d5f4

URL: https://github.com/llvm/llvm-project/commit/f67d57c95f50fabdfa0bbd454faa564f5059d5f4
DIFF: https://github.com/llvm/llvm-project/commit/f67d57c95f50fabdfa0bbd454faa564f5059d5f4.diff

LOG: [mlir][Shape] Add a pattern to turn extract from shape_of into tensor.dim

If I remember correctly this wasn't done previously because dim used to
be in the memref dialect.

Differential Revision: https://reviews.llvm.org/D111651

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 62b4e022408aa..0a5da5d32c426 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1473,7 +1473,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
+  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+               ExtractFromShapeOfExtentTensor>(context);
 }
 
 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(

diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 7460dc5f3d33d..0825f0f680979 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -16,6 +16,9 @@ def HasStaticShape : Constraint<CPred< [{
   $0.getType().dyn_cast<ShapedType>().hasStaticShape()
 }]>>;
 
+// Helper that takes the first element of a range.
+def TakeFront : NativeCodeCall<"$0.front()">;
+
 // Canonicalization patterns.
 
 def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
@@ -43,3 +46,9 @@ def SizeToIndexToSizeCanonicalization : Pat<
 def TensorCastConstShape : Pat <
   (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
   [(HasStaticShape $res)]>;
+
+// tensor.extract from shape_of -> tensor.dim. We can take the first index
+// because shape_of always returns a 1D tensor.
+def ExtractFromShapeOfExtentTensor : Pat<
+  (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
+  (Tensor_DimOp $arg, (TakeFront $indices))>;

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index b0c2181b5b7ba..a6b93e850761e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1380,3 +1380,17 @@ func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
       -> tensor<?xindex>
   return %0 : tensor<?xindex>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_shapeof
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<?x?xf64>
+func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
+ %c1 = constant 1 : index
+// CHECK:        %[[C1:.*]] = constant 1
+ %shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
+// CHECK:        %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+ %result = tensor.extract %shape[%c1] : tensor<2xindex>
+// CHECK:        return %[[DIM]]
+ return %result : index
+}


        


More information about the Mlir-commits mailing list