[flang-commits] [flang] 1ea66ee - [flang] Adapt target rewrite for fir.dispatch operation

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Nov 28 08:36:09 PST 2022


Author: Valentin Clement
Date: 2022-11-28T17:36:03+01:00
New Revision: 1ea66eefec8dcb1c7edbc47489489250dd0f1996

URL: https://github.com/llvm/llvm-project/commit/1ea66eefec8dcb1c7edbc47489489250dd0f1996
DIFF: https://github.com/llvm/llvm-project/commit/1ea66eefec8dcb1c7edbc47489489250dd0f1996.diff

LOG: [flang] Adapt target rewrite for fir.dispatch operation

Handle rewriting dispatch operation with complex arguments or
return.

sret will be done in a separate patch.

Reviewed By: jeanPerier, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D138820

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
    flang/test/Fir/target-rewrite-complex.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 9bf51cc6ee1a..1ad2526bd04b 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -209,6 +209,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
         newOpers.push_back(callOp.getOperand(0));
         dropFront = 1;
       }
+    } else {
+      dropFront = 1; // First operand is the polymorphic object.
     }
 
     // Determine the rewrite function, `wrap`, for the result value.
@@ -231,6 +233,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
     llvm::SmallVector<mlir::Type> trailingInTys;
     llvm::SmallVector<mlir::Value> trailingOpers;
+    unsigned passArgShift = 0;
     for (auto e : llvm::enumerate(
              llvm::zip(fnTy.getInputs().drop_front(dropFront),
                        callOp.getOperands().drop_front(dropFront)))) {
@@ -314,6 +317,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             }
           })
           .Default([&](mlir::Type ty) {
+            if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
+              if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index)
+                passArgShift = newOpers.size() - *callOp.getPassArgPos();
+            }
             newInTys.push_back(ty);
             newOpers.push_back(oper);
           });
@@ -338,8 +345,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       else
         replaceOp(callOp, newCall.getResults());
     } else {
-      // A is fir::DispatchOp
-      TODO(loc, "dispatch not implemented");
+      fir::DispatchOp dispatchOp = rewriter->create<A>(
+          loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
+          callOp.getOperands()[0], newOpers,
+          rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift));
+      if (wrap)
+        replaceOp(callOp, (*wrap)(dispatchOp.getOperation()));
+      else
+        replaceOp(callOp, dispatchOp.getResults());
     }
   }
 

diff  --git a/flang/test/Fir/target-rewrite-complex.fir b/flang/test/Fir/target-rewrite-complex.fir
index bf7f618fc490..d2ae44075af2 100644
--- a/flang/test/Fir/target-rewrite-complex.fir
+++ b/flang/test/Fir/target-rewrite-complex.fir
@@ -122,12 +122,12 @@ func.func @returncomplex8() -> !fir.complex<8> {
 func.func private @paramcomplex4(!fir.complex<4>) -> ()
 
 // Test that we rewrite calls to functions that return or accept complex<4>.
-// I32-LABEL: func @callcomplex4()
-// X64-LABEL: func @callcomplex4()
-// AARCH64-LABEL: func @callcomplex4()
-// PPC-LABEL: func @callcomplex4()
-// SPARCV9-LABEL: func @callcomplex4()
-func.func @callcomplex4() {
+// I32-LABEL: func @callcomplex4
+// X64-LABEL: func @callcomplex4
+// AARCH64-LABEL: func @callcomplex4
+// PPC-LABEL: func @callcomplex4
+// SPARCV9-LABEL: func @callcomplex4
+func.func @callcomplex4(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {
 
   // I32: [[RES:%[0-9A-Za-z]+]] = fir.call @returncomplex4() : () -> i64
   // X64: [[RES:%[0-9A-Za-z]+]] = fir.call @returncomplex4() : () -> !fir.vector<2:!fir.real<4>>
@@ -181,6 +181,69 @@ func.func @callcomplex4() {
   // SPARCV9: [[B:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [1 : i32] : (!fir.complex<4>) -> !fir.real<4>
   // SPARCV9: fir.call @paramcomplex4([[A]], [[B]]) : (!fir.real<4>, !fir.real<4>) -> ()
   fir.call @paramcomplex4(%1) : (!fir.complex<4>) -> ()
+
+  // I32: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> i64 {pass_arg_pos = 0 : i32}
+  // X64: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.vector<2:!fir.real<4>> {pass_arg_pos = 0 : i32}
+  // AARCH64: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> tuple<!fir.real<4>, !fir.real<4>> {pass_arg_pos = 0 : i32}
+  // PPC: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> tuple<!fir.real<4>, !fir.real<4>> {pass_arg_pos = 0 : i32}
+  // SPARCV9: [[RES:%[0-9A-Za-z]+]] = fir.dispatch "ret_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> tuple<!fir.real<4>, !fir.real<4>> {pass_arg_pos = 0 : i32}
+  %2 = fir.dispatch "ret_complex"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.complex<4> {pass_arg_pos = 0 : i32}
+
+  // I32: [[ADDRI64:%[0-9A-Za-z]+]] = fir.alloca i64
+  // I32: fir.store [[RES]] to [[ADDRI64]] : !fir.ref<i64>
+  // I32: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRI64]] : (!fir.ref<i64>) -> !fir.ref<!fir.complex<4>>
+  // I32: [[C:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref<!fir.complex<4>>
+  // I32: [[ADDRC2:%[0-9A-Za-z]+]] = fir.alloca !fir.complex<4>
+  // I32: fir.store [[C]] to [[ADDRC2]] : !fir.ref<!fir.complex<4>>
+  // I32: [[T:%[0-9A-Za-z]+]] = fir.convert [[ADDRC2]] : (!fir.ref<!fir.complex<4>>) -> !fir.ref<tuple<!fir.real<4>, !fir.real<4>>>
+  // I32: fir.dispatch "with_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, [[T]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>, !fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) {pass_arg_pos = 0 : i32}
+
+  // X64: [[ADDRV:%[0-9A-Za-z]+]] = fir.alloca !fir.vector<2:!fir.real<4>>
+  // X64: fir.store [[RES]] to [[ADDRV]] : !fir.ref<!fir.vector<2:!fir.real<4>>>
+  // X64: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRV]] : (!fir.ref<!fir.vector<2:!fir.real<4>>>) -> !fir.ref<!fir.complex<4>>
+  // X64: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref<!fir.complex<4>>
+  // X64: [[ADDRV2:%[0-9A-Za-z]+]] = fir.alloca !fir.vector<2:!fir.real<4>>
+  // X64: [[ADDRC2:%[0-9A-Za-z]+]] = fir.convert [[ADDRV2]] : (!fir.ref<!fir.vector<2:!fir.real<4>>>) -> !fir.ref<!fir.complex<4>>
+  // X64: fir.store [[V]] to [[ADDRC2]] : !fir.ref<!fir.complex<4>>
+  // X64: [[VRELOADED:%[0-9A-Za-z]+]] = fir.load [[ADDRV2]] : !fir.ref<!fir.vector<2:!fir.real<4>>>
+  // X64: fir.dispatch "with_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, [[VRELOADED]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>, !fir.vector<2:!fir.real<4>>) {pass_arg_pos = 0 : i32}
+
+  // AARCH64: [[ADDRT:%[0-9A-Za-z]+]] = fir.alloca tuple<!fir.real<4>, !fir.real<4>>
+  // AARCH64: fir.store [[RES]] to [[ADDRT]] : !fir.ref<tuple<!fir.real<4>, !fir.real<4>>>
+  // AARCH64: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRT]] : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> !fir.ref<!fir.complex<4>>
+  // AARCH64: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref<!fir.complex<4>>
+  // AARCH64: [[ADDRARR:%[0-9A-Za-z]+]] = fir.alloca !fir.array<2x!fir.real<4>>
+  // AARCH64: [[ADDRC2:%[0-9A-Za-z]+]] = fir.convert [[ADDRARR]] : (!fir.ref<!fir.array<2x!fir.real<4>>>) -> !fir.ref<!fir.complex<4>>
+  // AARCH64: fir.store [[V]] to [[ADDRC2]] : !fir.ref<!fir.complex<4>>
+  // AARCH64: [[ARR:%[0-9A-Za-z]+]] = fir.load [[ADDRARR]] : !fir.ref<!fir.array<2x!fir.real<4>>>
+  // AARCH64: fir.dispatch "with_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, [[ARR]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>, !fir.array<2x!fir.real<4>>) {pass_arg_pos = 0 : i32}
+
+  // PPC: [[ADDRT:%[0-9A-Za-z]+]] = fir.alloca tuple<!fir.real<4>, !fir.real<4>>
+  // PPC: fir.store [[RES]] to [[ADDRT]] : !fir.ref<tuple<!fir.real<4>, !fir.real<4>>>
+  // PPC: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRT]] : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> !fir.ref<!fir.complex<4>>
+  // PPC: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref<!fir.complex<4>>
+  // PPC: [[A:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [0 : i32] : (!fir.complex<4>) -> !fir.real<4>
+  // PPC: [[B:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [1 : i32] : (!fir.complex<4>) -> !fir.real<4>
+  // PPC: fir.dispatch "with_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, [[A]], [[B]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>, !fir.real<4>, !fir.real<4>) {pass_arg_pos = 0 : i32}
+
+  // SPARCV9: [[ADDRT:%[0-9A-Za-z]+]] = fir.alloca tuple<!fir.real<4>, !fir.real<4>>
+  // SPARCV9: fir.store [[RES]] to [[ADDRT]] : !fir.ref<tuple<!fir.real<4>, !fir.real<4>>>
+  // SPARCV9: [[ADDRC:%[0-9A-Za-z]+]] = fir.convert [[ADDRT]] : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> !fir.ref<!fir.complex<4>>
+  // SPARCV9: [[V:%[0-9A-Za-z]+]] = fir.load [[ADDRC]] : !fir.ref<!fir.complex<4>>
+  // SPARCV9: [[A:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [0 : i32] : (!fir.complex<4>) -> !fir.real<4>
+  // SPARCV9: [[B:%[0-9A-Za-z]+]] = fir.extract_value [[V]], [1 : i32] : (!fir.complex<4>) -> !fir.real<4>
+  // SPARCV9: fir.dispatch "with_complex"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, [[A]], [[B]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>, !fir.real<4>, !fir.real<4>) {pass_arg_pos = 0 : i32}
+
+  fir.dispatch "with_complex"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0, %2 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>, !fir.complex<4>) {pass_arg_pos = 0 : i32}
+
+
+  // I32: fir.dispatch "with_complex2"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, %{{.*}} : !fir.ref<tuple<!fir.real<4>, !fir.real<4>>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // X64: fir.dispatch "with_complex2"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, %{{.*}} : !fir.vector<2:!fir.real<4>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // AARCH64: fir.dispatch "with_complex2"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, %{{.*}} : !fir.array<2x!fir.real<4>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // PPC: fir.dispatch "with_complex2"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, %{{.*}}, %{{.*}} : !fir.real<4>, !fir.real<4>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 2 : i32}
+  // SPARCV9: fir.dispatch "with_complex2"(%{{.*}} : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%{{.*}}, %{{.*}}, %{{.*}} : !fir.real<4>, !fir.real<4>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 2 : i32}
+  fir.dispatch "with_complex2"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%2, %arg0 : !fir.complex<4>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  
   return
 }
 


        


More information about the flang-commits mailing list