[flang-commits] [flang] 97d8972 - [flang][fir] Add the pre-code gen rewrite pass and codegen ops.

Eric Schweitz via flang-commits flang-commits at lists.llvm.org
Wed Mar 24 19:27:24 PDT 2021


Author: Eric Schweitz
Date: 2021-03-24T19:27:10-07:00
New Revision: 97d8972c9cd1295fe838b0d0d1be4cefe2dd0b1c

URL: https://github.com/llvm/llvm-project/commit/97d8972c9cd1295fe838b0d0d1be4cefe2dd0b1c
DIFF: https://github.com/llvm/llvm-project/commit/97d8972c9cd1295fe838b0d0d1be4cefe2dd0b1c.diff

LOG: [flang][fir] Add the pre-code gen rewrite pass and codegen ops.

Before the conversion to LLVM-IR dialect and ultimately LLVM IR, FIR is
partially rewritten into a codegen form.  This patch adds that pass, the
fircg dialect, and the small set of Ops in the fircg (sub) dialect.
Fircg is not part of the FIR dialect and should never be used outside of
the (closed) conversion to LLVM IR.

Authors: Eric Schweitz, Jean Perier, Rajan Walia, et.al.

Differential Revision: https://reviews.llvm.org/D98063

Added: 
    flang/include/flang/Optimizer/CodeGen/CGOps.td
    flang/lib/Optimizer/CodeGen/CGOps.cpp
    flang/lib/Optimizer/CodeGen/CGOps.h
    flang/lib/Optimizer/CodeGen/PassDetail.h
    flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp
    flang/test/Fir/cg-ops.fir

Modified: 
    flang/include/flang/Optimizer/CodeGen/CGPasses.td
    flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
    flang/include/flang/Optimizer/Dialect/FIRDialect.h
    flang/include/flang/Optimizer/Support/InitFIR.h
    flang/lib/Optimizer/CMakeLists.txt
    flang/tools/fir-opt/fir-opt.cpp
    flang/tools/tco/tco.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/CodeGen/CGOps.td b/flang/include/flang/Optimizer/CodeGen/CGOps.td
new file mode 100644
index 000000000000..9ebda32825a6
--- /dev/null
+++ b/flang/include/flang/Optimizer/CodeGen/CGOps.td
@@ -0,0 +1,177 @@
+//===-- CGOps.td - FIR operation definitions ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Definition of the FIRCG dialect operations
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_DIALECT_FIRCG_OPS
+#define FORTRAN_DIALECT_FIRCG_OPS
+
+include "mlir/IR/SymbolInterfaces.td"
+include "flang/Optimizer/Dialect/FIRTypes.td"
+
+def fircg_Dialect : Dialect {
+  let name = "fircg";
+  let cppNamespace = "::fir::cg";
+}
+
+// Base class for FIR CG operations.
+// All operations automatically get a prefix of "fircg.".
+class fircg_Op<string mnemonic, list<OpTrait> traits>
+  : Op<fircg_Dialect, mnemonic, traits>;
+
+// Extended embox operation.
+def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> {
+  let summary = "for internal conversion only";
+
+  let description = [{
+    Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will
+    be converted to an extended embox. This op will have the following sets of
+    arguments.
+
+       - memref: The memory reference being emboxed.
+       - shape: A vector that is the runtime shape of the underlying array.
+       - shift: A vector that is the runtime origin of the first element.
+         The default is a vector of the value 1.
+       - slice: A vector of triples that describe an array slice.
+       - subcomponent: A vector of indices for subobject slicing.
+       - LEN type parameters: A vector of runtime LEN type parameters that
+         describe an correspond to the elemental derived type.
+
+    The memref and shape arguments are mandatory. The rest are optional.
+  }];
+
+  let arguments = (ins
+    AnyReferenceLike:$memref,
+    Variadic<AnyIntegerType>:$shape,
+    Variadic<AnyIntegerType>:$shift,
+    Variadic<AnyIntegerType>:$slice,
+    Variadic<AnyCoordinateType>:$subcomponent,
+    Variadic<AnyIntegerType>:$lenParams
+  );
+  let results = (outs fir_BoxType);
+
+  let assemblyFormat = [{
+    $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)?
+      (`path` $subcomponent^)? (`typeparams` $lenParams^)? attr-dict
+      `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    // The rank of the entity being emboxed
+    unsigned getRank() { return shape().size(); }
+
+    // The rank of the result. A slice op can reduce the rank.
+    unsigned getOutRank();
+
+    // The shape operands are mandatory and always start at 1.
+    unsigned shapeOffset() { return 1; }
+    unsigned shiftOffset() { return shapeOffset() + shape().size(); }
+    unsigned sliceOffset() { return shiftOffset() + shift().size(); }
+    unsigned subcomponentOffset() { return sliceOffset() + slice().size(); }
+    unsigned lenParamOffset() {
+      return subcomponentOffset() + subcomponent().size();
+    }
+  }];
+}
+
+// Extended rebox operation.
+def fircg_XReboxOp : fircg_Op<"ext_rebox", [AttrSizedOperandSegments]> {
+  let summary = "for internal conversion only";
+
+  let description = [{
+    Prior to lowering to LLVM IR dialect, a non-scalar non-trivial rebox op will
+    be converted to an extended rebox. This op will have the following sets of
+    arguments.
+
+       - box: The box being reboxed.
+       - shape: A vector that is the new runtime shape for the array
+       - shift: A vector that is the new runtime origin of the first element.
+         The default is a vector of the value 1.
+       - slice: A vector of triples that describe an array slice.
+       - subcomponent: A vector of indices for subobject slicing.
+
+    The box argument is mandatory, the other arguments are optional.
+    There must not both be a shape and slice/subcomponent arguments
+  }];
+
+  let arguments = (ins
+    fir_BoxType:$box,
+    Variadic<AnyIntegerType>:$shape,
+    Variadic<AnyIntegerType>:$shift,
+    Variadic<AnyIntegerType>:$slice,
+    Variadic<AnyCoordinateType>:$subcomponent
+  );
+  let results = (outs fir_BoxType);
+
+  let assemblyFormat = [{
+    $box (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)?
+      (`path` $subcomponent^) ? attr-dict
+      `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    // The rank of the entity being reboxed
+    unsigned getRank();
+    // The rank of the result box
+    unsigned getOutRank();
+  }];
+}
+
+
+// Extended array coordinate operation.
+def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]> {
+  let summary = "for internal conversion only";
+
+  let description = [{
+    Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will
+    be converted to an extended embox. This op will have the following sets of
+    arguments.
+
+       - memref: The memory reference of the array's data. It can be a fir.box if
+         the underlying data is not contiguous.
+       - shape: A vector that is the runtime shape of the underlying array.
+       - shift: A vector that is the runtime origin of the first element.
+         The default is a vector of the value 1.
+       - slice: A vector of triples that describe an array slice.
+       - subcomponent: A vector of indices that describe subobject slicing.
+       - indices: A vector of runtime values that describe the coordinate of
+         the element of the array to be computed.
+       - LEN type parameters: A vector of runtime LEN type parameters that
+         describe an correspond to the elemental derived type.
+
+    The memref and indices arguments are mandatory.
+    The shape argument is mandatory if the memref is not a box, and should be
+    omitted otherwise. The rest of the arguments are optional.
+  }];
+
+  let arguments = (ins
+    AnyRefOrBox:$memref,
+    Variadic<AnyIntegerType>:$shape,
+    Variadic<AnyIntegerType>:$shift,
+    Variadic<AnyIntegerType>:$slice,
+    Variadic<AnyCoordinateType>:$subcomponent,
+    Variadic<AnyCoordinateType>:$indices,
+    Variadic<AnyIntegerType>:$lenParams
+  );
+  let results = (outs fir_ReferenceType);
+
+  let assemblyFormat = [{
+    $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)?
+      (`path` $subcomponent^)? `<`$indices`>` (`typeparams` $lenParams^)?
+      attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    unsigned getRank();
+  }];
+}
+
+#endif

diff  --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index 46442a281606..ffe829644d1a 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -11,18 +11,24 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef FLANG_OPTIMIZER_CODEGEN_PASSES
-#define FLANG_OPTIMIZER_CODEGEN_PASSES
+#ifndef FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES
+#define FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES
 
 include "mlir/Pass/PassBase.td"
 
-def CodeGenRewrite : Pass<"cg-rewrite", "mlir::ModuleOp"> {
+def CodeGenRewrite : Pass<"cg-rewrite"> {
   let summary = "Rewrite some FIR ops into their code-gen forms.";
   let description = [{
     Fuse specific subgraphs into single Ops for code generation.
   }];
   let constructor = "fir::createFirCodeGenRewritePass()";
-  let dependentDialects = ["fir::FIROpsDialect"];
+  let dependentDialects = [
+    "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::BuiltinDialect",
+    "mlir::LLVM::LLVMDialect", "mlir::omp::OpenMPDialect"
+  ];
+  let statistics = [
+    Statistic<"numDCE", "num-dce'd", "Number of operations eliminated">
+  ];
 }
 
-#endif // FLANG_OPTIMIZER_CODEGEN_PASSES
+#endif // FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES

diff  --git a/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt b/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
index 8cbd772b30ab..3eda75190ba2 100644
--- a/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS CGOps.td)
+mlir_tablegen(CGOps.h.inc -gen-op-decls)
+mlir_tablegen(CGOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(CGOpsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS CGPasses.td)
 mlir_tablegen(CGPasses.h.inc -gen-pass-decls -name OptCodeGen)

diff  --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.h b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
index 4bafb4ab7fb6..fb828716d45a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRDialect.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
@@ -40,6 +40,16 @@ class FIROpsDialect final : public mlir::Dialect {
   void registerTypes();
 };
 
+/// The FIR codegen dialect is a dialect containing a small set of transient
+/// operations used exclusively during code generation.
+class FIRCodeGenDialect final : public mlir::Dialect {
+public:
+  explicit FIRCodeGenDialect(mlir::MLIRContext *ctx);
+  virtual ~FIRCodeGenDialect();
+
+  static llvm::StringRef getDialectNamespace() { return "fircg"; }
+};
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_DIALECT_FIRDIALECT_H

diff  --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index cb2dd4f4776b..194d42a41a1c 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -21,15 +21,16 @@
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/Passes.h"
+#include "flang/Optimizer/CodeGen/CodeGen.h"
 
 namespace fir::support {
 
 // The definitive list of dialects used by flang.
 #define FLANG_DIALECT_LIST                                                     \
-  mlir::AffineDialect, FIROpsDialect, mlir::LLVM::LLVMDialect,                 \
-      mlir::acc::OpenACCDialect, mlir::omp::OpenMPDialect,                     \
-      mlir::scf::SCFDialect, mlir::StandardOpsDialect,                         \
-      mlir::vector::VectorDialect
+  mlir::AffineDialect, FIROpsDialect, FIRCodeGenDialect,                       \
+      mlir::LLVM::LLVMDialect, mlir::acc::OpenACCDialect,                      \
+      mlir::omp::OpenMPDialect, mlir::scf::SCFDialect,                         \
+      mlir::StandardOpsDialect, mlir::vector::VectorDialect
 
 /// Register all the dialects used by flang.
 inline void registerDialects(mlir::DialectRegistry &registry) {
@@ -45,7 +46,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
 
 /// Register the standard passes we use. This comes from registerAllPasses(),
 /// but is a smaller set since we aren't using many of the passes found there.
-inline void registerFIRPasses() {
+inline void registerMLIRPassesForFortranTools() {
   mlir::registerCanonicalizerPass();
   mlir::registerCSEPass();
   mlir::registerAffineLoopFusionPass();
@@ -69,6 +70,9 @@ inline void registerFIRPasses() {
   mlir::registerAffineDataCopyGenerationPass();
 
   mlir::registerConvertAffineToStandardPass();
+
+  // Flang passes
+  fir::registerOptCodeGenPasses();
 }
 
 } // namespace fir::support

diff  --git a/flang/lib/Optimizer/CMakeLists.txt b/flang/lib/Optimizer/CMakeLists.txt
index 0a7286339e2e..b83d6a079db6 100644
--- a/flang/lib/Optimizer/CMakeLists.txt
+++ b/flang/lib/Optimizer/CMakeLists.txt
@@ -10,11 +10,16 @@ add_flang_library(FIROptimizer
   Support/InternalNames.cpp
   Support/KindMapping.cpp
 
+  CodeGen/CGOps.cpp
+  CodeGen/PreCGRewrite.cpp
+
   Transforms/Inliner.cpp
 
   DEPENDS
   FIROpsIncGen
+  FIROptCodeGenPassIncGen
   FIROptTransformsPassIncGen
+  CGOpsIncGen
   ${dialect_libs}
 
   LINK_LIBS

diff  --git a/flang/lib/Optimizer/CodeGen/CGOps.cpp b/flang/lib/Optimizer/CodeGen/CGOps.cpp
new file mode 100644
index 000000000000..527066ec5ccd
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/CGOps.cpp
@@ -0,0 +1,64 @@
+//===-- CGOps.cpp -- FIR codegen operations -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "CGOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+
+/// FIR codegen dialect constructor.
+fir::FIRCodeGenDialect::FIRCodeGenDialect(mlir::MLIRContext *ctx)
+    : mlir::Dialect("fircg", ctx, mlir::TypeID::get<FIRCodeGenDialect>()) {
+  addOperations<
+#define GET_OP_LIST
+#include "flang/Optimizer/CodeGen/CGOps.cpp.inc"
+      >();
+}
+
+// anchor the class vtable to this compilation unit
+fir::FIRCodeGenDialect::~FIRCodeGenDialect() {
+  // do nothing
+}
+
+#define GET_OP_CLASSES
+#include "flang/Optimizer/CodeGen/CGOps.cpp.inc"
+
+unsigned fir::cg::XEmboxOp::getOutRank() {
+  if (slice().empty())
+    return getRank();
+  auto outRank = fir::SliceOp::getOutputRank(slice());
+  assert(outRank >= 1);
+  return outRank;
+}
+
+unsigned fir::cg::XReboxOp::getOutRank() {
+  if (auto seqTy =
+          fir::dyn_cast_ptrOrBoxEleTy(getType()).dyn_cast<fir::SequenceType>())
+    return seqTy.getDimension();
+  return 0;
+}
+
+unsigned fir::cg::XReboxOp::getRank() {
+  if (auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(box().getType())
+                       .dyn_cast<fir::SequenceType>())
+    return seqTy.getDimension();
+  return 0;
+}
+
+unsigned fir::cg::XArrayCoorOp::getRank() {
+  auto memrefTy = memref().getType();
+  if (memrefTy.isa<fir::BoxType>())
+    if (auto seqty =
+            fir::dyn_cast_ptrOrBoxEleTy(memrefTy).dyn_cast<fir::SequenceType>())
+      return seqty.getDimension();
+  return shape().size();
+}

diff  --git a/flang/lib/Optimizer/CodeGen/CGOps.h b/flang/lib/Optimizer/CodeGen/CGOps.h
new file mode 100644
index 000000000000..f5f552c63376
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/CGOps.h
@@ -0,0 +1,24 @@
+//===-- CGOps.h -------------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef OPTIMIZER_CODEGEN_CGOPS_H
+#define OPTIMIZER_CODEGEN_CGOPS_H
+
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "flang/Optimizer/CodeGen/CGOps.h.inc"
+
+#endif

diff  --git a/flang/lib/Optimizer/CodeGen/PassDetail.h b/flang/lib/Optimizer/CodeGen/PassDetail.h
new file mode 100644
index 000000000000..f7030131beff
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/PassDetail.h
@@ -0,0 +1,26 @@
+//===- PassDetail.h - Optimizer code gen Pass class details -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef OPTMIZER_CODEGEN_PASSDETAIL_H
+#define OPTMIZER_CODEGEN_PASSDETAIL_H
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace fir {
+
+#define GEN_PASS_CLASSES
+#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
+
+} // namespace fir
+
+#endif // OPTMIZER_CODEGEN_PASSDETAIL_H

diff  --git a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp
new file mode 100644
index 000000000000..eca417ae49b8
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp
@@ -0,0 +1,263 @@
+//===-- PreCGRewrite.cpp --------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+
+#include "CGOps.h"
+#include "PassDetail.h"
+#include "flang/Optimizer/CodeGen/CodeGen.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/FIRContext.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+//===----------------------------------------------------------------------===//
+// Codegen rewrite: rewriting of subgraphs of ops
+//===----------------------------------------------------------------------===//
+
+using namespace fir;
+
+#define DEBUG_TYPE "flang-codegen-rewrite"
+
+static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
+                          ShapeOp shape) {
+  vec.append(shape.extents().begin(), shape.extents().end());
+}
+
+// Operands of fir.shape_shift split into two vectors.
+static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
+                                  llvm::SmallVectorImpl<mlir::Value> &shiftVec,
+                                  ShapeShiftOp shift) {
+  auto endIter = shift.pairs().end();
+  for (auto i = shift.pairs().begin(); i != endIter;) {
+    shiftVec.push_back(*i++);
+    shapeVec.push_back(*i++);
+  }
+}
+
+static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
+                          ShiftOp shift) {
+  vec.append(shift.origins().begin(), shift.origins().end());
+}
+
+namespace {
+
+/// Convert fir.embox to the extended form where necessary.
+///
+/// The embox operation can take arguments that specify multidimensional array
+/// properties at runtime. These properties may be shared between distinct
+/// objects that have the same properties. Before we lower these small DAGs to
+/// LLVM-IR, we gather all the information into a single extended operation. For
+/// example,
+/// ```
+/// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
+/// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
+/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
+/// ```
+/// can be rewritten as
+/// ```
+/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>>
+/// ```
+class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(EmboxOp embox,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto shapeVal = embox.getShape();
+    // If the embox does not include a shape, then do not convert it
+    if (shapeVal)
+      return rewriteDynamicShape(embox, rewriter, shapeVal);
+    if (auto boxTy = embox.getType().dyn_cast<BoxType>())
+      if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>())
+        if (seqTy.hasConstantShape())
+          return rewriteStaticShape(embox, rewriter, seqTy);
+    return mlir::failure();
+  }
+
+  mlir::LogicalResult rewriteStaticShape(EmboxOp embox,
+                                         mlir::PatternRewriter &rewriter,
+                                         SequenceType seqTy) const {
+    auto loc = embox.getLoc();
+    llvm::SmallVector<mlir::Value> shapeOpers;
+    auto idxTy = rewriter.getIndexType();
+    for (auto ext : seqTy.getShape()) {
+      auto iAttr = rewriter.getIndexAttr(ext);
+      auto extVal = rewriter.create<mlir::ConstantOp>(loc, idxTy, iAttr);
+      shapeOpers.push_back(extVal);
+    }
+    auto xbox = rewriter.create<cg::XEmboxOp>(
+        loc, embox.getType(), embox.memref(), shapeOpers, llvm::None,
+        llvm::None, llvm::None, embox.lenParams());
+    LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
+    rewriter.replaceOp(embox, xbox.getOperation()->getResults());
+    return mlir::success();
+  }
+
+  mlir::LogicalResult rewriteDynamicShape(EmboxOp embox,
+                                          mlir::PatternRewriter &rewriter,
+                                          mlir::Value shapeVal) const {
+    auto loc = embox.getLoc();
+    auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp());
+    llvm::SmallVector<mlir::Value> shapeOpers;
+    llvm::SmallVector<mlir::Value> shiftOpers;
+    if (shapeOp) {
+      populateShape(shapeOpers, shapeOp);
+    } else {
+      auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp());
+      assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
+      populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
+    }
+    llvm::SmallVector<mlir::Value> sliceOpers;
+    llvm::SmallVector<mlir::Value> subcompOpers;
+    if (auto s = embox.getSlice())
+      if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
+        sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
+        subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
+      }
+    auto xbox = rewriter.create<cg::XEmboxOp>(
+        loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers,
+        sliceOpers, subcompOpers, embox.lenParams());
+    LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
+    rewriter.replaceOp(embox, xbox.getOperation()->getResults());
+    return mlir::success();
+  }
+};
+
+/// Convert fir.rebox to the extended form where necessary.
+///
+/// For example,
+/// ```
+/// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>>
+/// ```
+/// converted to
+/// ```
+/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, index, index) -> !fir.box<!fir.array<?xi32>>
+/// ```
+class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(ReboxOp rebox,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto loc = rebox.getLoc();
+    llvm::SmallVector<mlir::Value> shapeOpers;
+    llvm::SmallVector<mlir::Value> shiftOpers;
+    if (auto shapeVal = rebox.shape()) {
+      if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
+        populateShape(shapeOpers, shapeOp);
+      else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
+        populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
+      else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
+        populateShift(shiftOpers, shiftOp);
+      else
+        return mlir::failure();
+    }
+    llvm::SmallVector<mlir::Value> sliceOpers;
+    llvm::SmallVector<mlir::Value> subcompOpers;
+    if (auto s = rebox.slice())
+      if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
+        sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
+        subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
+      }
+
+    auto xRebox = rewriter.create<cg::XReboxOp>(
+        loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers,
+        subcompOpers);
+    LLVM_DEBUG(llvm::dbgs()
+               << "rewriting " << rebox << " to " << xRebox << '\n');
+    rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
+    return mlir::success();
+  }
+};
+
+/// Convert all fir.array_coor to the extended form.
+///
+/// For example,
+/// ```
+///  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
+/// ```
+/// converted to
+/// ```
+/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> !fir.ref<i32>
+/// ```
+class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(ArrayCoorOp arrCoor,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto loc = arrCoor.getLoc();
+    llvm::SmallVector<mlir::Value> shapeOpers;
+    llvm::SmallVector<mlir::Value> shiftOpers;
+    if (auto shapeVal = arrCoor.shape()) {
+      if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
+        populateShape(shapeOpers, shapeOp);
+      else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
+        populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
+      else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
+        populateShift(shiftOpers, shiftOp);
+      else
+        return mlir::failure();
+    }
+    llvm::SmallVector<mlir::Value> sliceOpers;
+    llvm::SmallVector<mlir::Value> subcompOpers;
+    if (auto s = arrCoor.slice())
+      if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
+        sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
+        subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
+      }
+    auto xArrCoor = rewriter.create<cg::XArrayCoorOp>(
+        loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers,
+        sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams());
+    LLVM_DEBUG(llvm::dbgs()
+               << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
+    rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
+    return mlir::success();
+  }
+};
+
+class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> {
+public:
+  void runOnOperation() override final {
+    auto op = getOperation();
+    auto &context = getContext();
+    mlir::OpBuilder rewriter(&context);
+    mlir::ConversionTarget target(context);
+    target.addLegalDialect<FIROpsDialect, FIRCodeGenDialect,
+                           mlir::StandardOpsDialect>();
+    target.addIllegalOp<ArrayCoorOp>();
+    target.addIllegalOp<ReboxOp>();
+    target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) {
+      return !(embox.getShape() ||
+               embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>());
+    });
+    mlir::OwningRewritePatternList patterns;
+    patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
+        &context);
+    if (mlir::failed(
+            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(&context),
+                      "error in running the pre-codegen conversions");
+      signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
+  return std::make_unique<CodeGenRewrite>();
+}

diff  --git a/flang/test/Fir/cg-ops.fir b/flang/test/Fir/cg-ops.fir
new file mode 100644
index 000000000000..a138313eef94
--- /dev/null
+++ b/flang/test/Fir/cg-ops.fir
@@ -0,0 +1,30 @@
+// RUN: fir-opt --pass-pipeline="func(cg-rewrite),fir.global(cg-rewrite),cse" %s | FileCheck %s
+
+// CHECK-LABEL: func @codegen(
+// CHECK-SAME: %[[arg:.*]]: !fir
+func @codegen(%addr : !fir.ref<!fir.array<?xi32>>) {
+  // CHECK: %[[zero:.*]] = constant 0 : index
+  %0 = constant 0 : index
+  %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1>
+  %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1>
+  // CHECK: %[[box:.*]] = fircg.ext_embox %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>>
+  %3 = fir.embox %addr (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
+  // CHECK: fircg.ext_array_coor %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]]<%[[zero]]> : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> !fir.ref<i32>
+  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
+  // CHECK: fircg.ext_rebox %[[box]](%[[zero]]) origin %[[zero]] : (!fir.box<!fir.array<?xi32>>, index, index) -> !fir.box<!fir.array<?xi32>>
+  %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>>
+  return
+}
+
+// CHECK-LABEL: fir.global @box_global
+fir.global @box_global : !fir.box<!fir.array<?xi32>> {
+  // CHECK: %[[arr:.*]] = fir.zero_bits !fir.ref
+  %arr = fir.zero_bits !fir.ref<!fir.array<?xi32>>
+  // CHECK: %[[zero:.*]] = constant 0 : index
+  %0 = constant 0 : index
+  %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1>
+  %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1>
+  // CHECK: fircg.ext_embox %[[arr]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>>
+  %3 = fir.embox %arr (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
+  fir.has_value %3 : !fir.box<!fir.array<?xi32>>
+}

diff  --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp
index b2d383c06682..b66294339f1a 100644
--- a/flang/tools/fir-opt/fir-opt.cpp
+++ b/flang/tools/fir-opt/fir-opt.cpp
@@ -17,9 +17,9 @@
 using namespace mlir;
 
 int main(int argc, char **argv) {
-  fir::support::registerFIRPasses();
+  fir::support::registerMLIRPassesForFortranTools();
   DialectRegistry registry;
   fir::support::registerDialects(registry);
   return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
-      registry, /*preloadDialectsInContext*/ false));
+      registry, /*preloadDialectsInContext=*/false));
 }

diff  --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index a67b1453fc28..62e31fe47ed1 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -106,7 +106,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
 }
 
 int main(int argc, char **argv) {
-  fir::support::registerFIRPasses();
+  fir::support::registerMLIRPassesForFortranTools();
   [[maybe_unused]] InitLLVM y(argc, argv);
   mlir::registerPassManagerCLOptions();
   mlir::PassPipelineCLParser passPipe("", "Compiler passes to run");


        


More information about the flang-commits mailing list