[flang-commits] [flang] 683a6e1 - [flang][hlfir] lower hlfir.shape_of
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Mon Apr 17 06:28:03 PDT 2023
Author: Tom Eccles
Date: 2023-04-17T13:25:54Z
New Revision: 683a6e1c9e5396f64086c07bec334a38acd0ec7a
URL: https://github.com/llvm/llvm-project/commit/683a6e1c9e5396f64086c07bec334a38acd0ec7a
DIFF: https://github.com/llvm/llvm-project/commit/683a6e1c9e5396f64086c07bec334a38acd0ec7a.diff
LOG: [flang][hlfir] lower hlfir.shape_of
If possible the shape is gotten from the bufferization of the expr
argument.
The simple cases should already have been resolved during lowering. This
is mostly intended for cases where shape information is added in between
lowering and the end of bufferization (for example transformational
intrinsics with assumed shape arguments).
Depends on: D146832
Differential Revision: https://reviews.llvm.org/D146833
Added:
flang/test/HLFIR/shapeof-lowering.fir
Modified:
flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 4b631b2f99a5d..21fe2d9f89372 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -27,8 +27,9 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
-#include <mlir/Support/LogicalResult.h>
+#include "llvm/ADT/TypeSwitch.h"
namespace hlfir {
#define GEN_PASS_DEF_BUFFERIZEHLFIR
@@ -169,6 +170,38 @@ struct AsExprOpConversion : public mlir::OpConversionPattern<hlfir::AsExprOp> {
}
};
+struct ShapeOfOpConversion
+ : public mlir::OpConversionPattern<hlfir::ShapeOfOp> {
+ using mlir::OpConversionPattern<hlfir::ShapeOfOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(hlfir::ShapeOfOp shapeOf, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = shapeOf.getLoc();
+ mlir::ModuleOp mod = shapeOf->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod));
+
+ mlir::Value shape;
+ hlfir::Entity bufferizedExpr{getBufferizedExprStorage(adaptor.getExpr())};
+ if (bufferizedExpr.isVariable()) {
+ shape = hlfir::genShape(loc, builder, bufferizedExpr);
+ } else {
+ // everything else failed so try to create a shape from static type info
+ hlfir::ExprType exprTy =
+ adaptor.getExpr().getType().dyn_cast_or_null<hlfir::ExprType>();
+ if (exprTy)
+ shape = hlfir::genExprShape(builder, loc, exprTy);
+ }
+ // expected to never happen
+ if (!shape)
+ return emitError(loc,
+ "Unresolvable hlfir.shape_of where extents are unknown");
+
+ rewriter.replaceOp(shapeOf, shape);
+ return mlir::success();
+ }
+};
+
struct ApplyOpConversion : public mlir::OpConversionPattern<hlfir::ApplyOp> {
using mlir::OpConversionPattern<hlfir::ApplyOp>::OpConversionPattern;
explicit ApplyOpConversion(mlir::MLIRContext *ctx)
@@ -529,11 +562,11 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
auto module = this->getOperation();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
- patterns
- .insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
- AssociateOpConversion, ConcatOpConversion, DestroyOpConversion,
- ElementalOpConversion, EndAssociateOpConversion,
- NoReassocOpConversion, SetLengthOpConversion>(context);
+ patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
+ AssociateOpConversion, ConcatOpConversion,
+ DestroyOpConversion, ElementalOpConversion,
+ EndAssociateOpConversion, NoReassocOpConversion,
+ SetLengthOpConversion, ShapeOfOpConversion>(context);
mlir::ConversionTarget target(*context);
target.addIllegalOp<hlfir::ApplyOp, hlfir::AssociateOp, hlfir::ElementalOp,
hlfir::EndAssociateOp, hlfir::SetLengthOp,
diff --git a/flang/test/HLFIR/shapeof-lowering.fir b/flang/test/HLFIR/shapeof-lowering.fir
new file mode 100644
index 0000000000000..73e2270a0cd4a
--- /dev/null
+++ b/flang/test/HLFIR/shapeof-lowering.fir
@@ -0,0 +1,55 @@
+// Test hlfir.shape_of lowering
+// RUN: fir-opt %s -bufferize-hlfir | FileCheck %s
+
+func.func @shapeof_asexpr(%arg0: !fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.shape<1> {
+ %c0 = arith.constant 0 : index
+ %59:3 = fir.box_dims %arg0, %c0 : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+ %60 = fir.box_addr %arg0 : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.heap<!fir.array<?xf32>>
+ %61 = fir.shape_shift %59#0, %59#1 : (index, index) -> !fir.shapeshift<1>
+ %62:2 = hlfir.declare %60(%61) {uniq_name = ".tmp.intrinsic_result"} : (!fir.heap<!fir.array<?xf32>>, !fir.shapeshift<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.heap<!fir.array<?xf32>>)
+ %true = arith.constant true
+ %63 = hlfir.as_expr %62#0 move %true : (!fir.box<!fir.array<?xf32>>, i1) -> !hlfir.expr<?xf32>
+ %64 = hlfir.shape_of %63 : (!hlfir.expr<?xf32>) -> !fir.shape<1>
+ return %64 : !fir.shape<1>
+}
+// CHECK-LABEL: @shapeof_asexpr
+// CHECK: %[[ARG0:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]]
+// CHECK-NEXT: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]]
+// CHECK-NEXT: %[[SHPE_SHFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
+// CHECK-NEXT: %[[VAR:.*]]:2 = hlfir.declare %[[BOX_ADDR]](%[[SHPE_SHFT]])
+// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: %[[TUPLE0:.*]] = fir.undefined tuple
+// CHECK-NEXT: %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[TRUE]]
+// CHECK-NEXT: %[[TUPLE2:.*]] = fir.insert_value %[[TUPLE1]], %[[VAR]]#0
+// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1
+// CHECK-NEXT: return %[[SHAPE]]
+
+func.func @shapeof_elemental() -> !fir.shape<1> {
+ %c1 = arith.constant 1 : index
+ %0 = fir.shape %c1 : (index) -> !fir.shape<1>
+ %1 = hlfir.elemental %0 : (!fir.shape<1>) -> !hlfir.expr<?xindex> {
+ ^bb0(%arg3: index):
+ hlfir.yield_element %arg3 : index
+ }
+ %2 = hlfir.shape_of %1 : (!hlfir.expr<?xindex>) -> !fir.shape<1>
+ return %2 : !fir.shape<1>
+}
+// CHECK-LABEL: @shapeof_elemental
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]]
+// CHECK: fir.do_loop %{{.*}} = %{{.*}} to %[[C1:.*]]
+// CHECK: return %[[SHAPE]]
+
+func.func @shapeof_fallback(%arg0: !hlfir.expr<1x2x3xi32>) -> !fir.shape<3> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<1x2x3xi32>) -> !fir.shape<3>
+ return %shape : !fir.shape<3>
+}
+// CHECK-LABEL: @shapeof_fallback
+// CHECK: %[[EXPR:.*]]: !hlfir.expr<1x2x3xi32>
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-NEXT: %[[SHAPE:.*]] = fir.shape %[[C1]], %[[C2]], %[[C3]] :
+// CHECK-NEXT: return %[[SHAPE]]
More information about the flang-commits
mailing list