[Mlir-commits] [mlir] 2ea7fb7 - [MLIR] Add ComplexToStandard conversion pass.

Adrian Kuegel llvmlistbot at llvm.org
Wed Apr 28 05:18:17 PDT 2021


Author: Adrian Kuegel
Date: 2021-04-28T14:17:46+02:00
New Revision: 2ea7fb7b1c045a7d60fcccf3df3ebb26aa3699e5

URL: https://github.com/llvm/llvm-project/commit/2ea7fb7b1c045a7d60fcccf3df3ebb26aa3699e5
DIFF: https://github.com/llvm/llvm-project/commit/2ea7fb7b1c045a7d60fcccf3df3ebb26aa3699e5.diff

LOG: [MLIR] Add ComplexToStandard conversion pass.

So far, only a conversion for complex::AbsOp is done, but more will be added.

Differential Revision: https://reviews.llvm.org/D101442

Added: 
    mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h
    mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Conversion/ComplexToStandard/full-conversion.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h
new file mode 100644
index 0000000000000..285881b4d2c8c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ComplexToStandard/ComplexToStandard.h
@@ -0,0 +1,29 @@
+//===- ComplexToStandard.h - Utils to convert from the complex dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
+#define MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
+
+#include <memory>
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class FuncOp;
+class RewritePatternSet;
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Complex to Standard.
+void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns);
+
+/// Create a pass to convert Complex operations to the Standard dialect.
+std::unique_ptr<OperationPass<FuncOp>> createConvertComplexToStandardPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 64de7c962beed..c8f8baa87863e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index eb940d3414049..9ca99e9171392 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -99,6 +99,20 @@ def ConvertComplexToLLVM : Pass<"convert-complex-to-llvm", "ModuleOp"> {
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexToStandard
+//===----------------------------------------------------------------------===//
+
+def ConvertComplexToStandard : FunctionPass<"convert-complex-to-standard"> {
+  let summary = "Convert Complex dialect to standard dialect";
+  let constructor = "mlir::createConvertComplexToStandardPass()";
+  let dependentDialects = [
+    "complex::ComplexDialect",
+    "math::MathDialect",
+    "StandardOpsDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // GPUCommon
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 60dbab0a04432..1cf9d304151e0 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(AffineToStandard)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(ComplexToLLVM)
+add_subdirectory(ComplexToStandard)
 add_subdirectory(GPUCommon)
 add_subdirectory(GPUToNVVM)
 add_subdirectory(GPUToROCDL)

diff  --git a/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt
new file mode 100644
index 0000000000000..ba1f04a8f7102
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToStandard/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRComplexToStandard
+  ComplexToStandard.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToStandard
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRComplex
+  MLIRIR
+  MLIRMath
+  MLIRStandard
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
new file mode 100644
index 0000000000000..15fa25441db62
--- /dev/null
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -0,0 +1,77 @@
+//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
+
+#include <memory>
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
+  using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::AbsOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::AbsOp::Adaptor transformed(operands);
+    auto loc = op.getLoc();
+    auto type = op.getType();
+
+    Value real =
+        rewriter.create<complex::ReOp>(loc, type, transformed.complex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, type, transformed.complex());
+    Value realSqr = rewriter.create<MulFOp>(loc, real, real);
+    Value imagSqr = rewriter.create<MulFOp>(loc, imag, imag);
+    Value sqNorm = rewriter.create<AddFOp>(loc, realSqr, imagSqr);
+
+    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateComplexToStandardConversionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<AbsOpConversion>(patterns.getContext());
+}
+
+namespace {
+struct ConvertComplexToStandardPass
+    : public ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
+  void runOnFunction() override;
+};
+
+void ConvertComplexToStandardPass::runOnFunction() {
+  auto function = getFunction();
+
+  // Convert to the Standard dialect using the converter defined above.
+  RewritePatternSet patterns(&getContext());
+  populateComplexToStandardConversionPatterns(patterns);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<StandardOpsDialect, math::MathDialect,
+                         complex::ComplexDialect>();
+  target.addIllegalOp<complex::AbsOp>();
+  if (failed(applyPartialConversion(function, target, std::move(patterns))))
+    signalPassFailure();
+}
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertComplexToStandardPass() {
+  return std::make_unique<ConvertComplexToStandardPass>();
+}

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index ad9224a436ebb..7ff8c903348e5 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -36,6 +36,10 @@ namespace NVVM {
 class NVVMDialect;
 } // end namespace NVVM
 
+namespace math {
+class MathDialect;
+} // end namespace math
+
 namespace memref {
 class MemRefDialect;
 } // end namespace memref

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
new file mode 100644
index 0000000000000..788d42557883b
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-complex-to-standard | FileCheck %s
+
+// CHECK-LABEL: func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_abs(%arg: complex<f32>) -> f32 {
+  %abs = complex.abs %arg: complex<f32>
+  return %abs : f32
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[REAL_SQ:.*]] = mulf %[[REAL]], %[[REAL]] : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_NORM:.*]] = addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: return %[[NORM]] : f32
+

diff  --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
new file mode 100644
index 0000000000000..2fd46b4d02264
--- /dev/null
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-complex-to-standard -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s
+
+// CHECK-LABEL: llvm.func @complex_abs
+// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
+func @complex_abs(%arg: complex<f32>) -> f32 {
+  %abs = complex.abs %arg: complex<f32>
+  return %abs : f32
+}
+// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
+// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]]  : f32
+// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]]  : f32
+// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]]  : f32
+// CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: llvm.return %[[NORM]] : f32
+


        


More information about the Mlir-commits mailing list