[flang-commits] [flang] afb34cf - [flang] Hanlde disptach op in abstract result pass

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Nov 28 07:49:59 PST 2022


Author: Valentin Clement
Date: 2022-11-28T16:49:51+01:00
New Revision: afb34cf3077a38007fcebe17dc384532207283fa

URL: https://github.com/llvm/llvm-project/commit/afb34cf3077a38007fcebe17dc384532207283fa
DIFF: https://github.com/llvm/llvm-project/commit/afb34cf3077a38007fcebe17dc384532207283fa.diff

LOG: [flang] Hanlde disptach op in abstract result pass

Update the call conversion pattern to support fir.dispatch
operation as well. The first operand of fir.dispatch op is always the
polymoprhic object. The pass_arg_pos attribute needs to be shifted when
the result is added as argument.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D138799

Added: 
    

Modified: 
    flang/lib/Optimizer/Dialect/FIRType.cpp
    flang/lib/Optimizer/Transforms/AbstractResult.cpp
    flang/test/Fir/abstract-results.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index eb9b9afae90e3..89a806c0474aa 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -960,7 +960,7 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) {
   if (ty.getNumResults() == 0)
     return false;
   auto resultType = ty.getResult(0);
-  return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
+  return resultType.isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>();
 }
 
 /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error

diff  --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index dcc6e902fd84c..df00c17863d90 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -28,6 +28,8 @@ namespace fir {
 
 #define DEBUG_TYPE "flang-abstract-result-opt"
 
+using namespace mlir;
+
 namespace fir {
 namespace {
 
@@ -40,7 +42,7 @@ static mlir::Type getResultArgumentType(mlir::Type resultType,
               return fir::BoxType::get(type);
             return fir::ReferenceType::get(type);
           })
-      .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
+      .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
         return fir::ReferenceType::get(type);
       })
       .Default([](mlir::Type) -> mlir::Type {
@@ -75,16 +77,18 @@ static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
          shouldBoxResult;
 }
 
-class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
+template <typename Op>
+class CallConversion : public mlir::OpRewritePattern<Op> {
 public:
-  using OpRewritePattern::OpRewritePattern;
-  CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
-      : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
+      : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
+
   mlir::LogicalResult
-  matchAndRewrite(fir::CallOp callOp,
-                  mlir::PatternRewriter &rewriter) const override {
-    auto loc = callOp.getLoc();
-    auto result = callOp->getResult(0);
+  matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto result = op->getResult(0);
     if (!result.hasOneUse()) {
       mlir::emitError(loc,
                       "calls with abstract result must have exactly one user");
@@ -109,50 +113,74 @@ class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
     // TODO: This should be generalized for derived types, and it is
     // architecture and OS dependent.
     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
-    fir::CallOp newCallOp;
+    Op newOp;
     if (isResultBuiltinCPtr) {
-      auto recTy = result.getType().dyn_cast<fir::RecordType>();
+      auto recTy = result.getType().template dyn_cast<fir::RecordType>();
       newResultTypes.emplace_back(recTy.getTypeList()[0].second);
     }
-    if (callOp.getCallee()) {
+
+    // fir::CallOp specific handling.
+    if constexpr (std::is_same_v<Op, fir::CallOp>) {
+      if (op.getCallee()) {
+        llvm::SmallVector<mlir::Value> newOperands;
+        if (!isResultBuiltinCPtr)
+          newOperands.emplace_back(arg);
+        newOperands.append(op.getOperands().begin(), op.getOperands().end());
+        newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
+                                             newResultTypes, newOperands);
+      } else {
+        // Indirect calls.
+        llvm::SmallVector<mlir::Type> newInputTypes;
+        if (!isResultBuiltinCPtr)
+          newInputTypes.emplace_back(argType);
+        for (auto operand : op.getOperands().drop_front())
+          newInputTypes.push_back(operand.getType());
+        auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
+                                                 newResultTypes);
+
+        llvm::SmallVector<mlir::Value> newOperands;
+        newOperands.push_back(
+            rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
+        if (!isResultBuiltinCPtr)
+          newOperands.push_back(arg);
+        newOperands.append(op.getOperands().begin() + 1,
+                           op.getOperands().end());
+        newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
+                                             newResultTypes, newOperands);
+      }
+    }
+
+    // fir::DispatchOp specific handling.
+    if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
       llvm::SmallVector<mlir::Value> newOperands;
       if (!isResultBuiltinCPtr)
         newOperands.emplace_back(arg);
-      newOperands.append(callOp.getOperands().begin(),
-                         callOp.getOperands().end());
-      newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
-                                               newResultTypes, newOperands);
-    } else {
-      // Indirect calls.
-      llvm::SmallVector<mlir::Type> newInputTypes;
-      if (!isResultBuiltinCPtr)
-        newInputTypes.emplace_back(argType);
-      for (auto operand : callOp.getOperands().drop_front())
-        newInputTypes.push_back(operand.getType());
-      auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
-                                               newInputTypes, newResultTypes);
+      unsigned passArgShift = newOperands.size();
+      newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
 
-      llvm::SmallVector<mlir::Value> newOperands;
-      newOperands.push_back(rewriter.create<fir::ConvertOp>(
-          loc, newFuncTy, callOp.getOperand(0)));
-      if (!isResultBuiltinCPtr)
-        newOperands.push_back(arg);
-      newOperands.append(callOp.getOperands().begin() + 1,
-                         callOp.getOperands().end());
-      newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
-                                               newResultTypes, newOperands);
+      fir::DispatchOp newDispatchOp;
+      if (op.getPassArgPos())
+        newOp = rewriter.create<fir::DispatchOp>(
+            loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+            op.getOperands()[0], newOperands,
+            rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
+      else
+        newOp = rewriter.create<fir::DispatchOp>(
+            loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+            op.getOperands()[0], newOperands, nullptr);
     }
+
     if (isResultBuiltinCPtr) {
       mlir::Value save = saveResult.getMemref();
-      auto module = callOp->getParentOfType<mlir::ModuleOp>();
+      auto module = op->template getParentOfType<mlir::ModuleOp>();
       fir::KindMapping kindMap = fir::getKindMapping(module);
       FirOpBuilder builder(rewriter, kindMap);
       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
           builder, loc, save, result.getType());
-      rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
+      rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
     }
-    callOp->dropAllReferences();
-    rewriter.eraseOp(callOp);
+    op->dropAllReferences();
+    rewriter.eraseOp(op);
     return mlir::success();
   }
 
@@ -289,17 +317,11 @@ class AbstractResultOptTemplate : public PassBase<Pass> {
       return true;
     });
     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
-      if (dispatch->getNumResults() != 1)
-        return true;
-      auto resultType = dispatch->getResult(0).getType();
-      if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
-        TODO(dispatch.getLoc(), "dispatchOp with abstract results");
-        return false;
-      }
-      return true;
+      return !hasAbstractResult(dispatch.getFunctionType());
     });
 
-    patterns.insert<CallOpConversion>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
     patterns.insert<SaveResultOpConversion>(context);
     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
     if (mlir::failed(

diff  --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 14c59a6569744..374c0d18753bb 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -244,6 +244,25 @@ func.func @_QPtest_return_cptr() {
   // FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64>
 }
 
+// FUNC-REF-LABEL: func @dispatch(
+// FUNC-REF-SAME:    %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+// FUNC-BOX-LABEL: func @dispatch(
+// FUNC-BOX-SAME:    %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+func.func @dispatch(%arg0: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) {
+  %buffer = fir.alloca !fir.type<t{x:f32}>
+  %res = fir.dispatch "ret_array"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.type<t{x:f32}> {pass_arg_pos = 0 : i32}
+  fir.save_result %res to %buffer : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>
+  return
+  // FUNC-REF: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+  // FUNC-REF: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[buffer]], %[[ARG0]] : !fir.ref<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // FUNC-REF-NOT: fir.save_result
+
+  // FUNC-BOX: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+  // FUNC-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref<!fir.type<t{x:f32}>>) -> !fir.box<!fir.type<t{x:f32}>>
+  // FUNC-BOX: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[box]], %[[ARG0]] : !fir.box<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // FUNC-BOX-NOT: fir.save_result
+}
+
 // ------------------------ Test fir.address_of rewrite ------------------------
 
 func.func private @takesfuncarray((i32) -> !fir.array<?xf32>)


        


More information about the flang-commits mailing list