[flang-commits] [flang] [Flang] Allow Intrinsic simpification with min/maxloc dim and scalar result (PR #81619)

via flang-commits flang-commits at lists.llvm.org
Tue Feb 13 08:27:17 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

This makes an adjustment to the existing fir minloc/maxloc generation code to handle functions with a dim=1 that produce a scalar result. This should allow us to get the same benefits as the existing generated minmax reductions.

This is a recommit of #<!-- -->76194 with an extra alteration to the end of genRuntimeMinMaxlocBody to make sure we convert the output array to the correct type (a `box<heap<i32>>`, not `box<heap<array<1xi32>>>`) to prevent writing the wrong type of box into it. This still allocates the data as a `array<1xi32>`, converting it into a i32 assuming that is safe. An alternative would be to allocate the data as a i32 and change more of the accesses to it throughout genRuntimeMinMaxlocBody.

---
Full diff: https://github.com/llvm/llvm-project/pull/81619.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (+33-19) 
- (modified) flang/test/Transforms/simplifyintrinsics.fir (+64-7) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 86343e23c6e5db..1e1aa2a2de7c75 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -692,7 +692,7 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
                                     unsigned rank, int maskRank,
                                     mlir::Type elementType,
                                     mlir::Type maskElemType,
-                                    mlir::Type resultElemTy) {
+                                    mlir::Type resultElemTy, bool isDim) {
   auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
                       mlir::Type elementType) {
     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
@@ -877,16 +877,27 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
                             maskElemType, resultArr, maskRank == 0);
 
   // Store newly created output array to the reference passed in
-  fir::SequenceType::Shape resultShape(1, rank);
-  mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
-  mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
-  mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
-  mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
-  mlir::Value outputArr = builder.create<fir::ConvertOp>(
-      loc, outputRefTy, funcOp.front().getArgument(0));
-
-  // Store nearly created array to output array
-  builder.create<fir::StoreOp>(loc, resultArr, outputArr);
+  if (isDim) {
+    mlir::Type resultBoxTy =
+        fir::BoxType::get(fir::HeapType::get(resultElemTy));
+    mlir::Value outputArr = builder.create<fir::ConvertOp>(
+        loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0));
+    mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
+        loc, fir::HeapType::get(resultElemTy), resultArrInit);
+    mlir::Value resultBox =
+        builder.create<fir::EmboxOp>(loc, resultBoxTy, resultArrScalar);
+    builder.create<fir::StoreOp>(loc, resultBox, outputArr);
+  } else {
+    fir::SequenceType::Shape resultShape(1, rank);
+    mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
+    mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
+    mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
+    mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
+    mlir::Value outputArr = builder.create<fir::ConvertOp>(
+        loc, outputRefTy, funcOp.front().getArgument(0));
+    builder.create<fir::StoreOp>(loc, resultArr, outputArr);
+  }
+
   builder.create<mlir::func::ReturnOp>(loc);
 }
 
@@ -1165,11 +1176,14 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
 
   mlir::Operation::operand_range args = call.getArgs();
 
-  mlir::Value back = args[6];
+  mlir::SymbolRefAttr callee = call.getCalleeAttr();
+  mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
+  bool isDim = funcNameBase.ends_with("Dim");
+  mlir::Value back = args[isDim ? 7 : 6];
   if (isTrueOrNotConstant(back))
     return;
 
-  mlir::Value mask = args[5];
+  mlir::Value mask = args[isDim ? 6 : 5];
   mlir::Value maskDef = findMaskDef(mask);
 
   // maskDef is set to NULL when the defining op is not one we accept.
@@ -1178,10 +1192,8 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
   if (maskDef == NULL)
     return;
 
-  mlir::SymbolRefAttr callee = call.getCalleeAttr();
-  mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
   unsigned rank = getDimCount(args[1]);
-  if (funcNameBase.ends_with("Dim") || !(rank > 0))
+  if ((isDim && rank != 1) || !(rank > 0))
     return;
 
   fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)};
@@ -1222,22 +1234,24 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction(
 
   llvm::raw_string_ostream nameOS(funcName);
   outType.print(nameOS);
+  if (isDim)
+    nameOS << '_' << inputType;
   nameOS << '_' << fmfString;
 
   auto typeGenerator = [rank](fir::FirOpBuilder &builder) {
     return genRuntimeMinlocType(builder, rank);
   };
   auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType,
-                        isMax](fir::FirOpBuilder &builder,
+                        isMax, isDim](fir::FirOpBuilder &builder,
                                mlir::func::FuncOp &funcOp) {
     genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType,
-                            logicalElemType, outType);
+                            logicalElemType, outType, isDim);
   };
 
   mlir::func::FuncOp newFunc =
       getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator);
   builder.create<fir::CallOp>(loc, newFunc,
-                              mlir::ValueRange{args[0], args[1], args[5]});
+                              mlir::ValueRange{args[0], args[1], mask});
   call->dropAllReferences();
   call->erase();
 }
diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir
index cd059cc797a3f4..4eabddc2f5ca32 100644
--- a/flang/test/Transforms/simplifyintrinsics.fir
+++ b/flang/test/Transforms/simplifyintrinsics.fir
@@ -2116,13 +2116,13 @@ func.func @_QPtestminloc_doesntwork1d_back(%arg0: !fir.ref<!fir.array<10xi32>> {
 // CHECK-NOT:         fir.call @_FortranAMinlocInteger4x1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
 
 // -----
-// Check Minloc is not simplified when DIM arg is set
+// Check Minloc is simplified when DIM arg is set so long as the result is scalar
 
-func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
+func.func @_QPtestminloc_1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
   %0 = fir.alloca !fir.box<!fir.heap<i32>>
   %c10 = arith.constant 10 : index
   %c1 = arith.constant 1 : index
-  %1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_doesntwork1d_dim", uniq_name = "_QFtestminloc_doesntwork1d_dimEtestminloc_doesntwork1d_dim"}
+  %1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_1d_dim", uniq_name = "_QFtestminloc_1d_dimEtestminloc_1d_dim"}
   %2 = fir.shape %c1 : (index) -> !fir.shape<1>
   %3 = fir.array_load %1(%2) : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>) -> !fir.array<1xi32>
   %4 = fir.shape %c10 : (index) -> !fir.shape<1>
@@ -2157,11 +2157,68 @@ func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref<!fir.array<10xi32>> {f
   %21 = fir.load %1 : !fir.ref<!fir.array<1xi32>>
   return %21 : !fir.array<1xi32>
 }
-// CHECK-LABEL:   func.func @_QPtestminloc_doesntwork1d_dim(
+// CHECK-LABEL:   func.func @_QPtestminloc_1d_dim(
 // CHECK-SAME:                                             %[[ARR:.*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "a"}) -> !fir.array<1xi32> {
-// CHECK-NOT:         fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
-// CHECK:             fir.call @_FortranAMinlocDim({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32, !fir.box<none>, i1) -> none
-// CHECK-NOT:         fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
+// CHECK:             fir.call @_FortranAMinlocDimx1_i32_i32_contract_simplified({{.*}}) fastmath<contract> : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>) -> ()
+
+// CHECK-LABEL:  func.func private @_FortranAMinlocDimx1_i32_i32_contract_simplified(%arg0: !fir.ref<!fir.box<none>>, %arg1: !fir.box<none>, %arg2: !fir.box<none>) attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK-NEXT:    %[[V0:.*]] = fir.alloca i32
+// CHECK-NEXT:    %c0_i32 = arith.constant 0 : i32
+// CHECK-NEXT:    %c1 = arith.constant 1 : index
+// CHECK-NEXT:    %[[V1:.*]] = fir.allocmem !fir.array<1xi32>
+// CHECK-NEXT:    %[[V2:.*]] = fir.shape %c1 : (index) -> !fir.shape<1>
+// CHECK-NEXT:    %[[V3:.*]] = fir.embox %[[V1]](%[[V2]]) : (!fir.heap<!fir.array<1xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<1xi32>>>
+// CHECK-NEXT:    %c0 = arith.constant 0 : index
+// CHECK-NEXT:    %[[V4:.*]] = fir.coordinate_of %[[V3]], %c0 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:    fir.store %c0_i32 to %[[V4]] : !fir.ref<i32>
+// CHECK-NEXT:    %c0_0 = arith.constant 0 : index
+// CHECK-NEXT:    %[[V5:.*]] = fir.convert %arg1 : (!fir.box<none>) -> !fir.box<!fir.array<?xi32>>
+// CHECK-NEXT:    %c1_i32 = arith.constant 1 : i32
+// CHECK-NEXT:    %c0_i32_1 = arith.constant 0 : i32
+// CHECK-NEXT:    fir.store %c0_i32_1 to %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:    %c2147483647_i32 = arith.constant 2147483647 : i32
+// CHECK-NEXT:    %c1_2 = arith.constant 1 : index
+// CHECK-NEXT:    %c0_3 = arith.constant 0 : index
+// CHECK-NEXT:    %[[V6:.*]]:3 = fir.box_dims %[[V5]], %c0_3 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
+// CHECK-NEXT:    %[[V7:.*]] = arith.subi %[[V6]]#1, %c1_2 : index
+// CHECK-NEXT:    %[[V8:.*]] = fir.do_loop %arg3 = %c0_0 to %[[V7]] step %c1_2 iter_args(%arg4 = %c2147483647_i32) -> (i32) {
+// CHECK-NEXT:      %c1_i32_4 = arith.constant 1 : i32
+// CHECK-NEXT:      fir.store %c1_i32_4 to %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:      %[[V12:.*]] = fir.coordinate_of %[[V5]], %arg3 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:      %[[V13:.*]] = fir.load %[[V12]] : !fir.ref<i32>
+// CHECK-NEXT:      %[[V14:.*]] = arith.cmpi slt, %[[V13]], %arg4 : i32
+// CHECK-NEXT:      %[[V15:.*]] = fir.if %[[V14]] -> (i32) {
+// CHECK-NEXT:        %c1_i32_5 = arith.constant 1 : i32
+// CHECK-NEXT:        %c0_6 = arith.constant 0 : index
+// CHECK-NEXT:        %[[V16:.*]] = fir.coordinate_of %[[V3]], %c0_6 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:        %[[V17:.*]] = fir.convert %arg3 : (index) -> i32
+// CHECK-NEXT:        %[[V18:.*]] = arith.addi %[[V17]], %c1_i32_5 : i32
+// CHECK-NEXT:        fir.store %[[V18]] to %[[V16]] : !fir.ref<i32>
+// CHECK-NEXT:        fir.result %[[V13]] : i32
+// CHECK-NEXT:      } else {
+// CHECK-NEXT:        fir.result %arg4 : i32
+// CHECK-NEXT:      }
+// CHECK-NEXT:      fir.result %[[V15]] : i32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[V9:.*]] = fir.load %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:    %[[V10:.*]] = arith.cmpi eq, %[[V9]], %c1_i32 : i32
+// CHECK-NEXT:    fir.if %[[V10]] {
+// CHECK-NEXT:      %c2147483647_i32_4 = arith.constant 2147483647 : i32
+// CHECK-NEXT:      %[[V12]] = arith.cmpi eq, %c2147483647_i32_4, %[[V8]] : i32
+// CHECK-NEXT:      fir.if %[[V12]] {
+// CHECK-NEXT:        %c0_5 = arith.constant 0 : index
+// CHECK-NEXT:        %[[V13]] = fir.coordinate_of %[[V3]], %c0_5 : (!fir.box<!fir.heap<!fir.array<1xi32>>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:        fir.store %c1_i32 to %[[V13]] : !fir.ref<i32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:    %[[V11:.*]] = fir.convert %arg0 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK-NEXT:    %[[V12:.*]] = fir.convert %[[V1]] : (!fir.heap<!fir.array<1xi32>>) -> !fir.heap<i32>
+// CHECK-NEXT:    %[[V13:.*]] = fir.embox %[[V12]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
+// CHECK-NEXT:    fir.store %[[V13]] to %[[V11]] : !fir.ref<!fir.box<!fir.heap<i32>>>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+
 
 // -----
 // Check Minloc is not simplified when dimension of inputArr is unknown

``````````

</details>


https://github.com/llvm/llvm-project/pull/81619


More information about the flang-commits mailing list