[Mlir-commits] [mlir] cd1212d - [mlir] Introduced CallOp Dialect Conversion

River Riddle llvmlistbot at llvm.org
Wed Mar 18 20:07:56 PDT 2020


Author: Rob Suderman
Date: 2020-03-18T20:07:38-07:00
New Revision: cd1212deffbe110c94d1e6e80f95633653947a8a

URL: https://github.com/llvm/llvm-project/commit/cd1212deffbe110c94d1e6e80f95633653947a8a
DIFF: https://github.com/llvm/llvm-project/commit/cd1212deffbe110c94d1e6e80f95633653947a8a.diff

LOG: [mlir] Introduced CallOp Dialect Conversion

Summary:
Utility to perform CallOp Dialect conversion, specifically handling cases where
an argument type has changed and the corresponding CallOp needs to be updated.

Differential Revision: https://reviews.llvm.org/D76326

Added: 
    mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h
    mlir/lib/Conversion/StandardToStandard/CMakeLists.txt
    mlir/lib/Conversion/StandardToStandard/StandardToStandard.cpp

Modified: 
    mlir/lib/Conversion/CMakeLists.txt
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/TestDialect/TestPatterns.cpp
    mlir/test/lib/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h b/mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h
new file mode 100644
index 000000000000..a384d7c22166
--- /dev/null
+++ b/mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h
@@ -0,0 +1,31 @@
+//===- StandardToStandard.h - Std intra-dialect conversion  -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files contains patterns for lowering within the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSTANDARD_STANDARDTOSTANDARD_H_
+#define MLIR_CONVERSION_STANDARDTOSTANDARD_STANDARDTOSTANDARD_H_
+
+namespace mlir {
+
+// Forward declarations.
+class MLIRContext;
+class OwningRewritePatternList;
+class TypeConverter;
+
+/// Add a pattern to the given pattern list to convert the operand and result
+/// types of a CallOp with the given type converter.
+void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
+                                         MLIRContext *ctx,
+                                         TypeConverter &converter);
+
+} // end namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSTANDARD_STANDARDTOSTANDARD_H_

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 4634345cf43e..2f1826a1e299 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -10,5 +10,6 @@ add_subdirectory(LoopsToGPU)
 add_subdirectory(LoopToStandard)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
+add_subdirectory(StandardToStandard)
 add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToLoops)

diff  --git a/mlir/lib/Conversion/StandardToStandard/CMakeLists.txt b/mlir/lib/Conversion/StandardToStandard/CMakeLists.txt
new file mode 100644
index 000000000000..e1bc42a746ee
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToStandard/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_conversion_library(MLIRStandardToStandard
+  StandardToStandard.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/StandardToStandard
+  )
+target_link_libraries(MLIRStandardToStandard
+  PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRStandardOps
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/StandardToStandard/StandardToStandard.cpp b/mlir/lib/Conversion/StandardToStandard/StandardToStandard.cpp
new file mode 100644
index 000000000000..0138cf897563
--- /dev/null
+++ b/mlir/lib/Conversion/StandardToStandard/StandardToStandard.cpp
@@ -0,0 +1,49 @@
+//===- StandardToStandard.cpp - Std intra-dialect lowering ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToStandard/StandardToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+// Converts the operand and result types of the Standard's CallOp, used together
+// with the FuncOpSignatureConversion.
+struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
+  CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
+      : OpConversionPattern(ctx), converter(converter) {}
+
+  /// Hook for derived classes to implement combined matching and rewriting.
+  PatternMatchResult
+  matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FunctionType type = callOp.getCalleeType();
+
+    // Convert the original function results.
+    SmallVector<Type, 1> convertedResults;
+    if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+      return matchFailure();
+
+    // Substitute with the new result types from the corresponding FuncType
+    // conversion.
+    rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.callee(),
+                                        convertedResults, operands);
+    return matchSuccess();
+  }
+
+  /// The type converter to use when rewriting the signature.
+  TypeConverter &converter;
+};
+} // end anonymous namespace
+
+void mlir::populateCallOpTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &converter) {
+  patterns.insert<CallOpSignatureConversion>(ctx, converter);
+}

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 38f87dd2302c..dd8330626551 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -23,6 +23,13 @@ func @remap_input_1_to_1(%arg0: i64) {
   "test.invalid"(%arg0) : (i64) -> ()
 }
 
+// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+func @remap_call_1_to_1(%arg0: i64) {
+  // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
+  call @remap_input_1_to_1(%arg0) : (i64) -> ()
+  return
+}
+
 // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
 func @remap_input_1_to_N(%arg0: f32) -> f32 {
  // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()

diff  --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp
index c7235b8cb3a5..0b73f09c1943 100644
--- a/mlir/test/lib/TestDialect/TestPatterns.cpp
+++ b/mlir/test/lib/TestDialect/TestPatterns.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
+#include "mlir/Conversion/StandardToStandard/StandardToStandard.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -381,6 +382,8 @@ struct TestLegalizePatternDriver
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
+    mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
+                                              converter);
 
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index bc737a0a119f..61d1443869a9 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -45,6 +45,7 @@ target_link_libraries(MLIRTestTransforms
   MLIRLoopOps
   MLIRGPU
   MLIRPass
+  MLIRStandardToStandard
   MLIRTestDialect
   MLIRTransformUtils
   MLIRVectorToLoops


        


More information about the Mlir-commits mailing list