[Mlir-commits] [mlir] cd32044 - [mlir][shape] Lower Shape `ConstSizeOp` to Standard `ConstantOp`.

Alexander Belyaev llvmlistbot at llvm.org
Mon Jun 15 01:42:26 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-15T10:42:05+02:00
New Revision: cd320446f463da2c3ad0b71a587941939afc0788

URL: https://github.com/llvm/llvm-project/commit/cd320446f463da2c3ad0b71a587941939afc0788
DIFF: https://github.com/llvm/llvm-project/commit/cd320446f463da2c3ad0b71a587941939afc0788.diff

LOG: [mlir][shape] Lower Shape `ConstSizeOp` to Standard `ConstantOp`.

Differential Revision: https://reviews.llvm.org/D81735

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 81293a028489..35a5b14268e2 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -36,61 +36,72 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
 };
 
 class FromExtentTensorOpConversion
-    : public OpConversionPattern<shape::FromExtentTensorOp> {
+    : public OpConversionPattern<FromExtentTensorOp> {
 public:
-  using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
+  using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    shape::FromExtentTensorOpOperandAdaptor transformed(operands);
+    FromExtentTensorOpOperandAdaptor transformed(operands);
     rewriter.replaceOp(op.getOperation(), transformed.input());
     return success();
   }
 };
 
-class IndexToSizeOpConversion
-    : public OpConversionPattern<shape::IndexToSizeOp> {
+class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
 public:
-  using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
+  using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
+  matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    shape::IndexToSizeOpOperandAdaptor transformed(operands);
+    IndexToSizeOpOperandAdaptor transformed(operands);
     rewriter.replaceOp(op.getOperation(), transformed.arg());
     return success();
   }
 };
 
-class SizeToIndexOpConversion
-    : public OpConversionPattern<shape::SizeToIndexOp> {
+class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
 public:
-  using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+  using OpConversionPattern<SizeToIndexOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+  matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    shape::SizeToIndexOpOperandAdaptor transformed(operands);
+    SizeToIndexOpOperandAdaptor transformed(operands);
     rewriter.replaceOp(op.getOperation(), transformed.arg());
     return success();
   }
 };
 
 class ToExtentTensorOpConversion
-    : public OpConversionPattern<shape::ToExtentTensorOp> {
+    : public OpConversionPattern<ToExtentTensorOp> {
 public:
-  using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
+  using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    shape::ToExtentTensorOpOperandAdaptor transformed(operands);
+    ToExtentTensorOpOperandAdaptor transformed(operands);
     rewriter.replaceOp(op.getOperation(), transformed.input());
     return success();
   }
 };
 
+class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
+public:
+  using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
+                                                 op.value().getSExtValue());
+    return success();
+  }
+};
+
 /// Type conversions.
 class ShapeTypeConverter : public TypeConverter {
 public:
@@ -100,8 +111,8 @@ class ShapeTypeConverter : public TypeConverter {
     // Add default pass-through conversion.
     addConversion([&](Type type) { return type; });
 
-    addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
-    addConversion([ctx](shape::ShapeType type) {
+    addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
+    addConversion([ctx](ShapeType type) {
       return RankedTensorType::get({ShapedType::kDynamicSize},
                                    IndexType::get(ctx));
     });
@@ -111,9 +122,7 @@ class ShapeTypeConverter : public TypeConverter {
 /// Conversion pass.
 class ConvertShapeToStandardPass
     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
-
   void runOnOperation() override {
-
     // Setup type conversion.
     MLIRContext &ctx = getContext();
     ShapeTypeConverter typeConverter(&ctx);
@@ -146,6 +155,7 @@ void mlir::populateShapeToStandardConversionPatterns(
   patterns.insert<
       BinaryOpConversion<AddOp, AddIOp>,
       BinaryOpConversion<MulOp, MulIOp>,
+      ConstSizeOpConverter,
       FromExtentTensorOpConversion,
       IndexToSizeOpConversion,
       SizeToIndexOpConversion,

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 7c7098d76afa..1caf0051f37b 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -75,3 +75,14 @@ func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
   // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
   return
 }
+
+// -----
+
+// Convert `const_size` to `constant` op.
+// CHECK-LABEL: @size_const
+func @size_const() -> !shape.size {
+  %c1 = shape.const_size 1
+  return %c1 : !shape.size
+}
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: return %[[C1]] : index


        


More information about the Mlir-commits mailing list