[Mlir-commits] [mlir] 8577a09 - [MLIR][Shape] Fix lowering of `shape.get_extent`

Frederik Gossen llvmlistbot at llvm.org
Tue Jun 30 01:36:09 PDT 2020


Author: Frederik Gossen
Date: 2020-06-30T08:35:24Z
New Revision: 8577a090f5f04e18d72bb2dd387e60082e4da0ca

URL: https://github.com/llvm/llvm-project/commit/8577a090f5f04e18d72bb2dd387e60082e4da0ca
DIFF: https://github.com/llvm/llvm-project/commit/8577a090f5f04e18d72bb2dd387e60082e4da0ca.diff

LOG: [MLIR][Shape] Fix lowering of `shape.get_extent`

The declarative conversion patterns caused crashes in the asan configuration.
The non-declarative implementation circumvents this.

Differential Revision: https://reviews.llvm.org/D82797

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 5fd9be0bd73a..7ebcb397349d 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -90,6 +90,29 @@ class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
   }
 };
 
+class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
+  using OpConversionPattern<GetExtentOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    GetExtentOp::Adaptor transformed(operands);
+
+    // Derive shape extent directly from shape origin if possible.
+    // This circumvents the necessity to materialize the shape in memory.
+    if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
+      rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
+                                         transformed.dim());
+      return success();
+    }
+
+    rewriter.replaceOpWithNewOp<ExtractElementOp>(
+        op, rewriter.getIndexType(), transformed.shape(),
+        ValueRange{transformed.dim()});
+    return success();
+  }
+};
+
 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
 public:
   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
@@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       BinaryOpConversion<AddOp, AddIOp>,
       BinaryOpConversion<MulOp, MulIOp>,
       ConstSizeOpConverter,
+      GetExtentOpConverter,
       RankOpConverter,
       ShapeOfOpConversion>(ctx);
   // clang-format on

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
index 154cf6a9e1f7..a1335487f5ab 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat<
     (Shape_SizeToIndexOp $arg),
     (replaceWithValue $arg)>;
 
-// Derive shape extent directly from shape origin if possible.
-// This circumvents the necessity to materialize the shape in memory.
-def GetExtentShapeOfConversion : Pat<
-    (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
-    (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
-    [],
-    (addBenefit 10)>;
-def GetExtentFromExtentTensorConversion : Pattern<
-    (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
-    [
-      (Shape_SizeToIndexOp:$std_idx $idx),
-      (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
-      (Shape_IndexToSizeOp $std_result)
-    ],
-    [],
-    (addBenefit 10)>;
-


        


More information about the Mlir-commits mailing list