[Mlir-commits] [mlir] 4baf18d - [MLIR][Shape] Clean up shape to standard lowering

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 01:56:07 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T08:55:50Z
New Revision: 4baf18dba26c387ca673f0ed97541ba476480688

URL: https://github.com/llvm/llvm-project/commit/4baf18dba26c387ca673f0ed97541ba476480688
DIFF: https://github.com/llvm/llvm-project/commit/4baf18dba26c387ca673f0ed97541ba476480688.diff

LOG: [MLIR][Shape] Clean up shape to standard lowering

Put only class declarations in anonymous namespaces.

Differential Revision: https://reviews.llvm.org/D84424

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index ae3874d0cb4d..5e3a60d74506 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -18,27 +18,34 @@ using namespace mlir;
 using namespace mlir::shape;
 
 namespace {
-
 /// Generated conversion patterns.
 #include "ShapeToStandardPatterns.inc"
+} // namespace
 
 /// Conversion patterns.
+namespace {
 class AnyOpConversion : public OpConversionPattern<AnyOp> {
 public:
   using OpConversionPattern<AnyOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    AnyOp::Adaptor transformed(operands);
-
-    // Replace `any` with its first operand.
-    // Any operand would be a valid substitution.
-    rewriter.replaceOp(op, {transformed.inputs().front()});
-    return success();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
+} // namespace
+
+LogicalResult
+AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
+                                 ConversionPatternRewriter &rewriter) const {
+  AnyOp::Adaptor transformed(operands);
+
+  // Replace `any` with its first operand.
+  // Any operand would be a valid substitution.
+  rewriter.replaceOp(op, {transformed.inputs().front()});
+  return success();
+}
 
+namespace {
 template <typename SrcOpTy, typename DstOpTy>
 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
 public:
@@ -53,98 +60,122 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
     return success();
   }
 };
+} // namespace
 
+namespace {
 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
 public:
   using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    ShapeOfOp::Adaptor transformed(operands);
-    auto loc = op.getLoc();
-    auto tensorVal = transformed.arg();
-    auto tensorTy = tensorVal.getType();
-
-    // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
-    // found in the corresponding pass.
-    if (tensorTy.isa<UnrankedTensorType>())
-      return failure();
-
-    // Build values for individual dimensions.
-    SmallVector<Value, 8> dimValues;
-    auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
-    int64_t rank = rankedTensorTy.getRank();
-    for (int64_t i = 0; i < rank; i++) {
-      if (rankedTensorTy.isDynamicDim(i)) {
-        auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
-        dimValues.push_back(dimVal);
-      } else {
-        int64_t dim = rankedTensorTy.getDimSize(i);
-        auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
-        dimValues.push_back(dimVal);
-      }
-    }
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
 
-    // Materialize extent tensor.
-    Value staticExtentTensor =
-        rewriter.create<TensorFromElementsOp>(loc, dimValues);
-    rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
-                                              op.getType());
-    return success();
+LogicalResult ShapeOfOpConversion::matchAndRewrite(
+    ShapeOfOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  ShapeOfOp::Adaptor transformed(operands);
+  auto loc = op.getLoc();
+  auto tensorVal = transformed.arg();
+  auto tensorTy = tensorVal.getType();
+
+  // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
+  // found in the corresponding pass.
+  if (tensorTy.isa<UnrankedTensorType>())
+    return failure();
+
+  // Build values for individual dimensions.
+  SmallVector<Value, 8> dimValues;
+  auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
+  int64_t rank = rankedTensorTy.getRank();
+  for (int64_t i = 0; i < rank; i++) {
+    if (rankedTensorTy.isDynamicDim(i)) {
+      auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
+      dimValues.push_back(dimVal);
+    } else {
+      int64_t dim = rankedTensorTy.getDimSize(i);
+      auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
+      dimValues.push_back(dimVal);
+    }
   }
-};
 
+  // Materialize extent tensor.
+  Value staticExtentTensor =
+      rewriter.create<TensorFromElementsOp>(loc, dimValues);
+  rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
+                                            op.getType());
+  return success();
+}
+
+namespace {
 class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
 public:
   using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
-                                                 op.value().getSExtValue());
-    return success();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
+} // namespace
 
+LogicalResult ConstSizeOpConverter::matchAndRewrite(
+    ConstSizeOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(),
+                                               op.value().getSExtValue());
+  return success();
+}
+
+namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    GetExtentOp::Adaptor transformed(operands);
-
-    // Derive shape extent directly from shape origin if possible.
-    // This circumvents the necessity to materialize the shape in memory.
-    if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
-      rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
-                                         transformed.dim());
-      return success();
-    }
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
 
-    rewriter.replaceOpWithNewOp<ExtractElementOp>(
-        op, rewriter.getIndexType(), transformed.shape(),
-        ValueRange{transformed.dim()});
+LogicalResult GetExtentOpConverter::matchAndRewrite(
+    GetExtentOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  GetExtentOp::Adaptor transformed(operands);
+
+  // Derive shape extent directly from shape origin if possible.
+  // This circumvents the necessity to materialize the shape in memory.
+  if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
+    rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
     return success();
   }
-};
 
+  rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
+                                                transformed.shape(),
+                                                ValueRange{transformed.dim()});
+  return success();
+}
+
+namespace {
 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
 public:
   using OpConversionPattern<shape::RankOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    shape::RankOp::Adaptor transformed(operands);
-    rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(),
-                                       0);
-    return success();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
+} // namespace
+
+LogicalResult
+RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
+                                 ConversionPatternRewriter &rewriter) const {
+  shape::RankOp::Adaptor transformed(operands);
+  rewriter.replaceOpWithNewOp<DimOp>(op.getOperation(), transformed.shape(), 0);
+  return success();
+}
 
+namespace {
 /// Type conversions.
 class ShapeTypeConverter : public TypeConverter {
 public:
@@ -161,39 +192,42 @@ class ShapeTypeConverter : public TypeConverter {
     });
   }
 };
+} // namespace
 
+namespace {
 /// Conversion pass.
 class ConvertShapeToStandardPass
     : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
 
-  void runOnOperation() override {
-    // Setup type conversion.
-    MLIRContext &ctx = getContext();
-    ShapeTypeConverter typeConverter(&ctx);
-
-    // Setup target legality.
-    ConversionTarget target(ctx);
-    target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
-    target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
-    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-      return typeConverter.isSignatureLegal(op.getType()) &&
-             typeConverter.isLegal(&op.getBody());
-    });
-
-    // Setup conversion patterns.
-    OwningRewritePatternList patterns;
-    populateShapeToStandardConversionPatterns(patterns, &ctx);
-    populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
-
-    // Apply conversion.
-    auto module = getOperation();
-    if (failed(applyFullConversion(module, target, patterns)))
-      signalPassFailure();
-  }
+  void runOnOperation() override;
 };
-
 } // namespace
 
+void ConvertShapeToStandardPass::runOnOperation() {
+  // Setup type conversion.
+  MLIRContext &ctx = getContext();
+  ShapeTypeConverter typeConverter(&ctx);
+
+  // Setup target legality.
+  ConversionTarget target(ctx);
+  target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
+  target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+    return typeConverter.isSignatureLegal(op.getType()) &&
+           typeConverter.isLegal(&op.getBody());
+  });
+
+  // Setup conversion patterns.
+  OwningRewritePatternList patterns;
+  populateShapeToStandardConversionPatterns(patterns, &ctx);
+  populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
+
+  // Apply conversion.
+  auto module = getOperation();
+  if (failed(applyFullConversion(module, target, patterns)))
+    signalPassFailure();
+}
+
 void mlir::populateShapeToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   populateWithGenerated(ctx, &patterns);


        


More information about the Mlir-commits mailing list