[flang-commits] [flang] 9471637 - [flang][hlfir] Add hlfir.dot_product intrinsic

Jacob Crawley via flang-commits flang-commits at lists.llvm.org
Wed Jun 7 04:45:55 PDT 2023

Author: Jacob Crawley
Date: 2023-06-07T11:42:14Z
New Revision: 9471637f3c7c67b2c6a5badac9df5cd48b605dc7

URL: https://github.com/llvm/llvm-project/commit/9471637f3c7c67b2c6a5badac9df5cd48b605dc7
DIFF: https://github.com/llvm/llvm-project/commit/9471637f3c7c67b2c6a5badac9df5cd48b605dc7.diff

LOG: [flang][hlfir] Add hlfir.dot_product intrinsic

Adds a new HLFIR operation for the DOT_PRODUCT intrinsic according to
the design set out in flang/docs/HighLevel.md. This patch includes all
the necessary changes to create a new HLFIR operation and lower it into
the fir runtime call.

Differential Revision: https://reviews.llvm.org/D152252




diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 142a70c639127..2dd85a2c5c181 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -430,6 +430,29 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   let hasVerifier = 1;
+def hlifr_DotProductOp : hlfir_Op<"dot_product",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "DOT_PRODUCT transformational intrinsic";
+  let description = [{
+    Dot product of two vectors
+  }];
+  let arguments = (ins
+    AnyFortranNumericalOrLogicalArrayObject:$lhs,
+    AnyFortranNumericalOrLogicalArrayObject:$rhs,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+  let results = (outs AnyFortranValue);
+  let assemblyFormat = [{
+    $lhs $rhs attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
 def hlfir_MatmulOp : hlfir_Op<"matmul",
     [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
   let summary = "MATMUL transformational intrinsic";

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 66af19b94e78d..604291cdbec6d 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -1488,6 +1488,15 @@ genHLFIRIntrinsicRefCore(PreparedActualArguments &loweredActuals,
     return buildReductionIntrinsic(loweredActuals, loc, builder, callContext,
                                    buildAllOperation, false);
+  if (intrinsicName == "dot_product") {
+    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
+    mlir::Type resultTy =
+        computeResultType(operands[0], *callContext.resultType);
+    hlfir::DotProductOp dotProductOp = builder.create<hlfir::DotProductOp>(
+        loc, resultTy, operands[0], operands[1]);
+    return {hlfir::EntityWithAttributes{dotProductOp.getResult()}};
+  }
   // TODO add hlfir operations for other transformational intrinsics here

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index adf8b72993e4c..c094b66338512 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -669,6 +669,53 @@ mlir::LogicalResult hlfir::SumOp::verify() {
   return verifyNumericalReductionOp<hlfir::SumOp *>(this);
+// DotProductOp
+mlir::LogicalResult hlfir::DotProductOp::verify() {
+  mlir::Value lhs = getLhs();
+  mlir::Value rhs = getRhs();
+  fir::SequenceType lhsTy =
+      hlfir::getFortranElementOrSequenceType(lhs.getType())
+          .cast<fir::SequenceType>();
+  fir::SequenceType rhsTy =
+      hlfir::getFortranElementOrSequenceType(rhs.getType())
+          .cast<fir::SequenceType>();
+  llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  std::size_t lhsRank = lhsShape.size();
+  std::size_t rhsRank = rhsShape.size();
+  mlir::Type lhsEleTy = lhsTy.getEleTy();
+  mlir::Type rhsEleTy = rhsTy.getEleTy();
+  mlir::Type resultTy = getResult().getType();
+  if ((lhsRank != 1) || (rhsRank != 1))
+    return emitOpError("both arrays must have rank 1");
+  int64_t lhsSize = lhsShape[0];
+  int64_t rhsSize = rhsShape[0];
+  if (lhsSize != rhsSize)
+    return emitOpError("both arrays must have the same size");
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(rhsEleTy))
+    return emitOpError("if one array is logical, so should the other be");
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(resultTy))
+    return emitOpError("the result type should be a logical only if the "
+                       "argument types are logical");
+  if (!hlfir::isFortranScalarNumericalType(resultTy) &&
+      !mlir::isa<fir::LogicalType>(resultTy))
+    return emitOpError(
+        "the result must be of scalar numerical or logical type");
+  return mlir::success();
 // MatmulOp

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 0ffb2ac9ca0cb..ac7eafc862fd6 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -277,6 +277,38 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
+struct DotProductOpConversion
+    : public HlfirIntrinsicConversion<hlfir::DotProductOp> {
+  using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::DotProductOp dotProduct,
+                  mlir::PatternRewriter &rewriter) const override {
+    fir::KindMapping kindMapping{rewriter.getContext()};
+    fir::FirOpBuilder builder{rewriter, kindMapping};
+    const mlir::Location &loc = dotProduct->getLoc();
+    mlir::Value lhs = dotProduct.getLhs();
+    mlir::Value rhs = dotProduct.getRhs();
+    llvm::SmallVector<IntrinsicArgument, 2> inArgs;
+    inArgs.push_back({lhs, lhs.getType()});
+    inArgs.push_back({rhs, rhs.getType()});
+    auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product");
+    llvm::SmallVector<fir::ExtendedValue, 2> args =
+        lowerArguments(dotProduct, inArgs, rewriter, argLowering);
+    mlir::Type scalarResultType =
+        hlfir::getFortranElementType(dotProduct.getType());
+    auto [resultExv, mustBeFreed] = fir::genIntrinsicCall(
+        builder, loc, "dot_product", scalarResultType, args);
+    processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter);
+    return mlir::success();
+  }
 class TransposeOpConversion
     : public HlfirIntrinsicConversion<hlfir::TransposeOp> {
   using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
@@ -356,14 +388,15 @@ class LowerHLFIRIntrinsics
     mlir::RewritePatternSet patterns(context);
     patterns.insert<MatmulOpConversion, MatmulTransposeOpConversion,
                     AllOpConversion, AnyOpConversion, SumOpConversion,
-                    ProductOpConversion, TransposeOpConversion>(context);
+                    ProductOpConversion, TransposeOpConversion,
+                    DotProductOpConversion>(context);
     mlir::ConversionTarget target(*context);
     target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
                            mlir::func::FuncDialect, fir::FIROpsDialect,
     target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
                         hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
-                        hlfir::AllOp>();
+                        hlfir::AllOp, hlfir::DotProductOp>();
         [](mlir::Operation *) { return true; });
     if (mlir::failed(

diff  --git a/flang/test/HLFIR/dot_product-lowering.fir b/flang/test/HLFIR/dot_product-lowering.fir
new file mode 100644
index 0000000000000..efc113087fea0
--- /dev/null
+++ b/flang/test/HLFIR/dot_product-lowering.fir
@@ -0,0 +1,80 @@
+// Test hlfir.matmul operation lowering to fir runtime call
+// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s
+func.func @_QPdot_product1(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "lhs"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "rhs"}, %arg2: !fir.ref<i32> {fir.bindc_name = "res"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product1Elhs"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product1Eres"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product1Erhs"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+  %3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>) -> i32
+  hlfir.assign %3 to %1#0 : i32, !fir.ref<i32>
+  return
+// CHECK-LABEL: func.func @_QPdot_product1(
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "lhs"}
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "rhs"}
+// CHECK:           %[[ARG2:.*]]: !fir.ref<i32> {fir.bindc_name = "res"}
+// CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]]
+// CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]]
+// CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
+// CHECK-DAG:     %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
+// CHECK-DAG:     %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranADotProductInteger4(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
+// CHECK-NEXT:    hlfir.assign %[[NONE]] to %[[RES_VAR]]#0 : i32, !fir.ref<i32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+func.func @_QPdot_product2(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.bindc_name = "lhs"}, %arg1: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.bindc_name = "rhs"}, %arg2: !fir.ref<!fir.logical<4>> {fir.bindc_name = "res"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFdot_product2Elhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
+  %1:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product2Eres"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFdot_product2Erhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
+  %3 = hlfir.dot_product %0#0 %2#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+  hlfir.assign %3 to %1#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+  return
+// CHECK-LABEL: func.func @_QPdot_product2(
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.bindc_name = "lhs"}
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.bindc_name = "rhs"}
+// CHECK:           %[[ARG2:.*]]: !fir.ref<!fir.logical<4>> {fir.bindc_name = "res"}
+// CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]]
+// CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]]
+// CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
+// CHECK-DAG:     %[[LHS_ARG:.*]] = fir.convert %[[LHS_VAR]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
+// CHECK-DAG:     %[[RHS_ARG:.*]] = fir.convert %[[RHS_VAR]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranADotProductLogical(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
+// CHECK-NEXT:    hlfir.assign %[[NONE]] to %[[RES_VAR]]#0 : i1, !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+ func.func @_QPdot_product3(%arg0: !fir.ref<!fir.array<5xi32>> {fir.bindc_name = "lhs"}, %arg1: !fir.ref<!fir.array<5xi32>> {fir.bindc_name = "rhs"}, %arg2: !fir.ref<i32> {fir.bindc_name = "res"}) {
+  %c5 = arith.constant 5 : index
+  %0 = fir.shape %c5 : (index) -> !fir.shape<1>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFdot_product3Elhs"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+  %2:2 = hlfir.declare %arg2 {uniq_name = "_QFdot_product3Eres"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %c5_0 = arith.constant 5 : index
+  %3 = fir.shape %c5_0 : (index) -> !fir.shape<1>
+  %4:2 = hlfir.declare %arg1(%3) {uniq_name = "_QFdot_product3Erhs"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+  %5 = hlfir.dot_product %1#0 %4#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>) -> i32
+  hlfir.assign %5 to %2#0 : i32, !fir.ref<i32>
+  return
+// CHECK-LABEL: func.func @_QPdot_product3(
+// CHECK:           %[[ARG0:.*]]: !fir.ref<!fir.array<5xi32>> {fir.bindc_name = "lhs"}
+// CHECK:           %[[ARG1:.*]]: !fir.ref<!fir.array<5xi32>> {fir.bindc_name = "rhs"}
+// CHECK:           %[[ARG2:.*]]: !fir.ref<i32> {fir.bindc_name = "res"}
+// CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[ARG0]]
+// CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[ARG1]]
+// CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
+// CHECK-DAG:     %[[LHS_BOX:.*]] = fir.embox %[[LHS_VAR]]#1
+// CHECK-DAG:     %[[RHS_BOX:.*]] = fir.embox %[[RHS_VAR]]#1
+// CHECK-DAG:     %[[LHS_ARG:.*]] = fir.convert %[[LHS_BOX]] : (!fir.box<!fir.array<5xi32>>) -> !fir.box<none>
+// CHECK-DAG:     %[[RHS_ARG:.*]] = fir.convert %[[RHS_BOX]] : (!fir.box<!fir.array<5xi32>>) -> !fir.box<none>
+// CHECK:         %[[NONE:.*]] = fir.call @_FortranADotProductInteger4(%[[LHS_ARG]], %[[RHS_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]])
+// CHECK-NEXT:    hlfir.assign %[[NONE]] to %[[RES_VAR]]#0 : i32, !fir.ref<i32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
\ No newline at end of file

diff  --git a/flang/test/HLFIR/dot_product.fir b/flang/test/HLFIR/dot_product.fir
new file mode 100644
index 0000000000000..78b293eeff5c8
--- /dev/null
+++ b/flang/test/HLFIR/dot_product.fir
@@ -0,0 +1,51 @@
+// Test hlfir.dot_product operation parse, verify (no errors), and unparse
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+// arguments are expressions of known shape
+func.func @dot_product0(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2xi32>) {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> i32
+  return
+// CHECK-LABEL: func.func @dot_product0
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2xi32>
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> i32
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// arguments are expressions of assumed shape
+func.func @dot_product1(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?xi32>) -> i32
+  return
+// CHECK-LABEL: func.func @dot_product1
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?xi32>
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?xi32>, !hlfir.expr<?xi32>) -> i32
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// arguments are boxed arrays
+func.func @dot_product2(%arg0: !fir.box<!fir.array<2xi32>>, %arg1: !fir.box<!fir.array<2xi32>>) {
+  %res = hlfir.dot_product %arg0 %arg1 : (!fir.box<!fir.array<2xi32>>, !fir.box<!fir.array<2xi32>>) -> i32
+  return
+// CHECK-LABEL: func.func @dot_product2
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<2xi32>>,
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<2xi32>>
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!fir.box<!fir.array<2xi32>>, !fir.box<!fir.array<2xi32>>) -> i32
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// arguments are logical
+func.func @dot_product3(%arg0: !fir.box<!fir.array<2x!fir.logical<4>>>, %arg1: !fir.box<!fir.array<2x!fir.logical<4>>>) {
+  %res = hlfir.dot_product %arg0 %arg1 : (!fir.box<!fir.array<2x!fir.logical<4>>>, !fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
+  return
+// CHECK-LABEL: func.func @dot_product3
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>,
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.dot_product %[[ARG0]] %[[ARG1]] : (!fir.box<!fir.array<2x!fir.logical<4>>>, !fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
\ No newline at end of file

diff  --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index 8dc5679346bc1..01bccdf80428b 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -508,6 +508,41 @@ func.func @bad_matmul8(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+// -----
+func.func @bad_dot_product1(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+  // expected-error at +1 {{'hlfir.dot_product' op both arrays must have rank 1}}
+  %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x3xi32>) -> i32
+  return
+// -----
+func.func @bad_dot_product2(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<3xi32>) {
+  // expected-error at +1 {{'hlfir.dot_product' op both arrays must have the same size}}
+  %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<3xi32>) -> i32
+  return
+// -----
+func.func @bad_dot_product3(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x!fir.logical<4>>) {
+  // expected-error at +1 {{'hlfir.dot_product' op if one array is logical, so should the other be}}
+  %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x!fir.logical<4>>) -> i32
+  return
+// -----
+func.func @bad_dot_product4(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2xi32>) {
+  // expected-error at +1 {{'hlfir.dot_product' op the result type should be a logical only if the argument types are logical}}
+  %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> !fir.logical<4>
+  return
+// -----
+func.func @bad_dot_product5(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2xi32>) {
+  // expected-error at +1 {{'hlfir.dot_product' op the result must be of scalar numerical or logical type}}
+  %0 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<i32>
+  return
 // -----
 func.func @bad_transpose1(%arg0: !hlfir.expr<2xi32>) {
   // expected-error at +1 {{'hlfir.transpose' op input and output arrays should have rank 2}}

diff  --git a/flang/test/Lower/HLFIR/dot_product.f90 b/flang/test/Lower/HLFIR/dot_product.f90
new file mode 100644
index 0000000000000..94ee46d63a288
--- /dev/null
+++ b/flang/test/Lower/HLFIR/dot_product.f90
@@ -0,0 +1,53 @@
+! Test lowering of DOT_PRODUCT intrinsic to HLFIR
+! RUN: bbc -emit-hlfir -o - %s 2>&1 | FileCheck %s
+! dot product with numerical arguments
+subroutine dot_product1(lhs, rhs, res)
+  integer lhs(:), rhs(:), res
+  res = DOT_PRODUCT(lhs,rhs)
+end subroutine
+! CHECK-LABEL: func.func @_QPdot_product1
+! CHECK:           %[[LHS:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "lhs"}
+! CHECK:           %[[RHS:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "rhs"}
+! CHECK:           %[[RES:.*]]: !fir.ref<i32> {fir.bindc_name = "res"}
+! CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]]
+! CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]]
+! CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.dot_product %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>) -> i32
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : i32, !fir.ref<i32>
+! CHECK-NEXT:    return
+! dot product with logical arguments
+subroutine dot_product2(lhs, rhs, res)
+  logical lhs(:), rhs(:), res
+  res = DOT_PRODUCT(lhs,rhs)
+end subroutine
+! CHECK-LABEL: func.func @_QPdot_product2
+! CHECK:           %[[LHS:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.bindc_name = "lhs"}
+! CHECK:           %[[RHS:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.bindc_name = "rhs"}
+! CHECK:           %[[RES:.*]]: !fir.ref<!fir.logical<4>> {fir.bindc_name = "res"}
+! CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]]
+! CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]]
+! CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.dot_product %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+! CHECK-NEXT:    return
+! arguments are of known shape
+subroutine dot_product3(lhs, rhs, res)
+  integer lhs(5), rhs(5), res
+  res = DOT_PRODUCT(lhs,rhs)
+end subroutine
+! CHECK-LABEL: func.func @_QPdot_product3
+! CHECK:           %[[LHS:.*]]: !fir.ref<!fir.array<5xi32>> {fir.bindc_name = "lhs"}
+! CHECK:           %[[RHS:.*]]: !fir.ref<!fir.array<5xi32>> {fir.bindc_name = "rhs"}
+! CHECK:           %[[RES:.*]]: !fir.ref<i32> {fir.bindc_name = "res"}
+! CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]]
+! CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]]
+! CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.dot_product %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>) -> i32
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : i32, !fir.ref<i32>
+! CHECK-NEXT:    return


More information about the flang-commits mailing list