[flang-commits] [flang] 74adc3e - [flang][hlfir] fix missing conversion in transpose simplification

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Jun 21 09:59:29 PDT 2023


Author: Tom Eccles
Date: 2023-06-21T16:54:58Z
New Revision: 74adc3e0ebfb42a48f02d8d7094d6848a37a21f5

URL: https://github.com/llvm/llvm-project/commit/74adc3e0ebfb42a48f02d8d7094d6848a37a21f5
DIFF: https://github.com/llvm/llvm-project/commit/74adc3e0ebfb42a48f02d8d7094d6848a37a21f5.diff

LOG: [flang][hlfir] fix missing conversion in transpose simplification

It seems just replacing the operation was not replacing all of the uses
when the types of the expression before and after this pass differ (due
to differing shape information). Now the shape information is always
kept the same.

This fixes https://github.com/llvm/llvm-project/issues/63399

Differential Revision: https://reviews.llvm.org/D153333

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Builder/HLFIRTools.h
    flang/lib/Optimizer/Builder/HLFIRTools.cpp
    flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
    flang/test/HLFIR/simplify-hlfir-intrinsics.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index d20cdf76a1e73..9dceee4b37b4f 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -363,11 +363,13 @@ using ElementalKernelGenerator = std::function<hlfir::Entity(
     mlir::Location, fir::FirOpBuilder &, mlir::ValueRange)>;
 /// Generate an hlfir.elementalOp given call back to generate the element
 /// value at for each iteration.
+/// If exprType is specified, this will be the return type of the elemental op
 hlfir::ElementalOp genElementalOp(mlir::Location loc,
                                   fir::FirOpBuilder &builder,
                                   mlir::Type elementType, mlir::Value shape,
                                   mlir::ValueRange typeParams,
-                                  const ElementalKernelGenerator &genKernel);
+                                  const ElementalKernelGenerator &genKernel,
+                                  mlir::Type exprType = mlir::Type{});
 
 /// Structure to describe a loop nest.
 struct LoopNest {

diff  --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index aad7b72ba01f7..a11c235510fe9 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -722,12 +722,12 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
                               isPolymorphic);
 }
 
-hlfir::ElementalOp
-hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
-                      mlir::Type elementType, mlir::Value shape,
-                      mlir::ValueRange typeParams,
-                      const ElementalKernelGenerator &genKernel) {
-  mlir::Type exprType = getArrayExprType(elementType, shape, false);
+hlfir::ElementalOp hlfir::genElementalOp(
+    mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
+    mlir::Value shape, mlir::ValueRange typeParams,
+    const ElementalKernelGenerator &genKernel, mlir::Type exprType) {
+  if (!exprType)
+    exprType = getArrayExprType(elementType, shape, false);
   auto elementalOp =
       builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
   auto insertPt = builder.saveInsertionPoint();

diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 21df162265cda..f1c8d68960600 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -59,9 +59,16 @@ class TransposeAsElementalConversion
       return val;
     };
     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
-        loc, builder, elementType, resultShape, typeParams, genKernel);
+        loc, builder, elementType, resultShape, typeParams, genKernel,
+        transpose.getResult().getType());
 
-    rewriter.replaceOp(transpose, elementalOp.getResult());
+    // it wouldn't be safe to replace block arguments with a 
diff erent
+    // hlfir.expr type. Types can 
diff er due to 
diff ering amounts of shape
+    // information
+    assert(elementalOp.getResult().getType() ==
+           transpose.getResult().getType());
+
+    rewriter.replaceOp(transpose, elementalOp);
     return mlir::success();
   }
 

diff  --git a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir
index e75b6fba885ba..eac89a6423921 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics.fir
@@ -93,3 +93,94 @@ func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
+
+// expr with multiple uses
+func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) {
+  %0 = hlfir.transpose %arg0 : (!hlfir.expr<2x2xf32>) -> !hlfir.expr<2x2xf32>
+  %1 = hlfir.shape_of %0 : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
+  %2 = hlfir.elemental %1 : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
+  ^bb0(%arg2: index, %arg3: index):
+    %3 = hlfir.apply %0, %arg2, %arg3 : (!hlfir.expr<2x2xf32>, index, index) -> f32
+    %4 = math.cos %3 fastmath<contract> : f32
+    hlfir.yield_element %4 : f32
+  }
+  hlfir.assign %2 to %arg1 realloc : !hlfir.expr<2x2xf32>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
+  hlfir.destroy %2 : !hlfir.expr<2x2xf32>
+  hlfir.destroy %0 : !hlfir.expr<2x2xf32>
+  return
+}
+// CHECK-LABEL: func.func @transpose4(
+// CHECK-SAME:      %[[ARG0:.*]]: !hlfir.expr<2x2xf32>
+// CHECK-SAME:      %[[ARG1:.*]]:
+// CHECK:         %[[SHAPE0:.*]] = fir.shape
+// CHECK:         %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
+// CHECK:         ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
+// CHECK:           %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
+// CHECK:           hlfir.yield_element %[[ELE]] : f32
+// CHECK:         }
+// CHECK:         %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]] : (!hlfir.expr<2x2xf32>) -> !fir.shape<2>
+// CHECK:         %[[COS:.*]] = hlfir.elemental %[[SHAPE1]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
+// CHECK:         ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
+// CHECK:           %[[ELE:.*]] = hlfir.apply %[[TRANSPOSE]], %[[I]], %[[J]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
+// CHECK:           %[[COS_ELE:.*]] = math.cos %[[ELE]] fastmath<contract> : f32
+// CHECK:           hlfir.yield_element %[[COS_ELE]] : f32
+// CHECK:         }
+// CHECK:         hlfir.assign %[[COS]] to %[[ARG1]] realloc
+// CHECK:         hlfir.destroy %[[COS]] : !hlfir.expr<2x2xf32>
+// CHECK:         hlfir.destroy %[[TRANSPOSE]] : !hlfir.expr<2x2xf32>
+// CHECK:         return
+// CHECK:       }
+
+// regression test
+func.func @transpose5(%arg0: !fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>> {fir.host_assoc}) attributes {fir.internal_proc} {
+  %0 = fir.address_of(@_QFEb) : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
+  %1:2 = hlfir.declare %0 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>)
+  %c0_i32 = arith.constant 0 : i32
+  %2 = fir.coordinate_of %arg0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+  %3 = fir.load %2 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+  %4 = fir.box_addr %3 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
+  %c0 = arith.constant 0 : index
+  %5:3 = fir.box_dims %3, %c0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+  %c1 = arith.constant 1 : index
+  %6:3 = fir.box_dims %3, %c1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+  %7 = fir.shape %5#1, %6#1 : (index, index) -> !fir.shape<2>
+  %8:2 = hlfir.declare %4(%7) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
+  %c1_i32 = arith.constant 1 : i32
+  %9 = fir.coordinate_of %arg0, %c1_i32 : (!fir.ref<tuple<!fir.box<!fir.array<2x2xf64>>, !fir.box<!fir.array<2x2xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+  %10 = fir.load %9 : !fir.ref<!fir.box<!fir.array<2x2xf64>>>
+  %11 = fir.box_addr %10 : (!fir.box<!fir.array<2x2xf64>>) -> !fir.ref<!fir.array<2x2xf64>>
+  %c0_0 = arith.constant 0 : index
+  %12:3 = fir.box_dims %10, %c0_0 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+  %c1_1 = arith.constant 1 : index
+  %13:3 = fir.box_dims %10, %c1_1 : (!fir.box<!fir.array<2x2xf64>>, index) -> (index, index, index)
+  %14 = fir.shape %12#1, %13#1 : (index, index) -> !fir.shape<2>
+  %15:2 = hlfir.declare %11(%14) {uniq_name = "_QFEc"} : (!fir.ref<!fir.array<2x2xf64>>, !fir.shape<2>) -> (!fir.ref<!fir.array<2x2xf64>>, !fir.ref<!fir.array<2x2xf64>>)
+  %16 = hlfir.transpose %8#0 : (!fir.ref<!fir.array<2x2xf64>>) -> !hlfir.expr<2x2xf64>
+  %17 = hlfir.shape_of %16 : (!hlfir.expr<2x2xf64>) -> !fir.shape<2>
+  %18 = hlfir.elemental %17 : (!fir.shape<2>) -> !hlfir.expr<?x?xf64> {
+  ^bb0(%arg1: index, %arg2: index):
+    %19 = hlfir.apply %16, %arg1, %arg2 : (!hlfir.expr<2x2xf64>, index, index) -> f64
+    %20 = math.cos %19 fastmath<contract> : f64
+    hlfir.yield_element %20 : f64
+  }
+  hlfir.assign %18 to %1#0 realloc : !hlfir.expr<?x?xf64>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf64>>>>
+  hlfir.destroy %18 : !hlfir.expr<?x?xf64>
+  hlfir.destroy %16 : !hlfir.expr<2x2xf64>
+  return
+}
+// CHECK-LABEL: func.func @transpose5(
+// ...
+// CHECK:         %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0:.*]]
+// CHECK:         ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
+// CHECK:           %[[ELE:.*]] = hlfir.designate %[[ARRAY:.*]] (%[[J]], %[[I]])
+// CHECK:           %[[LOAD:.*]] = fir.load %[[ELE]]
+// CHECK:           hlfir.yield_element %[[LOAD]]
+// CHECK:         }
+// CHECK:         %[[SHAPE1:.*]] = hlfir.shape_of %[[TRANSPOSE]]
+// CHECK:         %[[COS:.*]] = hlfir.elemental %[[SHAPE1]]
+// ...
+// CHECK:         hlfir.assign %[[COS]] to %{{.*}} realloc
+// CHECK:         hlfir.destroy %[[COS]]
+// CHECK:         hlfir.destroy %[[TRANSPOSE]]
+// CHECK:         return
+// CHECK:       }


        


More information about the flang-commits mailing list