[flang-commits] [flang] 56eda98 - [flang] Handle mixed types in DOT_PRODUCT simplification.
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Mon Aug 15 09:10:45 PDT 2022
Author: Slava Zakharin
Date: 2022-08-15T09:03:38-07:00
New Revision: 56eda98f0cbd9b1db40a5b52b8f57a05b8bb4588
URL: https://github.com/llvm/llvm-project/commit/56eda98f0cbd9b1db40a5b52b8f57a05b8bb4588
DIFF: https://github.com/llvm/llvm-project/commit/56eda98f0cbd9b1db40a5b52b8f57a05b8bb4588.diff
LOG: [flang] Handle mixed types in DOT_PRODUCT simplification.
Fortran runtime supports mixed types by casting the loaded values
to the result type, so DOT_PRODUCT simplification has to do the same.
Differential Revision: https://reviews.llvm.org/D131726
Added:
Modified:
flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
flang/test/Transforms/simplifyintrinsics.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 4f5f7ad4d571b..e0bc108634d68 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -31,11 +31,14 @@
#include "flang/Optimizer/Support/FIRContext.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "flang-simplify-intrinsics"
@@ -159,8 +162,13 @@ static mlir::FunctionType genFortranADotType(fir::FirOpBuilder &builder,
/// with signature provided by \p funcOp. The caller is responsible
/// for saving/restoring the original insertion point of \p builder.
/// \p funcOp is expected to be empty on entry to this function.
+/// \p arg1ElementTy and \p arg2ElementTy specify elements types
+/// of the underlying array objects - they are used to generate proper
+/// element accesses.
static void genFortranADotBody(fir::FirOpBuilder &builder,
- mlir::func::FuncOp &funcOp) {
+ mlir::func::FuncOp &funcOp,
+ mlir::Type arg1ElementTy,
+ mlir::Type arg2ElementTy) {
// function FortranADotProduct<T>_simplified(arr1, arr2)
// T, dimension(:) :: arr1, arr2
// T product = 0
@@ -171,14 +179,15 @@ static void genFortranADotBody(fir::FirOpBuilder &builder,
// FortranADotProduct<T>_simplified = product
// end function FortranADotProduct<T>_simplified
auto loc = mlir::UnknownLoc::get(builder.getContext());
- mlir::Type elementType = funcOp.getResultTypes()[0];
+ mlir::Type resultElementType = funcOp.getResultTypes()[0];
builder.setInsertionPointToEnd(funcOp.addEntryBlock());
mlir::IndexType idxTy = builder.getIndexType();
- mlir::Value zero = elementType.isa<mlir::FloatType>()
- ? builder.createRealConstant(loc, elementType, 0.0)
- : builder.createIntegerConstant(loc, elementType, 0);
+ mlir::Value zero =
+ resultElementType.isa<mlir::FloatType>()
+ ? builder.createRealConstant(loc, resultElementType, 0.0)
+ : builder.createIntegerConstant(loc, resultElementType, 0);
mlir::Block::BlockArgListType args = funcOp.front().getArguments();
mlir::Value arg1 = args[0];
@@ -187,10 +196,12 @@ static void genFortranADotBody(fir::FirOpBuilder &builder,
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()};
- mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
- mlir::Type boxArrTy = fir::BoxType::get(arrTy);
- mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy, arg1);
- mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy, arg2);
+ mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
+ mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
+ mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
+ mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
+ mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
+ mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
// This version takes the loop trip count from the first argument.
// If the first argument's box has unknown (at compilation time)
// extent, then it may be better to take the extent from the second
@@ -216,19 +227,25 @@ static void genFortranADotBody(fir::FirOpBuilder &builder,
mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loop.getBody());
- mlir::Type eleRefTy = builder.getRefType(elementType);
+ mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
mlir::Value index = loop.getInductionVar();
mlir::Value addr1 =
- builder.create<fir::CoordinateOp>(loc, eleRefTy, array1, index);
+ builder.create<fir::CoordinateOp>(loc, eleRef1Ty, array1, index);
mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
+ // Convert to the result type.
+ elem1 = builder.create<fir::ConvertOp>(loc, resultElementType, elem1);
+
+ mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
mlir::Value addr2 =
- builder.create<fir::CoordinateOp>(loc, eleRefTy, array2, index);
+ builder.create<fir::CoordinateOp>(loc, eleRef2Ty, array2, index);
mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
+ // Convert to the result type.
+ elem2 = builder.create<fir::ConvertOp>(loc, resultElementType, elem2);
- if (elementType.isa<mlir::FloatType>())
+ if (resultElementType.isa<mlir::FloatType>())
sumVal = builder.create<mlir::arith::AddFOp>(
loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
- else if (elementType.isa<mlir::IntegerType>())
+ else if (resultElementType.isa<mlir::IntegerType>())
sumVal = builder.create<mlir::arith::AddIOp>(
loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
else
@@ -317,6 +334,29 @@ static unsigned getDimCount(mlir::Value val) {
return 0;
}
+/// Given the call operation's box argument \p val, discover
+/// the element type of the underlying array object.
+/// \returns the element type or llvm::None if the type cannot
+/// be reliably found.
+/// We expect that the argument is a result of fir.convert
+/// with the destination type of !fir.box<none>.
+static llvm::Optional<mlir::Type> getArgElementType(mlir::Value val) {
+ mlir::Operation *defOp;
+ do {
+ defOp = val.getDefiningOp();
+ // Analyze only sequences of convert operations.
+ if (!mlir::isa<fir::ConvertOp>(defOp))
+ return llvm::None;
+ val = defOp->getOperand(0);
+ // The convert operation is expected to convert from one
+ // box type to another box type.
+ auto boxType = val.getType().cast<fir::BoxType>();
+ auto elementType = fir::unwrapSeqOrBoxedSeqType(boxType);
+ if (!elementType.isa<mlir::NoneType>())
+ return elementType;
+ } while (true);
+}
+
void SimplifyIntrinsicsPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::ModuleOp module = getOperation();
@@ -380,11 +420,42 @@ void SimplifyIntrinsicsPass::runOnOperation() {
if (!type.isa<mlir::FloatType>() && !type.isa<mlir::IntegerType>())
return;
+ // Try to find the element types of the boxed arguments.
+ auto arg1Type = getArgElementType(v1);
+ auto arg2Type = getArgElementType(v2);
+
+ if (!arg1Type || !arg2Type)
+ return;
+
+ // Support only floating point and integer arguments
+ // now (e.g. logical is skipped here).
+ if (!arg1Type->isa<mlir::FloatType>() &&
+ !arg1Type->isa<mlir::IntegerType>())
+ return;
+ if (!arg2Type->isa<mlir::FloatType>() &&
+ !arg2Type->isa<mlir::IntegerType>())
+ return;
+
auto typeGenerator = [&type](fir::FirOpBuilder &builder) {
return genFortranADotType(builder, type);
};
+ auto bodyGenerator = [&arg1Type,
+ &arg2Type](fir::FirOpBuilder &builder,
+ mlir::func::FuncOp &funcOp) {
+ genFortranADotBody(builder, funcOp, *arg1Type, *arg2Type);
+ };
+
+ // Suffix the function name with the element types
+ // of the arguments.
+ std::string typedFuncName(funcName);
+ llvm::raw_string_ostream nameOS(typedFuncName);
+ nameOS << "_";
+ arg1Type->print(nameOS);
+ nameOS << "_";
+ arg2Type->print(nameOS);
+
mlir::func::FuncOp newFunc = getOrCreateFunction(
- builder, funcName, typeGenerator, genFortranADotBody);
+ builder, typedFuncName, typeGenerator, bodyGenerator);
auto newCall = builder.create<fir::CallOp>(loc, newFunc,
mlir::ValueRange{v1, v2});
call->replaceAllUsesWith(newCall.getResults());
diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index 78df5d57f91d9..c2315fca02d10 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -344,15 +344,15 @@ fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
// CHECK: %[[RESLOC:.*]] = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
// CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
-// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductReal4_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box<none>, !fir.box<none>) -> f32
+// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductReal4_f32_f32_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box<none>, !fir.box<none>) -> f32
// CHECK: fir.store %[[RES]] to %[[RESLOC]] : !fir.ref<f32>
// CHECK: %[[RET:.*]] = fir.load %[[RESLOC]] : !fir.ref<f32>
// CHECK: return %[[RET]] : f32
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranADotProductReal4_simplified(
-// CHECK-SAME: %[[A:.*]]: !fir.box<none>,
-// CHECK-SAME: %[[B:.*]]: !fir.box<none>) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK-LABEL: func.func private @_FortranADotProductReal4_f32_f32_simplified(
+// CHECK-SAME: %[[A:.*]]: !fir.box<none>,
+// CHECK-SAME: %[[B:.*]]: !fir.box<none>) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[FZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[IZERO:.*]] = arith.constant 0 : index
// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf32>>
@@ -363,9 +363,11 @@ fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
// CHECK: %[[RES:.*]] = fir.do_loop %[[IDX:.*]] = %[[IZERO]] to %[[LEN]] step %[[IONE]] iter_args(%[[SUM:.*]] = %[[FZERO]]) -> (f32) {
// CHECK: %[[ALOC:.*]] = fir.coordinate_of %[[ACAST]], %[[IDX]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: %[[AVAL:.*]] = fir.load %[[ALOC]] : !fir.ref<f32>
+// CHECK: %[[AVALCAST:.*]] = fir.convert %[[AVAL]] : (f32) -> f32
// CHECK: %[[BLOC:.*]] = fir.coordinate_of %[[BCAST]], %[[IDX]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: %[[BVAL:.*]] = fir.load %[[BLOC]] : !fir.ref<f32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[AVAL]], %[[BVAL]] : f32
+// CHECK: %[[BVALCAST:.*]] = fir.convert %[[BVAL]] : (f32) -> f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[AVALCAST]], %[[BVALCAST]] : f32
// CHECK: %[[NEWSUM:.*]] = arith.addf %[[MUL]], %[[SUM]] : f32
// CHECK: fir.result %[[NEWSUM]] : f32
// CHECK: }
@@ -479,15 +481,15 @@ fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
// CHECK: %[[RESLOC:.*]] = fir.alloca i32 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
// CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box<!fir.array<?xi32>>) -> !fir.box<none>
-// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductInteger4_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box<none>, !fir.box<none>) -> i32
+// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductInteger4_i32_i32_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box<none>, !fir.box<none>) -> i32
// CHECK: fir.store %[[RES]] to %[[RESLOC]] : !fir.ref<i32>
// CHECK: %[[RET:.*]] = fir.load %[[RESLOC]] : !fir.ref<i32>
// CHECK: return %[[RET]] : i32
// CHECK: }
-// CHECK-LABEL: func.func private @_FortranADotProductInteger4_simplified(
-// CHECK-SAME: %[[A:.*]]: !fir.box<none>,
-// CHECK-SAME: %[[B:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK-LABEL: func.func private @_FortranADotProductInteger4_i32_i32_simplified(
+// CHECK-SAME: %[[A:.*]]: !fir.box<none>,
+// CHECK-SAME: %[[B:.*]]: !fir.box<none>) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
// CHECK: %[[I32ZERO:.*]] = arith.constant 0 : i32
// CHECK: %[[IZERO:.*]] = arith.constant 0 : index
// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
@@ -498,9 +500,11 @@ fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
// CHECK: %[[RES:.*]] = fir.do_loop %[[IDX:.*]] = %[[IZERO]] to %[[LEN]] step %[[IONE]] iter_args(%[[SUM:.*]] = %[[I32ZERO]]) -> (i32) {
// CHECK: %[[ALOC:.*]] = fir.coordinate_of %[[ACAST]], %[[IDX]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[AVAL:.*]] = fir.load %[[ALOC]] : !fir.ref<i32>
+// CHECK: %[[AVALCAST:.*]] = fir.convert %[[AVAL]] : (i32) -> i32
// CHECK: %[[BLOC:.*]] = fir.coordinate_of %[[BCAST]], %[[IDX]] : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
// CHECK: %[[BVAL:.*]] = fir.load %[[BLOC]] : !fir.ref<i32>
-// CHECK: %[[MUL:.*]] = arith.muli %[[AVAL]], %[[BVAL]] : i32
+// CHECK: %[[BVALCAST:.*]] = fir.convert %[[BVAL]] : (i32) -> i32
+// CHECK: %[[MUL:.*]] = arith.muli %[[AVALCAST]], %[[BVALCAST]] : i32
// CHECK: %[[NEWSUM:.*]] = arith.addi %[[MUL]], %[[SUM]] : i32
// CHECK: fir.result %[[NEWSUM]] : i32
// CHECK: }
@@ -587,3 +591,63 @@ fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
// CHECK-SAME: %[[A:.*]]: !fir.box<!fir.array<?xi64>> {fir.bindc_name = "a"},
// CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xi64>> {fir.bindc_name = "b"}) -> i64 {
// CHECK-NOT: call{{.*}}_FortranADotProductInteger8(
+
+// -----
+
+// Test mixed types, e.g. when _FortranADotProductReal8 is called
+// with <?xf64> and <?xf32> arguments. The loaded elements must be converted
+// to the result type REAL(8) before the computations.
+
+func.func @dot_f64_f32(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "b"}) -> f64 {
+ %0 = fir.alloca f64 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
+ %1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref<!fir.char<1,10>>
+ %c3_i32 = arith.constant 3 : i32
+ %2 = fir.convert %arg0 : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+ %3 = fir.convert %arg1 : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+ %4 = fir.convert %1 : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<i8>
+ %5 = fir.call @_FortranADotProductReal8(%2, %3, %4, %c3_i32) : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f64
+ fir.store %5 to %0 : !fir.ref<f64>
+ %6 = fir.load %0 : !fir.ref<f64>
+ return %6 : f64
+}
+func.func private @_FortranADotProductReal4(!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> f32 attributes {fir.runtime}
+fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> {
+ %0 = fir.string_lit "./dot.f90\00"(10) : !fir.char<1,10>
+ fir.has_value %0 : !fir.char<1,10>
+}
+
+// CHECK-LABEL: func.func @dot_f64_f32(
+// CHECK-SAME: %[[A:.*]]: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"},
+// CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "b"}) -> f64 {
+// CHECK: %[[RESLOC:.*]] = fir.alloca f64 {bindc_name = "dot", uniq_name = "_QFdotEdot"}
+// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box<!fir.array<?xf64>>) -> !fir.box<none>
+// CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box<!fir.array<?xf32>>) -> !fir.box<none>
+// CHECK: %[[RES:.*]] = fir.call @_FortranADotProductReal8_f64_f32_simplified(%[[ACAST]], %[[BCAST]]) : (!fir.box<none>, !fir.box<none>) -> f64
+// CHECK: fir.store %[[RES]] to %[[RESLOC]] : !fir.ref<f64>
+// CHECK: %[[RET:.*]] = fir.load %[[RESLOC]] : !fir.ref<f64>
+// CHECK: return %[[RET]] : f64
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_FortranADotProductReal8_f64_f32_simplified(
+// CHECK-SAME: %[[A:.*]]: !fir.box<none>,
+// CHECK-SAME: %[[B:.*]]: !fir.box<none>) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[FZERO:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[IZERO:.*]] = arith.constant 0 : index
+// CHECK: %[[ACAST:.*]] = fir.convert %[[A]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf64>>
+// CHECK: %[[BCAST:.*]] = fir.convert %[[B]] : (!fir.box<none>) -> !fir.box<!fir.array<?xf32>>
+// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ACAST]], %[[IZERO]] : (!fir.box<!fir.array<?xf64>>, index) -> (index, index, index)
+// CHECK: %[[IONE:.*]] = arith.constant 1 : index
+// CHECK: %[[LEN:.*]] = arith.subi %[[DIMS]]#1, %[[IONE]] : index
+// CHECK: %[[RES:.*]] = fir.do_loop %[[IDX:.*]] = %[[IZERO]] to %[[LEN]] step %[[IONE]] iter_args(%[[SUM:.*]] = %[[FZERO]]) -> (f64) {
+// CHECK: %[[ALOC:.*]] = fir.coordinate_of %[[ACAST]], %[[IDX]] : (!fir.box<!fir.array<?xf64>>, index) -> !fir.ref<f64>
+// CHECK: %[[AVAL:.*]] = fir.load %[[ALOC]] : !fir.ref<f64>
+// CHECK: %[[AVALCAST:.*]] = fir.convert %[[AVAL]] : (f64) -> f64
+// CHECK: %[[BLOC:.*]] = fir.coordinate_of %[[BCAST]], %[[IDX]] : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK: %[[BVAL:.*]] = fir.load %[[BLOC]] : !fir.ref<f32>
+// CHECK: %[[BVALCAST:.*]] = fir.convert %[[BVAL]] : (f32) -> f64
+// CHECK: %[[MUL:.*]] = arith.mulf %[[AVALCAST]], %[[BVALCAST]] : f64
+// CHECK: %[[NEWSUM:.*]] = arith.addf %[[MUL]], %[[SUM]] : f64
+// CHECK: fir.result %[[NEWSUM]] : f64
+// CHECK: }
+// CHECK: return %[[RES]] : f64
+// CHECK: }
More information about the flang-commits
mailing list