[flang-commits] [flang] 5a0722e - [flang] Update fir.dispatch operation

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Oct 6 09:12:06 PDT 2022


Author: Valentin Clement
Date: 2022-10-06T18:11:56+02:00
New Revision: 5a0722e0469894564dd9b0ee1d4f8c291d09c776

URL: https://github.com/llvm/llvm-project/commit/5a0722e0469894564dd9b0ee1d4f8c291d09c776
DIFF: https://github.com/llvm/llvm-project/commit/5a0722e0469894564dd9b0ee1d4f8c291d09c776.diff

LOG: [flang] Update fir.dispatch operation

Update the `fir.dispatch` operation to prepare
the lowering part. `nopass` and `pass_arg_pos` attributes
are added in the arguments list so accessors are generated
by MLIR tablegen. A verifier is added as well as some tests.

This patch is part of the implementation of the poltymorphic
entities.
https://github.com/llvm/llvm-project/blob/main/flang/docs/PolymorphicEntities.md

Reviewed By: jeanPerier, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D135358

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    flang/test/Fir/Todo/dispatch.fir
    flang/test/Fir/fir-ops.fir
    flang/test/Fir/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 5b835c381c03..3dd74db488fc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2327,22 +2327,30 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
 
   let description = [{
     Perform a dynamic dispatch on the method name via the dispatch table
-    associated with the first argument.  The attribute 'pass_arg_pos' can be
-    used to select a dispatch argument other than the first one.
+    associated with the first operand.  The attribute `pass_arg_pos` can be
+    used to select a dispatch operand other than the first one.  The absence of
+    `pass_arg_pos` attribute means nopass.
 
     ```mlir
-      %r = fir.dispatch methodA(%o) : (!fir.box<none>) -> i32
+      // fir.dispatch with no attribute.
+      %r = fir.dispatch "methodA"(%o) : (!fir.class<T>) -> i32
+
+      // fir.dispatch with the `pass_arg_pos` attribute.
+      %r = fir.dispatch "methodA"(%o, %o) : (!fir.class<T>, !fir.class<T>) -> i32 {pass_arg_pos = 0 : i32}
     ```
   }];
 
   let arguments = (ins
     StrAttr:$method,
-    fir_BoxType:$object,
-    Variadic<AnyType>:$args
+    fir_ClassType:$object,
+    Variadic<AnyType>:$args,
+    OptionalAttr<I32Attr>:$pass_arg_pos
   );
 
   let results = (outs Variadic<AnyType>);
 
+  let hasVerifier = 1;
+
   let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
@@ -2350,14 +2358,10 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
     operand_range getArgOperands() {
       return {arg_operand_begin(), arg_operand_end()};
     }
-    // operand[0] is the object (of box type)
+    // operand[0] is the object (of class type)
     operand_iterator arg_operand_begin() { return operand_begin() + 1; }
     operand_iterator arg_operand_end() { return operand_end(); }
-    static constexpr llvm::StringRef getPassArgAttrName() {
-      return "pass_arg_pos";
-    }
     static constexpr llvm::StringRef getMethodAttrNameStr() { return "method"; }
-    unsigned passArgPos();
   }];
 }
 

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 1c1e1a6e9c7f..782f50d1ff5f 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1038,6 +1038,20 @@ mlir::LogicalResult fir::CoordinateOp::verify() {
 // DispatchOp
 //===----------------------------------------------------------------------===//
 
+mlir::LogicalResult fir::DispatchOp::verify() {
+  // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is
+  // unsigned so check for less than zero is not needed.
+  if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1))
+    return emitOpError(
+        "pass_arg_pos must be smaller than the number of operands");
+
+  // Operand pointed by pass_arg_pos must have polymorphic type.
+  if (getPassArgPos() &&
+      !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType()))
+    return emitOpError("pass_arg_pos must be a polymorphic operand");
+  return mlir::success();
+}
+
 mlir::FunctionType fir::DispatchOp::getFunctionType() {
   return mlir::FunctionType::get(getContext(), getOperandTypes(),
                                  getResultTypes());
@@ -1060,11 +1074,11 @@ mlir::ParseResult fir::DispatchOp::parse(mlir::OpAsmParser &parser,
                         parser.getBuilder().getStringAttr(calleeName));
   }
   if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(calleeType) ||
       parser.addTypesToList(calleeType.getResults(), result.types) ||
       parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc,
-                             result.operands))
+                             result.operands) ||
+      parser.parseOptionalAttrDict(result.attributes))
     return mlir::failure();
   return mlir::success();
 }
@@ -1079,6 +1093,9 @@ void fir::DispatchOp::print(mlir::OpAsmPrinter &p) {
   p << ") : ";
   p.printFunctionalType(getOperation()->getOperandTypes(),
                         getOperation()->getResultTypes());
+  p.printOptionalAttrDict(getOperation()->getAttrs(),
+                          {mlir::SymbolTable::getSymbolAttrName(),
+                           fir::DispatchOp::getMethodAttrNameStr()});
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Fir/Todo/dispatch.fir b/flang/test/Fir/Todo/dispatch.fir
index e5b90dad9f73..dcc24bd818a3 100644
--- a/flang/test/Fir/Todo/dispatch.fir
+++ b/flang/test/Fir/Todo/dispatch.fir
@@ -3,8 +3,8 @@
 // Test `fir.dispatch` conversion to llvm.
 // Not implemented yet.
 
-func.func @dispatch(%arg0: !fir.box<!fir.type<derived3{f:f32}>>) {
-// CHECK: not yet implemented: fir.dispatch codegen
-  %0 = fir.dispatch "method"(%arg0) : (!fir.box<!fir.type<derived3{f:f32}>>) -> i32
+func.func @dispatch(%arg0: !fir.class<!fir.type<derived3{f:f32}>>) {
+// CHECK: not yet implemented: fir.class type conversion
+  %0 = fir.dispatch "method"(%arg0) : (!fir.class<!fir.type<derived3{f:f32}>>) -> i32
   return
 }

diff  --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index cf895842753e..66e46e340ab6 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -114,14 +114,14 @@ func.func @instructions() {
   %25 = fir.insert_value %22, %cf1, ["f", !fir.type<derived{f:f32}>] : (!fir.type<derived{f:f32}>, f32) -> !fir.type<derived{f:f32}>
   %26 = fir.len_param_index f, !fir.type<derived3{f:f32}>
 
-// CHECK: [[VAL_31:%.*]] = fir.call @box3() : () -> !fir.box<!fir.type<derived3{f:f32}>>
-// CHECK: [[VAL_32:%.*]] = fir.dispatch "method"([[VAL_31]]) : (!fir.box<!fir.type<derived3{f:f32}>>) -> i32
+// CHECK: [[VAL_31:%.*]] = fir.call @box3() : () -> !fir.class<!fir.type<derived3{f:f32}>>
+// CHECK: [[VAL_32:%.*]] = fir.dispatch "method"([[VAL_31]]) : (!fir.class<!fir.type<derived3{f:f32}>>) -> i32
 // CHECK: [[VAL_33:%.*]] = fir.convert [[VAL_32]] : (i32) -> i64
 // CHECK: [[VAL_34:%.*]] = fir.gentypedesc !fir.type<x>
 // CHECK: fir.call @user_tdesc([[VAL_34]]) : (!fir.tdesc<!fir.type<x>>) -> ()
 // CHECK: [[VAL_35:%.*]] = fir.no_reassoc [[VAL_33]] : i64
-  %27 = fir.call @box3() : () -> !fir.box<!fir.type<derived3{f:f32}>>
-  %28 = fir.dispatch "method"(%27) : (!fir.box<!fir.type<derived3{f:f32}>>) -> i32
+  %27 = fir.call @box3() : () -> !fir.class<!fir.type<derived3{f:f32}>>
+  %28 = fir.dispatch "method"(%27) : (!fir.class<!fir.type<derived3{f:f32}>>) -> i32
   %29 = fir.convert %28 : (i32) -> i64
   %30 = fir.gentypedesc !fir.type<x>
   fir.call @user_tdesc(%30) : (!fir.tdesc<!fir.type<x>>) -> ()
@@ -309,12 +309,12 @@ func.func @bar_select_rank(%arg : i32, %arg2 : i32) -> i32 {
 
 // CHECK: ^bb5:
 // CHECK: [[VAL_99:%.*]] = arith.constant 0 : i32
-// CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.box<!fir.type<derived3{f:f32}>>
-// CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.box<!fir.type<derived3{f:f32}>>) -> ()
+// CHECK: [[VAL_100:%.*]] = fir.call @get_method_box() : () -> !fir.class<!fir.type<derived3{f:f32}>>
+// CHECK: fir.dispatch "method"([[VAL_100]]) : (!fir.class<!fir.type<derived3{f:f32}>>) -> ()
 ^bb5 :
   %zero = arith.constant 0 : i32
-  %7 = fir.call @get_method_box() : () -> !fir.box<!fir.type<derived3{f:f32}>>
-  fir.dispatch method(%7) : (!fir.box<!fir.type<derived3{f:f32}>>) -> ()
+  %7 = fir.call @get_method_box() : () -> !fir.class<!fir.type<derived3{f:f32}>>
+  fir.dispatch method(%7) : (!fir.class<!fir.type<derived3{f:f32}>>) -> ()
 
 // CHECK: return [[VAL_99]] : i32
 // CHECK: }
@@ -805,3 +805,17 @@ func.func @array_amend_ops(%a : !fir.ref<!fir.array<?x?xf32>>) {
   // CHECK: %{{.*}} = fir.array_amend %{{.*}}, %{{.*}} : (!fir.array<?x?xf32>, !fir.ref<f32>) -> !fir.array<?x?xf32>
   return
 }
+
+func.func private @dispatch(%arg0: !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, %arg1: i32) -> () {
+  // CHECK-LABEL: func.func private @dispatch(
+  // CHECK-SAME: %[[CLASS:.*]]: !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, %[[INTARG:.*]]: i32)
+  fir.dispatch "proc1"(%arg0, %arg0) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 0 : i32}
+  // CHECK: fir.dispatch "proc1"(%[[CLASS]], %[[CLASS]]) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 0 : i32}
+
+  fir.dispatch "proc2"(%arg0) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {nopass}
+  // CHECK: fir.dispatch "proc2"(%[[CLASS]]) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {nopass}
+
+  fir.dispatch "proc3"(%arg0, %arg1, %arg0) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, i32, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 1 : i32}
+  // CHECK: fir.dispatch "proc3"(%[[CLASS]], %[[INTARG]], %[[CLASS]]) : (!fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>, i32, !fir.class<!fir.type<dispatch_derived1{a:i32,b:i32}>>) -> () {pass_arg_pos = 1 : i32}
+  return
+}

diff  --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index ee384fc4a5af..6bf0612887f6 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -756,3 +756,19 @@ func.func @foo(%arg0: !fir.ref<!fir.array<30x!fir.type<t{c:!fir.array<20xi32>}>>
   return
 }
 func.func private @ifoo(!fir.ref<f32>) -> i32
+
+// -----
+
+func.func private @dispatch(%arg0: !fir.class<!fir.type<derived{a:i32,b:i32}>>) -> () {
+  // expected-error at +1 {{'fir.dispatch' op pass_arg_pos must be smaller than the number of operands}}
+  fir.dispatch "proc1"(%arg0, %arg0) : (!fir.class<!fir.type<derived{a:i32,b:i32}>>, !fir.class<!fir.type<derived{a:i32,b:i32}>>) -> () {pass_arg_pos = 1 : i32}
+  return
+}
+
+// -----
+
+func.func private @dispatch(%arg0: !fir.class<!fir.type<derived{a:i32,b:i32}>>, %arg1: i32) -> () {
+  // expected-error at +1 {{'fir.dispatch' op pass_arg_pos must be a polymorphic operand}}
+  fir.dispatch "proc1"(%arg0, %arg0, %arg1) : (!fir.class<!fir.type<derived{a:i32,b:i32}>>, !fir.class<!fir.type<derived{a:i32,b:i32}>>, i32) -> () {pass_arg_pos = 1 : i32}
+  return
+}


        


More information about the flang-commits mailing list