[Mlir-commits] [mlir] cd32044 - [mlir][shape] Lower Shape `ConstSizeOp` to Standard `ConstantOp`.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Jun 15 01:42:26 PDT 2020
Author: Alexander Belyaev
Date: 2020-06-15T10:42:05+02:00
New Revision: cd320446f463da2c3ad0b71a587941939afc0788
URL: https://github.com/llvm/llvm-project/commit/cd320446f463da2c3ad0b71a587941939afc0788
DIFF: https://github.com/llvm/llvm-project/commit/cd320446f463da2c3ad0b71a587941939afc0788.diff
LOG: [mlir][shape] Lower Shape `ConstSizeOp` to Standard `ConstantOp`.
Differential Revision: https://reviews.llvm.org/D81735
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 81293a028489..35a5b14268e2 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -36,61 +36,72 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
};
class FromExtentTensorOpConversion
- : public OpConversionPattern<shape::FromExtentTensorOp> {
+ : public OpConversionPattern<FromExtentTensorOp> {
public:
- using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
+ using OpConversionPattern<FromExtentTensorOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
+ matchAndRewrite(FromExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- shape::FromExtentTensorOpOperandAdaptor transformed(operands);
+ FromExtentTensorOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
};
-class IndexToSizeOpConversion
- : public OpConversionPattern<shape::IndexToSizeOp> {
+class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
public:
- using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
+ using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
+ matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- shape::IndexToSizeOpOperandAdaptor transformed(operands);
+ IndexToSizeOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.arg());
return success();
}
};
-class SizeToIndexOpConversion
- : public OpConversionPattern<shape::SizeToIndexOp> {
+class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
public:
- using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+ using OpConversionPattern<SizeToIndexOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+ matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- shape::SizeToIndexOpOperandAdaptor transformed(operands);
+ SizeToIndexOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.arg());
return success();
}
};
class ToExtentTensorOpConversion
- : public OpConversionPattern<shape::ToExtentTensorOp> {
+ : public OpConversionPattern<ToExtentTensorOp> {
public:
- using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
+ using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+ matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- shape::ToExtentTensorOpOperandAdaptor transformed(operands);
+ ToExtentTensorOpOperandAdaptor transformed(operands);
rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
};
+class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
+public:
+ using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
+ op.value().getSExtValue());
+ return success();
+ }
+};
+
/// Type conversions.
class ShapeTypeConverter : public TypeConverter {
public:
@@ -100,8 +111,8 @@ class ShapeTypeConverter : public TypeConverter {
// Add default pass-through conversion.
addConversion([&](Type type) { return type; });
- addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
- addConversion([ctx](shape::ShapeType type) {
+ addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
+ addConversion([ctx](ShapeType type) {
return RankedTensorType::get({ShapedType::kDynamicSize},
IndexType::get(ctx));
});
@@ -111,9 +122,7 @@ class ShapeTypeConverter : public TypeConverter {
/// Conversion pass.
class ConvertShapeToStandardPass
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
-
void runOnOperation() override {
-
// Setup type conversion.
MLIRContext &ctx = getContext();
ShapeTypeConverter typeConverter(&ctx);
@@ -146,6 +155,7 @@ void mlir::populateShapeToStandardConversionPatterns(
patterns.insert<
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
+ ConstSizeOpConverter,
FromExtentTensorOpConversion,
IndexToSizeOpConversion,
SizeToIndexOpConversion,
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 7c7098d76afa..1caf0051f37b 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -75,3 +75,14 @@ func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
// CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index
return
}
+
+// -----
+
+// Convert `const_size` to `constant` op.
+// CHECK-LABEL: @size_const
+func @size_const() -> !shape.size {
+ %c1 = shape.const_size 1
+ return %c1 : !shape.size
+}
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: return %[[C1]] : index
More information about the Mlir-commits
mailing list