[Mlir-commits] [mlir] 4baf18d - [MLIR][Shape] Clean up shape to standard lowering
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 01:56:07 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T08:55:50Z
New Revision: 4baf18dba26c387ca673f0ed97541ba476480688
URL: https://github.com/llvm/llvm-project/commit/4baf18dba26c387ca673f0ed97541ba476480688
DIFF: https://github.com/llvm/llvm-project/commit/4baf18dba26c387ca673f0ed97541ba476480688.diff
LOG: [MLIR][Shape] Clean up shape to standard lowering
Put only class declarations in anonymous namespaces.
Differential Revision: https://reviews.llvm.org/D84424
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index ae3874d0cb4d..5e3a60d74506 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -18,27 +18,34 @@ using namespace mlir;
using namespace mlir::shape;
namespace {
-
/// Generated conversion patterns.
#include "ShapeToStandardPatterns.inc"
+} // namespace
/// Conversion patterns.
+namespace {
class AnyOpConversion : public OpConversionPattern<AnyOp> {
public:
using OpConversionPattern<AnyOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- AnyOp::Adaptor transformed(operands);
-
- // Replace `any` with its first operand.
- // Any operand would be a valid substitution.
- rewriter.replaceOp(op, {transformed.inputs().front()});
- return success();
- }
+ ConversionPatternRewriter &rewriter) const override;
};
+} // namespace
+
+LogicalResult
+AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ AnyOp::Adaptor transformed(operands);
+
+ // Replace `any` with its first operand.
+ // Any operand would be a valid substitution.
+ rewriter.replaceOp(op, {transformed.inputs().front()});
+ return success();
+}
+namespace {
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
public:
@@ -53,98 +60,122 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
return success();
}
};
+} // namespace
+namespace {
class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
public:
using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- ShapeOfOp::Adaptor transformed(operands);
- auto loc = op.getLoc();
- auto tensorVal = transformed.arg();
- auto tensorTy = tensorVal.getType();
-
- // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
- // found in the corresponding pass.
- if (tensorTy.isa<UnrankedTensorType>())
- return failure();
-
- // Build values for individual dimensions.
- SmallVector<Value, 8> dimValues;
- auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
- int64_t rank = rankedTensorTy.getRank();
- for (int64_t i = 0; i < rank; i++) {
- if (rankedTensorTy.isDynamicDim(i)) {
- auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
- dimValues.push_back(dimVal);
- } else {
- int64_t dim = rankedTensorTy.getDimSize(i);
- auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
- dimValues.push_back(dimVal);
- }
- }
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
- // Materialize extent tensor.
- Value staticExtentTensor =
- rewriter.create<TensorFromElementsOp>(loc, dimValues);
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
- op.getType());
- return success();
+LogicalResult ShapeOfOpConversion::matchAndRewrite(
+ ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ ShapeOfOp::Adaptor transformed(operands);
+ auto loc = op.getLoc();
+ auto tensorVal = transformed.arg();
+ auto tensorTy = tensorVal.getType();
+
+ // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
+ // found in the corresponding pass.
+ if (tensorTy.isa<UnrankedTensorType>())
+ return failure();
+
+ // Build values for individual dimensions.
+ SmallVector<Value, 8> dimValues;
+ auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
+ int64_t rank = rankedTensorTy.getRank();
+ for (int64_t i = 0; i < rank; i++) {
+ if (rankedTensorTy.isDynamicDim(i)) {
+ auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
+ dimValues.push_back(dimVal);
+ } else {
+ int64_t dim = rankedTensorTy.getDimSize(i);
+ auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
+ dimValues.push_back(dimVal);
+ }
}
-};
+ // Materialize extent tensor.
+ Value staticExtentTensor =
+ rewriter.create<TensorFromElementsOp>(loc, dimValues);
+ rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
+ op.getType());
+ return success();
+}
+
+namespace {
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
public:
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
- op.value().getSExtValue());
- return success();
- }
+ ConversionPatternRewriter &rewriter) const override;
};
+} // namespace
+LogicalResult ConstSizeOpConverter::matchAndRewrite(
+ ConstSizeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
+ op.value().getSExtValue());
+ return success();
+}
+
+namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- GetExtentOp::Adaptor transformed(operands);
-
- // Derive shape extent directly from shape origin if possible.
- // This circumvents the necessity to materialize the shape in memory.
- if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
- rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
- transformed.dim());
- return success();
- }
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
- rewriter.replaceOpWithNewOp<ExtractElementOp>(
- op, rewriter.getIndexType(), transformed.shape(),
- ValueRange{transformed.dim()});
+LogicalResult GetExtentOpConverter::matchAndRewrite(
+ GetExtentOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ GetExtentOp::Adaptor transformed(operands);
+
+ // Derive shape extent directly from shape origin if possible.
+ // This circumvents the necessity to materialize the shape in memory.
+ if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
+ rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
return success();
}
-};
+ rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
+ transformed.shape(),
+ ValueRange{transformed.dim()});
+ return success();
+}
+
+namespace {
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
public:
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- shape::RankOp::Adaptor transformed(operands);
- rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(),
- 0);
- return success();
- }
+ ConversionPatternRewriter &rewriter) const override;
};
+} // namespace
+
+LogicalResult
+RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ shape::RankOp::Adaptor transformed(operands);
+ rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(), 0);
+ return success();
+}
+namespace {
/// Type conversions.
class ShapeTypeConverter : public TypeConverter {
public:
@@ -161,39 +192,42 @@ class ShapeTypeConverter : public TypeConverter {
});
}
};
+} // namespace
+namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
- void runOnOperation() override {
- // Setup type conversion.
- MLIRContext &ctx = getContext();
- ShapeTypeConverter typeConverter(&ctx);
-
- // Setup target legality.
- ConversionTarget target(ctx);
- target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
- target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- return typeConverter.isSignatureLegal(op.getType()) &&
- typeConverter.isLegal(&op.getBody());
- });
-
- // Setup conversion patterns.
- OwningRewritePatternList patterns;
- populateShapeToStandardConversionPatterns(patterns, &ctx);
- populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
-
- // Apply conversion.
- auto module = getOperation();
- if (failed(applyFullConversion(module, target, patterns)))
- signalPassFailure();
- }
+ void runOnOperation() override;
};
-
} // namespace
+void ConvertShapeToStandardPass::runOnOperation() {
+ // Setup type conversion.
+ MLIRContext &ctx = getContext();
+ ShapeTypeConverter typeConverter(&ctx);
+
+ // Setup target legality.
+ ConversionTarget target(ctx);
+ target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getType()) &&
+ typeConverter.isLegal(&op.getBody());
+ });
+
+ // Setup conversion patterns.
+ OwningRewritePatternList patterns;
+ populateShapeToStandardConversionPatterns(patterns, &ctx);
+ populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
+
+ // Apply conversion.
+ auto module = getOperation();
+ if (failed(applyFullConversion(module, target, patterns)))
+ signalPassFailure();
+}
+
void mlir::populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
populateWithGenerated(ctx, &patterns);
More information about the Mlir-commits
mailing list