[Mlir-commits] [mlir] d876e32 - [MLIR][Shape] Lower `shape.get_extent` to `extract_element` when possible
Frederik Gossen
llvmlistbot at llvm.org
Mon Jun 29 01:39:40 PDT 2020
Author: Frederik Gossen
Date: 2020-06-29T08:39:22Z
New Revision: d876e3202af3057cc180d7540d0de8b20234f114
URL: https://github.com/llvm/llvm-project/commit/d876e3202af3057cc180d7540d0de8b20234f114
DIFF: https://github.com/llvm/llvm-project/commit/d876e3202af3057cc180d7540d0de8b20234f114.diff
LOG: [MLIR][Shape] Lower `shape.get_extent` to `extract_element` when possible
When the origin of a shape is an extent tensor the operation `get_extent` can be
lowered directly to `extract_element`.
This choice circumvents the necessity to materialize the shape in memory.
Differential Revision: https://reviews.llvm.org/D82645
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
index 473da36a84ec..154cf6a9e1f7 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -26,3 +26,13 @@ def GetExtentShapeOfConversion : Pat<
(Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
[],
(addBenefit 10)>;
+def GetExtentFromExtentTensorConversion : Pattern<
+ (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
+ [
+ (Shape_SizeToIndexOp:$std_idx $idx),
+ (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
+ (Shape_IndexToSizeOp $std_result)
+ ],
+ [],
+ (addBenefit 10)>;
+
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index f9daadd03196..28ef190d09eb 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -143,3 +143,18 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
return %result : !shape.size
}
+// -----
+
+// Express `get_extent` as `std.extract_element` when it relies directly on the
+// outcome of a `from_extent_tensor` operation.
+// CHECK-LABEL: @get_extent_from_extent_tensor
+// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
+func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
+ %idx : !shape.size) -> !shape.size {
+ // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
+ // CHECK: return %[[RESULT]] : index
+ %shape = shape.from_extent_tensor %extents : tensor<?xindex>
+ %result = shape.get_extent %shape, %idx
+ return %result : !shape.size
+}
+
More information about the Mlir-commits
mailing list