[Mlir-commits] [mlir] b6b9d3e - [MLIR][Shape] Remove type conversion from lowering to standard
Frederik Gossen
llvmlistbot at llvm.org
Wed Jul 29 03:48:22 PDT 2020
Author: Frederik Gossen
Date: 2020-07-29T10:48:05Z
New Revision: b6b9d3ea85cc158a5230b4b75147d299bfc372df
URL: https://github.com/llvm/llvm-project/commit/b6b9d3ea85cc158a5230b4b75147d299bfc372df
DIFF: https://github.com/llvm/llvm-project/commit/b6b9d3ea85cc158a5230b4b75147d299bfc372df.diff
LOG: [MLIR][Shape] Remove type conversion from lowering to standard
Operating on indices and extent tensors directly, the type conversion is no
longer needed for the supported cases.
Differential Revision: https://reviews.llvm.org/D84442
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index efeaa18e17c1..4deaa8cd2df3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -219,25 +219,6 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
return success();
}
-namespace {
-/// Type conversions.
-class ShapeTypeConverter : public TypeConverter {
-public:
- using TypeConverter::convertType;
-
- ShapeTypeConverter(MLIRContext *ctx) {
- // Add default pass-through conversion.
- addConversion([&](Type type) { return type; });
-
- addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
- addConversion([ctx](ShapeType type) {
- return RankedTensorType::get({ShapedType::kDynamicSize},
- IndexType::get(ctx));
- });
- }
-};
-} // namespace
-
namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
@@ -248,23 +229,15 @@ class ConvertShapeToStandardPass
} // namespace
void ConvertShapeToStandardPass::runOnOperation() {
- // Setup type conversion.
- MLIRContext &ctx = getContext();
- ShapeTypeConverter typeConverter(&ctx);
-
// Setup target legality.
+ MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
- target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
- target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- return typeConverter.isSignatureLegal(op.getType()) &&
- typeConverter.isLegal(&op.getBody());
- });
+ target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
OwningRewritePatternList patterns;
populateShapeToStandardConversionPatterns(patterns, &ctx);
- populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
// Apply conversion.
auto module = getOperation();
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index b94b24599351..0e30cc2bdf56 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -1,40 +1,11 @@
// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s
-// Convert `size` to `index` type.
-// CHECK-LABEL: @size_id
-// CHECK-SAME: (%[[SIZE:.*]]: index)
-func @size_id(%size : !shape.size) -> !shape.size {
- // CHECK: return %[[SIZE]] : index
- return %size : !shape.size
-}
-
-// -----
-
-// Convert `shape` to `tensor<?xindex>` type.
-// CHECK-LABEL: @shape_id
-// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>)
-func @shape_id(%shape : !shape.shape) -> !shape.shape {
- // CHECK: return %[[SHAPE]] : tensor<?xindex>
- return %shape : !shape.shape
-}
-
-// -----
-
-// Lower binary ops.
-// CHECK-LABEL: @binary_ops
-// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
-func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
- // CHECK: addi %[[LHS]], %[[RHS]] : index
- %sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size
- return
-}
-
-// -----
-
// Lower binary ops.
// CHECK-LABEL: @binary_ops
// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
func @binary_ops(%lhs : index, %rhs : index) {
+ // CHECK: addi %[[LHS]], %[[RHS]] : index
+ %sum = shape.add %lhs, %rhs : index, index -> index
// CHECK: muli %[[LHS]], %[[RHS]] : index
%product = shape.mul %lhs, %rhs : index, index -> index
return
More information about the Mlir-commits
mailing list