[Mlir-commits] [mlir] 3713314 - [MLIR] Shape to standard dialect lowering

Frederik Gossen llvmlistbot at llvm.org
Wed Jun 3 09:17:29 PDT 2020


Author: Frederik Gossen
Date: 2020-06-03T16:17:03Z
New Revision: 3713314bfae3dc9a793c152e8a698a2ca1444f6c

URL: https://github.com/llvm/llvm-project/commit/3713314bfae3dc9a793c152e8a698a2ca1444f6c
DIFF: https://github.com/llvm/llvm-project/commit/3713314bfae3dc9a793c152e8a698a2ca1444f6c.diff

LOG: [MLIR] Shape to standard dialect lowering

Add a new pass to lower operations from the `shape` to the `std` dialect.
The conversion applies only to the `size_to_index` and `index_to_size`
operations and affected types.
Other patterns will be added as needed.

Differential Revision: https://reviews.llvm.org/D81091

Added: 
    mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
    mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f83e913e484f..4bcfd8d34aa2 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,16 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
   let constructor = "mlir::createParallelLoopToGpuPass()";
 }
 
+//===----------------------------------------------------------------------===//
+// ShapeToStandard
+//===----------------------------------------------------------------------===//
+
+def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
+  let summary = "Convert operations from the shape dialect into the standard "
+                "dialect";
+  let constructor = "mlir::createConvertShapeToStandardPass()";
+}
+
 //===----------------------------------------------------------------------===//
 // StandardToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
new file mode 100644
index 000000000000..6b5ce8bfd274
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -0,0 +1,28 @@
+//===- ShapeToStandard.h - Conversion utils from shape to std dialect -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
+#define MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+class MLIRContext;
+class ModuleOp;
+template <typename T>
+class OperationPass;
+
+void populateShapeToStandardConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 95f9ce1c4e1f..c0c79d1c69ef 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -25,6 +25,7 @@
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8b70e6523106..c99dceec31cc 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(SCFToGPU)
 add_subdirectory(SCFToStandard)
+add_subdirectory(ShapeToStandard)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
 add_subdirectory(VectorToLLVM)

diff  --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
new file mode 100644
index 000000000000..8750c331859e
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_conversion_library(MLIRShapeToStandard
+  ShapeToStandard.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToStandard
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIREDSC
+  MLIRIR
+  MLIRShape
+  MLIRPass
+  MLIRSCF
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
new file mode 100644
index 000000000000..0083ebdfa21b
--- /dev/null
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -0,0 +1,106 @@
+//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace {
+
+/// Conversion patterns.
+class SizeToIndexOpConversion
+    : public OpConversionPattern<shape::SizeToIndexOp> {
+public:
+  using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::SizeToIndexOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.arg());
+    return success();
+  }
+};
+
+class IndexToSizeOpConversion
+    : public OpConversionPattern<shape::IndexToSizeOp> {
+public:
+  using OpConversionPattern<shape::IndexToSizeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(shape::IndexToSizeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    shape::IndexToSizeOpOperandAdaptor transformed(operands);
+    rewriter.replaceOp(op.getOperation(), transformed.arg());
+    return success();
+  }
+};
+
+/// Type conversions.
+class ShapeTypeConverter : public TypeConverter {
+public:
+  using TypeConverter::convertType;
+
+  ShapeTypeConverter(MLIRContext *ctx) {
+    // Add default pass-through conversion.
+    addConversion([&](Type type) { return type; });
+    addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
+  }
+};
+
+/// Conversion pass.
+class ConvertShapeToStandardPass
+    : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
+
+  void runOnOperation() override {
+
+    // Setup type conversion.
+    MLIRContext &ctx = getContext();
+    ShapeTypeConverter typeConverter(&ctx);
+
+    // Setup target legality.
+    ConversionTarget target(ctx);
+    target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
+    target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getType());
+    });
+
+    // Setup conversion patterns.
+    OwningRewritePatternList patterns;
+    populateShapeToStandardConversionPatterns(patterns, &ctx);
+    populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);
+
+    // Apply conversion.
+    auto module = getOperation();
+    if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+void populateShapeToStandardConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  // clang-format off
+  patterns.insert<
+      IndexToSizeOpConversion,
+      SizeToIndexOpConversion>(ctx);
+  // clang-format on
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass() {
+  return std::make_unique<ConvertShapeToStandardPass>();
+}
+
+} // namespace mlir

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
new file mode 100644
index 000000000000..c27b408dfef1
--- /dev/null
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s --dump-input-on-failure
+
+// Convert `size` to `index` type.
+// CHECK-LABEL: @size_id
+// CHECK-SAME: (%[[SIZE:.*]]: index)
+func @size_id(%size : !shape.size) -> !shape.size {
+  // CHECK: return %[[SIZE]] : index
+  return %size : !shape.size
+}
+
+// -----
+
+// Lower `size_to_index` conversion to no-op.
+// CHECK-LABEL: @size_to_index
+// CHECK-SAME: (%[[SIZE:.*]]: index) -> index
+func @size_to_index(%size : !shape.size) -> index {
+  // CHECK-NEXT: return %[[SIZE]] : index
+  %index = shape.size_to_index %size
+  return %index : index
+}
+
+// -----
+
+// Lower `index_to_size` conversion to no-op.
+// CHECK-LABEL: @index_to_size
+// CHECK-SAME: (%[[INDEX:.*]]: index) -> index
+func @index_to_size(%index : index) -> !shape.size {
+  // CHECK-NEXT: return %[[INDEX]] : index
+  %size = shape.index_to_size %index
+  return %size : !shape.size
+}


        


More information about the Mlir-commits mailing list