[Mlir-commits] [mlir] 8577a09 - [MLIR][Shape] Fix lowering of `shape.get_extent`
Frederik Gossen
llvmlistbot at llvm.org
Tue Jun 30 01:36:09 PDT 2020
Author: Frederik Gossen
Date: 2020-06-30T08:35:24Z
New Revision: 8577a090f5f04e18d72bb2dd387e60082e4da0ca
URL: https://github.com/llvm/llvm-project/commit/8577a090f5f04e18d72bb2dd387e60082e4da0ca
DIFF: https://github.com/llvm/llvm-project/commit/8577a090f5f04e18d72bb2dd387e60082e4da0ca.diff
LOG: [MLIR][Shape] Fix lowering of `shape.get_extent`
The declarative conversion patterns caused crashes in the asan configuration.
The non-declarative implementation circumvents this.
Differential Revision: https://reviews.llvm.org/D82797
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 5fd9be0bd73a..7ebcb397349d 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -90,6 +90,29 @@ class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
}
};
+class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
+ using OpConversionPattern<GetExtentOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ GetExtentOp::Adaptor transformed(operands);
+
+ // Derive shape extent directly from shape origin if possible.
+ // This circumvents the necessity to materialize the shape in memory.
+ if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
+ rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
+ transformed.dim());
+ return success();
+ }
+
+ rewriter.replaceOpWithNewOp<ExtractElementOp>(
+ op, rewriter.getIndexType(), transformed.shape(),
+ ValueRange{transformed.dim()});
+ return success();
+ }
+};
+
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
public:
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
@@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns(
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
ConstSizeOpConverter,
+ GetExtentOpConverter,
RankOpConverter,
ShapeOfOpConversion>(ctx);
// clang-format on
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
index 154cf6a9e1f7..a1335487f5ab 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat<
(Shape_SizeToIndexOp $arg),
(replaceWithValue $arg)>;
-// Derive shape extent directly from shape origin if possible.
-// This circumvents the necessity to materialize the shape in memory.
-def GetExtentShapeOfConversion : Pat<
- (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
- (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
- [],
- (addBenefit 10)>;
-def GetExtentFromExtentTensorConversion : Pattern<
- (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
- [
- (Shape_SizeToIndexOp:$std_idx $idx),
- (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
- (Shape_IndexToSizeOp $std_result)
- ],
- [],
- (addBenefit 10)>;
-
More information about the Mlir-commits
mailing list