[flang-commits] [flang] [flang] Do not inline SUM with invalid DIM argument. (PR #118911)
via flang-commits
flang-commits at lists.llvm.org
Thu Dec 5 18:11:33 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.
---
Full diff: https://github.com/llvm/llvm-project/pull/118911.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+13-3)
- (modified) flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir (+18)
``````````diff
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0c34c8221aeda6..ace63a970db932 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
- assert(dimVal > 0 && "DIM must be present and a positive constant");
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
@@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
+ assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
+ "DIM must be present and a positive constant not exceeding "
+ "the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();
@@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
// would avoid creating a temporary for the elemental array expression.
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (mlir::Value dim = sum.getDim()) {
- if (fir::getIntIfConstant(dim)) {
+ if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
- return false;
+ fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+ hlfir::getFortranElementOrSequenceType(
+ sum.getArray().getType()));
+ if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
+ // Ignore SUMs with illegal DIM values.
+ // They may appear in dead code,
+ // and they do not have to be converted.
+ return false;
+ }
}
}
}
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 703b6673154f3f..313e54d5d0c4af 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
// CHECK: return
// CHECK: }
+
+// negative: invalid dim==0
+func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
+ %cst = arith.constant 0 : i32
+ %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+ return
+}
+// CHECK-LABEL: func.func @sum_invalid_dim0(
+// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+
+// negative: invalid dim>rank
+func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
+ %cst = arith.constant 3 : i32
+ %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+ return
+}
+// CHECK-LABEL: func.func @sum_invalid_dim_big(
+// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/118911
More information about the flang-commits
mailing list