[flang-commits] [flang] 5ab5cdc - [flang][hlfir] get extents from hlfir.shape_of

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Mon Apr 17 06:28:01 PDT 2023


Author: Tom Eccles
Date: 2023-04-17T13:25:54Z
New Revision: 5ab5cdc1e00e90865fb9907fe8d3a1e0fe3972c8

URL: https://github.com/llvm/llvm-project/commit/5ab5cdc1e00e90865fb9907fe8d3a1e0fe3972c8
DIFF: https://github.com/llvm/llvm-project/commit/5ab5cdc1e00e90865fb9907fe8d3a1e0fe3972c8.diff

LOG: [flang][hlfir] get extents from hlfir.shape_of

If the extents were known, this should have been canonicalised into a
fir.shape operation. Therefore, the extents at this point are not known at
compile time. Use hlfir.get_extents to delay resolving the real extent
until after the expression is bufferized.

Depends On: D146831

Differential Revision: https://reviews.llvm.org/D146832

Added: 
    flang/test/HLFIR/extents-of-shape-of.f90

Modified: 
    flang/lib/Optimizer/Builder/HLFIRTools.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 5fdf9928b244b..4603c6d0e256d 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -24,7 +24,7 @@
 // Return explicit extents. If the base is a fir.box, this won't read it to
 // return the extents and will instead return an empty vector.
 static llvm::SmallVector<mlir::Value>
-getExplicitExtentsFromShape(mlir::Value shape) {
+getExplicitExtentsFromShape(mlir::Value shape, fir::FirOpBuilder &builder) {
   llvm::SmallVector<mlir::Value> result;
   auto *shapeOp = shape.getDefiningOp();
   if (auto s = mlir::dyn_cast_or_null<fir::ShapeOp>(shapeOp)) {
@@ -35,15 +35,23 @@ getExplicitExtentsFromShape(mlir::Value shape) {
     result.append(e.begin(), e.end());
   } else if (mlir::dyn_cast_or_null<fir::ShiftOp>(shapeOp)) {
     return {};
+  } else if (auto s = mlir::dyn_cast_or_null<hlfir::ShapeOfOp>(shapeOp)) {
+    fir::ShapeType shapeTy = shape.getType().cast<fir::ShapeType>();
+    result.reserve(shapeTy.getRank());
+    for (unsigned i = 0; i < shapeTy.getRank(); ++i) {
+      auto op = builder.create<hlfir::GetExtentOp>(shape.getLoc(), shape, i);
+      result.emplace_back(op.getResult());
+    }
   } else {
     TODO(shape.getLoc(), "read fir.shape to get extents");
   }
   return result;
 }
 static llvm::SmallVector<mlir::Value>
-getExplicitExtents(fir::FortranVariableOpInterface var) {
+getExplicitExtents(fir::FortranVariableOpInterface var,
+                   fir::FirOpBuilder &builder) {
   if (mlir::Value shape = var.getShape())
-    return getExplicitExtentsFromShape(var.getShape());
+    return getExplicitExtentsFromShape(var.getShape(), builder);
   return {};
 }
 
@@ -385,7 +393,7 @@ hlfir::genBounds(mlir::Location loc, fir::FirOpBuilder &builder,
   assert((shape.getType().isa<fir::ShapeShiftType>() ||
           shape.getType().isa<fir::ShapeType>()) &&
          "shape must contain extents");
-  auto extents = getExplicitExtentsFromShape(shape);
+  auto extents = getExplicitExtentsFromShape(shape, builder);
   auto lowers = getExplicitLboundsFromShape(shape);
   assert(lowers.empty() || lowers.size() == extents.size());
   mlir::Type idxTy = builder.getIndexType();
@@ -440,7 +448,7 @@ llvm::SmallVector<mlir::Value> getVariableExtents(mlir::Location loc,
   llvm::SmallVector<mlir::Value> extents;
   if (fir::FortranVariableOpInterface varIface =
           variable.getIfVariableInterface()) {
-    extents = getExplicitExtents(varIface);
+    extents = getExplicitExtents(varIface, builder);
     if (!extents.empty())
       return extents;
   }
@@ -493,7 +501,8 @@ mlir::Value hlfir::genShape(mlir::Location loc, fir::FirOpBuilder &builder,
 llvm::SmallVector<mlir::Value>
 hlfir::getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder,
                        mlir::Value shape) {
-  llvm::SmallVector<mlir::Value> extents = getExplicitExtentsFromShape(shape);
+  llvm::SmallVector<mlir::Value> extents =
+      getExplicitExtentsFromShape(shape, builder);
   mlir::Type indexType = builder.getIndexType();
   for (auto &extent : extents)
     extent = builder.createConvert(loc, indexType, extent);
@@ -504,7 +513,7 @@ mlir::Value hlfir::genExtent(mlir::Location loc, fir::FirOpBuilder &builder,
                              hlfir::Entity entity, unsigned dim) {
   entity = followShapeInducingSource(entity);
   if (auto shape = tryRetrievingShapeOrShift(entity)) {
-    auto extents = getExplicitExtentsFromShape(shape);
+    auto extents = getExplicitExtentsFromShape(shape, builder);
     if (!extents.empty()) {
       assert(extents.size() > dim && "bad inquiry");
       return extents[dim];

diff  --git a/flang/test/HLFIR/extents-of-shape-of.f90 b/flang/test/HLFIR/extents-of-shape-of.f90
new file mode 100644
index 0000000000000..ff1a657dc0ea5
--- /dev/null
+++ b/flang/test/HLFIR/extents-of-shape-of.f90
@@ -0,0 +1,31 @@
+! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck %s
+subroutine foo(a, b)
+  real :: a(:, :), b(:, :)
+  interface
+    elemental subroutine elem_sub(x)
+      real, intent(in) :: x
+    end subroutine
+  end interface
+  call elem_sub(matmul(a, b))
+end subroutine
+! CHECK-LABEL: func.func @_QPfoo
+! CHECK:           %[[A_ARG:.*]]: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "a"}
+! CHECK:           %[[B_ARG:.*]]: !fir.box<!fir.array<?x?xf32>> {fir.bindc_name = "b"}
+! CHECK-DAG:     %[[A_VAR:.*]]:2 = hlfir.declare %[[A_ARG]]
+! CHECK-DAG:     %[[B_VAR:.*]]:2 = hlfir.declare %[[B_ARG]]
+! CHECK-NEXT:    %[[MUL:.*]] = hlfir.matmul %[[A_VAR]]#0 %[[B_VAR]]#0
+! CHECK-NEXT:    %[[SHAPE:.*]] = hlfir.shape_of %[[MUL]] : (!hlfir.expr<?x?xf32>) -> !fir.shape<2>
+! CHECK-NEXT:    %[[EXT0:.*]] = hlfir.get_extent %[[SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index
+! CHECK-NEXT:    %[[EXT1:.*]] = hlfir.get_extent %[[SHAPE]] {dim = 1 : index} : (!fir.shape<2>) -> index
+! CHECK-NEXT:    %[[C1:.*]] = arith.constant 1 : index
+! CHECK-NEXT:    fir.do_loop %[[ARG2:.*]] = %[[C1]] to %[[EXT1]] step %[[C1]] {
+! CHECK-NEXT:      fir.do_loop %[[ARG3:.*]] = %[[C1]] to %[[EXT0]] step %[[C1]] {
+! CHECK-NEXT:        %[[ELE:.*]] = hlfir.apply %[[MUL]], %[[ARG3]], %[[ARG2]] : (!hlfir.expr<?x?xf32>, index, index) -> f32
+! CHECK-NEXT:        %[[ASSOC:.*]]:3 = hlfir.associate %[[ELE]] {uniq_name = "adapt.valuebyref"} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+! CHECK-NEXT:        fir.call
+! CHECK-NEXT:        hlfir.end_associate
+! CHECK-NEXT:      }
+! CHECK-NEXT:    }
+! CHECK-NEXT:    hlfir.destroy %[[MUL]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:  }


        


More information about the flang-commits mailing list