[Mlir-commits] [mlir] d876e32 - [MLIR][Shape] Lower `shape.get_extent` to `extract_element` when possible

Frederik Gossen llvmlistbot at llvm.org
Mon Jun 29 01:39:40 PDT 2020


Author: Frederik Gossen
Date: 2020-06-29T08:39:22Z
New Revision: d876e3202af3057cc180d7540d0de8b20234f114

URL: https://github.com/llvm/llvm-project/commit/d876e3202af3057cc180d7540d0de8b20234f114
DIFF: https://github.com/llvm/llvm-project/commit/d876e3202af3057cc180d7540d0de8b20234f114.diff

LOG: [MLIR][Shape] Lower `shape.get_extent` to `extract_element` when possible

When the origin of a shape is an extent tensor the operation `get_extent` can be
lowered directly to `extract_element`.
This choice circumvents the necessity to materialize the shape in memory.

Differential Revision: https://reviews.llvm.org/D82645

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
index 473da36a84ec..154cf6a9e1f7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -26,3 +26,13 @@ def GetExtentShapeOfConversion : Pat<
     (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
     [],
     (addBenefit 10)>;
+def GetExtentFromExtentTensorConversion : Pattern<
+    (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
+    [
+      (Shape_SizeToIndexOp:$std_idx $idx),
+      (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
+      (Shape_IndexToSizeOp $std_result)
+    ],
+    [],
+    (addBenefit 10)>;
+

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index f9daadd03196..28ef190d09eb 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -143,3 +143,18 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
   return %result : !shape.size
 }
 
+// -----
+
+// Express `get_extent` as `std.extract_element` when it relies directly on the
+// outcome of a `from_extent_tensor` operation.
+// CHECK-LABEL: @get_extent_from_extent_tensor
+// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
+func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
+                                    %idx : !shape.size) -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
+  // CHECK: return %[[RESULT]] : index
+  %shape = shape.from_extent_tensor %extents : tensor<?xindex>
+  %result = shape.get_extent %shape, %idx
+  return %result : !shape.size
+}
+


        


More information about the Mlir-commits mailing list