[flang-commits] [flang] 69193c6 - [flang] Generate DOT_PRODUCT runtime call based on the result type.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Aug 31 15:27:19 PDT 2022


Author: Slava Zakharin
Date: 2022-08-31T15:20:12-07:00
New Revision: 69193c6cd73bb43c4c1970914b962e8e568dd092

URL: https://github.com/llvm/llvm-project/commit/69193c6cd73bb43c4c1970914b962e8e568dd092
DIFF: https://github.com/llvm/llvm-project/commit/69193c6cd73bb43c4c1970914b962e8e568dd092.diff

LOG: [flang] Generate DOT_PRODUCT runtime call based on the result type.

We used to select the runtime function based on the first argument's
type, which was not correct behavior. The selection is done using
the result type now.

Differential Revision: https://reviews.llvm.org/D133032

Added: 
    

Modified: 
    flang/lib/Lower/IntrinsicCall.cpp
    flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
    flang/test/Lower/Intrinsics/dot_product.f90
    flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp
index a8ebe10522968..c5a1121985257 100644
--- a/flang/lib/Lower/IntrinsicCall.cpp
+++ b/flang/lib/Lower/IntrinsicCall.cpp
@@ -231,18 +231,18 @@ genDotProd(FN func, mlir::Type resultType, fir::FirOpBuilder &builder,
   // Handle required vector arguments
   mlir::Value vectorA = fir::getBase(args[0]);
   mlir::Value vectorB = fir::getBase(args[1]);
+  // Result type is used for picking appropriate runtime function.
+  mlir::Type eleTy = resultType;
 
-  mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(vectorA.getType())
-                         .cast<fir::SequenceType>()
-                         .getEleTy();
   if (fir::isa_complex(eleTy)) {
     mlir::Value result = builder.createTemporary(loc, eleTy);
     func(builder, loc, vectorA, vectorB, result);
     return builder.create<fir::LoadOp>(loc, result);
   }
 
-  auto resultBox = builder.create<fir::AbsentOp>(
-      loc, fir::BoxType::get(builder.getI1Type()));
+  // This operation is only used to pass the result type
+  // information to the DotProduct generator.
+  auto resultBox = builder.create<fir::AbsentOp>(loc, fir::BoxType::get(eleTy));
   return func(builder, loc, vectorA, vectorB, resultBox);
 }
 

diff  --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index 7c6d187f59f31..0fa035246b4cf 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -799,9 +799,10 @@ mlir::Value fir::runtime::genDotProduct(fir::FirOpBuilder &builder,
                                         mlir::Value vectorBBox,
                                         mlir::Value resultBox) {
   mlir::func::FuncOp func;
-  auto ty = vectorABox.getType();
-  auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
-  auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
+  // For complex data types, resultBox is !fir.ref<!fir.complex<N>>,
+  // otherwise it is !fir.box<T>.
+  auto ty = resultBox.getType();
+  auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
 
   if (eleTy.isF16() || eleTy.isBF16())
     TODO(loc, "half-precision DOTPRODUCT");

diff  --git a/flang/test/Lower/Intrinsics/dot_product.f90 b/flang/test/Lower/Intrinsics/dot_product.f90
index 3b4c77a853b58..42843dc5115ae 100644
--- a/flang/test/Lower/Intrinsics/dot_product.f90
+++ b/flang/test/Lower/Intrinsics/dot_product.f90
@@ -245,3 +245,46 @@ subroutine dot_prod_logical (x, y, z)
   ! CHECK-DAG: %[[res:.*]] = fir.call @_FortranADotProductLogical(%[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> i1
   z = dot_product(x,y)
 end subroutine
+
+! CHECK-LABEL: dot_product_mixed_int_real
+! CHECK-SAME: %[[x:arg0]]: !fir.box<!fir.array<?xi32>>
+! CHECK-SAME: %[[y:arg1]]: !fir.box<!fir.array<?xf32>>
+! CHECK-SAME: %[[z:arg2]]: !fir.box<!fir.array<?xf32>>
+subroutine dot_product_mixed_int_real(x, y, z)
+  integer, dimension(1:) :: x
+  real, dimension(1:) :: y, z
+  ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[res:.*]] = fir.call @_FortranADotProductReal4(%[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32
+  z = dot_product(x,y)
+end subroutine
+
+! CHECK-LABEL: dot_product_mixed_int_complex
+! CHECK-SAME: %[[x:arg0]]: !fir.box<!fir.array<?xi32>>
+! CHECK-SAME: %[[y:arg1]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+! CHECK-SAME: %[[z:arg2]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+subroutine dot_product_mixed_int_complex(x, y, z)
+  integer, dimension(1:) :: x
+  complex, dimension(1:) :: y, z
+  ! CHECK-DAG: %[[res:.*]] = fir.alloca !fir.complex<4>
+  ! CHECK-DAG: %[[res_conv:.*]] = fir.convert %[[res]] : (!fir.ref<!fir.complex<4>>) -> !fir.ref<complex<f32>>
+  ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<none>
+  ! CHECK-DAG: fir.call @_FortranACppDotProductComplex4(%[[res_conv]], %[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+  z = dot_product(x,y)
+end subroutine
+
+! CHECK-LABEL: dot_product_mixed_real_complex
+! CHECK-SAME: %[[x:arg0]]: !fir.box<!fir.array<?xf32>>
+! CHECK-SAME: %[[y:arg1]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+! CHECK-SAME: %[[z:arg2]]: !fir.box<!fir.array<?x!fir.complex<4>>>
+subroutine dot_product_mixed_real_complex(x, y, z)
+  real, dimension(1:) :: x
+  complex, dimension(1:) :: y, z
+  ! CHECK-DAG: %[[res:.*]] = fir.alloca !fir.complex<4>
+  ! CHECK-DAG: %[[res_conv:.*]] = fir.convert %[[res]] : (!fir.ref<!fir.complex<4>>) -> !fir.ref<complex<f32>>
+  ! CHECK-DAG: %[[x_conv:.*]] = fir.convert %[[x]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+  ! CHECK-DAG: %[[y_conv:.*]] = fir.convert %[[y]] : (!fir.box<!fir.array<?x!fir.complex<4>>>) -> !fir.box<none>
+  ! CHECK-DAG: fir.call @_FortranACppDotProductComplex4(%[[res_conv]], %[[x_conv]], %[[y_conv]], %{{[0-9]+}}, %{{.*}}) : (!fir.ref<complex<f32>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
+  z = dot_product(x,y)
+end subroutine

diff  --git a/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp b/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp
index daa2082c8d9ff..fea28d0c7e663 100644
--- a/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp
+++ b/flang/unittests/Optimizer/Builder/Runtime/ReductionTest.cpp
@@ -202,7 +202,8 @@ void testGenDotProduct(
   mlir::Type refSeqTy = fir::ReferenceType::get(seqTy);
   mlir::Value a = builder.create<fir::UndefOp>(loc, refSeqTy);
   mlir::Value b = builder.create<fir::UndefOp>(loc, refSeqTy);
-  mlir::Value result = builder.create<fir::UndefOp>(loc, seqTy);
+  mlir::Value result =
+      builder.create<fir::UndefOp>(loc, fir::ReferenceType::get(eleTy));
   mlir::Value prod = fir::runtime::genDotProduct(builder, loc, a, b, result);
   if (fir::isa_complex(eleTy))
     checkCallOpFromResultBox(result, fctName, 3);


        


More information about the flang-commits mailing list