[flang-commits] [flang] 74adc3e - [flang][hlfir] fix missing conversion in transpose simplification
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Wed Jun 21 09:59:29 PDT 2023
Author: Tom Eccles
Date: 2023-06-21T16:54:58Z
New Revision: 74adc3e0ebfb42a48f02d8d7094d6848a37a21f5
URL: https://github.com/llvm/llvm-project/commit/74adc3e0ebfb42a48f02d8d7094d6848a37a21f5
DIFF: https://github.com/llvm/llvm-project/commit/74adc3e0ebfb42a48f02d8d7094d6848a37a21f5.diff
LOG: [flang][hlfir] fix missing conversion in transpose simplification
It seems just replacing the operation was not replacing all of the uses
when the types of the expression before and after this pass differ (due
to differing shape information). Now the shape information is always
kept the same.
This fixes https://github.com/llvm/llvm-project/issues/63399
Differential Revision: https://reviews.llvm.org/D153333
Added:
Modified:
flang/include/flang/Optimizer/Builder/HLFIRTools.h
flang/lib/Optimizer/Builder/HLFIRTools.cpp
flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
flang/test/HLFIR/simplify-hlfir-intrinsics.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index d20cdf76a1e73..9dceee4b37b4f 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -363,11 +363,13 @@ using ElementalKernelGenerator = std::function<hlfir::Entity(
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
/// Generate an hlfir.elementalOp given call back to generate the element
/// value at for each iteration.
+/// If exprType is specified, this will be the return type of the elemental op
hlfir::ElementalOp genElementalOp(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
- const ElementalKernelGenerator &genKernel);
+ const ElementalKernelGenerator &genKernel,
+ mlir::Type exprType = mlir::Type{});
/// Structure to describe a loop nest.
struct LoopNest {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index aad7b72ba01f7..a11c235510fe9 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -722,12 +722,12 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
isPolymorphic);
}
-hlfir::ElementalOp
-hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::Type elementType, mlir::Value shape,
- mlir::ValueRange typeParams,
- const ElementalKernelGenerator &genKernel) {
- mlir::Type exprType = getArrayExprType(elementType, shape, false);
+hlfir::ElementalOp hlfir::genElementalOp(
+ mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
+ mlir::Value shape, mlir::ValueRange typeParams,
+ const ElementalKernelGenerator &genKernel, mlir::Type exprType) {
+ if (!exprType)
+ exprType = getArrayExprType(elementType, shape, false);
auto elementalOp =
builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
auto insertPt = builder.saveInsertionPoint();
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 21df162265cda..f1c8d68960600 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -59,9 +59,16 @@ class TransposeAsElementalConversion
return val;
};
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
- loc, builder, elementType, resultShape, typeParams, genKernel);
+ loc, builder, elementType, resultShape, typeParams, genKernel,
+ transpose.getResult().getType());
- rewriter.replaceOp(transpose, elementalOp.getResult());
+ // it wouldn't be safe to replace block arguments with a
diff erent
+ // hlfir.expr type. Types can
diff er due to
diff ering amounts of shape
+ // information
+ assert(elementalOp.getResult().getType() ==
+ transpose.getResult().getType());
+
+ rewriter.replaceOp(transpose, elementalOp);
return mlir::success();
}
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir
index e75b6fba885ba..eac89a6423921 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir
@@ -93,3 +93,94 @@ func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
// CHECK: }
// CHECK: return
// CHECK: }
+
+// expr with multiple uses
+func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) {
+ %0 = hlfir.transpose %arg0 : (!hlfir.expr<2x2xf32>) -> !hlfir.expr<2x2xf32>
+ %1 = hlfir.shape_of %0 : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
+ %2 = hlfir.elemental %1 : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
+ ^bb0(%arg2: index, %arg3: index):
+ %3 = hlfir.apply %0, %arg2, %arg3 : (!hlfir.expr<2x2xf32>, index, index) -> f32
+ %4 = math.cos %3 fastmath<contract> : f32
+ hlfir.yield_element %4 : f32
+ }
+ hlfir.assign %2 to %arg1 realloc : !hlfir.expr<2x2xf32>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+ hlfir.destroy %2 : !hlfir.expr<2x2xf32>
+ hlfir.destroy %0 : !hlfir.expr<2x2xf32>
+ return
+}
+// CHECK-LABEL: func.func @transpose4(
+// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<2x2xf32>
+// CHECK-SAME: %[[ARG1:.*]]:
+// CHECK: %[[SHAPE0:.*]] = fir.shape
+// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
+// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
+// CHECK: %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
+// CHECK: hlfir.yield_element %[[ELE]] : f32
+// CHECK: }
+// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
+// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
+// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
+// CHECK: %[[ELE:.*]] = hlfir.apply %[[TRANSPOSE]], %[[I]], %[[J]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
+// CHECK: %[[COS_ELE:.*]] = math.cos %[[ELE]] fastmath<contract> : f32
+// CHECK: hlfir.yield_element %[[COS_ELE]] : f32
+// CHECK: }
+// CHECK: hlfir.assign %[[COS]] to %[[ARG1]] realloc
+// CHECK: hlfir.destroy %[[COS]] : !hlfir.expr<2x2xf32>
+// CHECK: hlfir.destroy %[[TRANSPOSE]] : !hlfir.expr<2x2xf32>
+// CHECK: return
+// CHECK: }
+
+// regression test
+func.func @transpose5(%arg0: !fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>> {fir.host_assoc}) attributes {fir.internal_proc} {
+ %0 = fir.address_of(@_QFEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
+ %1:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>)
+ %c0_i32 = arith.constant 0 : i32
+ %2 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+ %3 = fir.load %2 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+ %4 = fir.box_addr %3 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
+ %c0 = arith.constant 0 : index
+ %5:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+ %c1 = arith.constant 1 : index
+ %6:3 = fir.box_dims %3, %c1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+ %7 = fir.shape %5#1, %6#1 : (index, index) -> !fir.shape<2>
+ %8:2 = hlfir.declare %4(%7) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
+ %c1_i32 = arith.constant 1 : i32
+ %9 = fir.coordinate_of %arg0, %c1_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+ %10 = fir.load %9 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+ %11 = fir.box_addr %10 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
+ %c0_0 = arith.constant 0 : index
+ %12:3 = fir.box_dims %10, %c0_0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+ %c1_1 = arith.constant 1 : index
+ %13:3 = fir.box_dims %10, %c1_1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+ %14 = fir.shape %12#1, %13#1 : (index, index) -> !fir.shape<2>
+ %15:2 = hlfir.declare %11(%14) {uniq_name = "_QFEc"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
+ %16 = hlfir.transpose %8#0 : (!fir.ref<!fir.array<2x2xf64>>) -> !hlfir.expr<2x2xf64>
+ %17 = hlfir.shape_of %16 : (!hlfir.expr<2x2xf64>) -> !fir.shape<2>
+ %18 = hlfir.elemental %17 : (!fir.shape<2>) -> !hlfir.expr<?x?xf64> {
+ ^bb0(%arg1: index, %arg2: index):
+ %19 = hlfir.apply %16, %arg1, %arg2 : (!hlfir.expr<2x2xf64>, index, index) -> f64
+ %20 = math.cos %19 fastmath<contract> : f64
+ hlfir.yield_element %20 : f64
+ }
+ hlfir.assign %18 to %1#0 realloc : !hlfir.expr<?x?xf64>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
+ hlfir.destroy %18 : !hlfir.expr<?x?xf64>
+ hlfir.destroy %16 : !hlfir.expr<2x2xf64>
+ return
+}
+// CHECK-LABEL: func.func @transpose5(
+// ...
+// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0:.*]]
+// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
+// CHECK: %[[ELE:.*]] = hlfir.designate %[[ARRAY:.*]] (%[[J]], %[[I]])
+// CHECK: %[[LOAD:.*]] = fir.load %[[ELE]]
+// CHECK: hlfir.yield_element %[[LOAD]]
+// CHECK: }
+// CHECK: %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]]
+// CHECK: %[[COS:.*]] = hlfir.elemental %[[SHAPE1]]
+// ...
+// CHECK: hlfir.assign %[[COS]] to %{{.*}} realloc
+// CHECK: hlfir.destroy %[[COS]]
+// CHECK: hlfir.destroy %[[TRANSPOSE]]
+// CHECK: return
+// CHECK: }
More information about the flang-commits
mailing list