[Mlir-commits] [mlir] 3713314 - [MLIR] Shape to standard dialect lowering
Frederik Gossen
llvmlistbot at llvm.org
Wed Jun 3 09:17:29 PDT 2020
Author: Frederik Gossen
Date: 2020-06-03T16:17:03Z
New Revision: 3713314bfae3dc9a793c152e8a698a2ca1444f6c
URL: https://github.com/llvm/llvm-project/commit/3713314bfae3dc9a793c152e8a698a2ca1444f6c
DIFF: https://github.com/llvm/llvm-project/commit/3713314bfae3dc9a793c152e8a698a2ca1444f6c.diff
LOG: [MLIR] Shape to standard dialect lowering
Add a new pass to lower operations from the `shape` to the `std` dialect.
The conversion applies only to the `size_to_index` and `index_to_size`
operations and affected types.
Other patterns will be added as needed.
Differential Revision: https://reviews.llvm.org/D81091
Added:
mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f83e913e484f..4bcfd8d34aa2 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,16 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
let constructor = "mlir::createParallelLoopToGpuPass()";
}
+//===----------------------------------------------------------------------===//
+// ShapeToStandard
+//===----------------------------------------------------------------------===//
+
+def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
+ let summary = "Convert operations from the shape dialect into the standard "
+ "dialect";
+ let constructor = "mlir::createConvertShapeToStandardPass()";
+}
+
//===----------------------------------------------------------------------===//
// StandardToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
new file mode 100644
index 000000000000..6b5ce8bfd274
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -0,0 +1,28 @@
+//===- ShapeToStandard.h - Conversion utils from shape to std dialect -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
+#define MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+class MLIRContext;
+class ModuleOp;
+template <typename T>
+class OperationPass;
+
+void populateShapeToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 95f9ce1c4e1f..c0c79d1c69ef 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -25,6 +25,7 @@
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8b70e6523106..c99dceec31cc 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToStandard)
+add_subdirectory(ShapeToStandard)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
new file mode 100644
index 000000000000..8750c331859e
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_conversion_library(MLIRShapeToStandard
+ ShapeToStandard.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToStandard
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIREDSC
+ MLIRIR
+ MLIRShape
+ MLIRPass
+ MLIRSCF
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
new file mode 100644
index 000000000000..0083ebdfa21b
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -0,0 +1,106 @@
+//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace {
+
+/// Conversion patterns.
+class SizeToIndexOpConversion
+ : public OpConversionPattern<shape::SizeToIndexOp> {
+public:
+ using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ shape::SizeToIndexOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.arg());
+ return success();
+ }
+};
+
+class IndexToSizeOpConversion
+ : public OpConversionPattern<shape::IndexToSizeOp> {
+public:
+ using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ shape::IndexToSizeOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.arg());
+ return success();
+ }
+};
+
+/// Type conversions.
+class ShapeTypeConverter : public TypeConverter {
+public:
+ using TypeConverter::convertType;
+
+ ShapeTypeConverter(MLIRContext *ctx) {
+ // Add default pass-through conversion.
+ addConversion([&](Type type) { return type; });
+ addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
+ }
+};
+
+/// Conversion pass.
+class ConvertShapeToStandardPass
+ : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
+
+ void runOnOperation() override {
+
+ // Setup type conversion.
+ MLIRContext &ctx = getContext();
+ ShapeTypeConverter typeConverter(&ctx);
+
+ // Setup target legality.
+ ConversionTarget target(ctx);
+ target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
+ target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getType());
+ });
+
+ // Setup conversion patterns.
+ OwningRewritePatternList patterns;
+ populateShapeToStandardConversionPatterns(patterns, &ctx);
+ populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
+
+ // Apply conversion.
+ auto module = getOperation();
+ if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void populateShapeToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ // clang-format off
+ patterns.insert<
+ IndexToSizeOpConversion,
+ SizeToIndexOpConversion>(ctx);
+ // clang-format on
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass() {
+ return std::make_unique<ConvertShapeToStandardPass>();
+}
+
+} // namespace mlir
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
new file mode 100644
index 000000000000..c27b408dfef1
--- /dev/null
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s --dump-input-on-failure
+
+// Convert `size` to `index` type.
+// CHECK-LABEL: @size_id
+// CHECK-SAME: (%[[SIZE:.*]]: index)
+func @size_id(%size : !shape.size) -> !shape.size {
+ // CHECK: return %[[SIZE]] : index
+ return %size : !shape.size
+}
+
+// -----
+
+// Lower `size_to_index` conversion to no-op.
+// CHECK-LABEL: @size_to_index
+// CHECK-SAME: (%[[SIZE:.*]]: index) -> index
+func @size_to_index(%size : !shape.size) -> index {
+ // CHECK-NEXT: return %[[SIZE]] : index
+ %index = shape.size_to_index %size
+ return %index : index
+}
+
+// -----
+
+// Lower `index_to_size` conversion to no-op.
+// CHECK-LABEL: @index_to_size
+// CHECK-SAME: (%[[INDEX:.*]]: index) -> index
+func @index_to_size(%index : index) -> !shape.size {
+ // CHECK-NEXT: return %[[INDEX]] : index
+ %size = shape.index_to_size %index
+ return %size : !shape.size
+}
More information about the Mlir-commits
mailing list