[flang-commits] [flang] 5c92507 - [flang][hlfir] add hlfir.shape_of
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Mon Apr 17 06:27:57 PDT 2023
Author: Tom Eccles
Date: 2023-04-17T13:25:53Z
New Revision: 5c925079927766adf72116b508d061c4a5803b56
URL: https://github.com/llvm/llvm-project/commit/5c925079927766adf72116b508d061c4a5803b56
DIFF: https://github.com/llvm/llvm-project/commit/5c925079927766adf72116b508d061c4a5803b56.diff
LOG: [flang][hlfir] add hlfir.shape_of
This is an operation which returns the fir.shape for a hlfir.expr.
A hlfir.expr can be defined by:
- A transformational intrinsic (e.g. hlfir.matmul)
- hlfir.as_expr
- hlfir.elemental
hlfir.elemental is easy because there is a compulsory shape operand.
hlfir.as_expr is defined as operating on a variable (defined using a
hlfir.declare). hlfir.declare has an optional shape argument. The
transformational intrinsics do not have an associated shape.
If all extents are known at compile time, the extents for the shape can
be fetched from the hlfir.expr's type. For example, the result of a
hlfir.matmul with arguments who's extents are known at compile time will
have constant extents which can be queried from the type. In this case
the hlfir.shape_of will be canonicalised to a fir.shape operation using
those extents.
If not all extents are known at compile time, shapes have to be read
from boxes after bufferization. In the case of the transformational
intrinsics, the shape read from the result box can be queried from the
hlfir.declare operation for the buffer allocated to that hlfir.expr (via
the hlfir.as_expr).
Differential Revision: https://reviews.llvm.org/D146830
Added:
flang/test/HLFIR/shapeof.fir
Modified:
flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 9673804e09f1a..0386b37dbda34 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -78,6 +78,11 @@ bool isI1Type(mlir::Type);
// scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)
bool isMaskArgument(mlir::Type);
+/// If an expression's extents are known at compile time, generate a fir.shape
+/// for this expression. Otherwise return {}
+mlir::Value genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc,
+ const hlfir::ExprType &expr);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 15c5735c1e8eb..c942d97f4218b 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -763,4 +763,34 @@ def hlfir_CopyOutOp : hlfir_Op<"copy_out", []> {
}];
}
+def hlfir_ShapeOfOp : hlfir_Op<"shape_of", [Pure]> {
+ let summary = "Get the shape of a hlfir.expr";
+ let description = [{
+ Gets the runtime shape of a hlfir.expr. In lowering to FIR, the
+ hlfir.shape_of operation will be replaced by an fir.shape.
+ It is not valid to request the shape of a hlfir.expr which has no shape.
+ }];
+
+ let arguments = (ins hlfir_ExprType:$expr);
+
+ let results = (outs fir_ShapeType);
+
+ let hasVerifier = 1;
+
+ // If all extents are known at compile time, the hlfir.shape_of can be
+ // immediately folded into a fir.shape operation. This makes information
+ // available sooner to inform bufferization decisions
+ let hasCanonicalizeMethod = 1;
+
+ let extraClassDeclaration = [{
+ std::size_t getRank();
+ }];
+
+ let assemblyFormat = [{
+ $expr attr-dict `:` functional-type(operands, results)
+ }];
+
+ let builders = [OpBuilder<(ins "mlir::Value":$expr)>];
+}
+
#endif // FORTRAN_DIALECT_HLFIR_OPS
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index 2cadd6880cb1f..cf6b332028c78 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -11,13 +11,16 @@
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.cpp.inc"
@@ -158,3 +161,23 @@ bool hlfir::isMaskArgument(mlir::Type type) {
// input is a scalar, so allow i1 too
return mlir::isa<fir::LogicalType>(elementType) || isI1Type(elementType);
}
+
+mlir::Value hlfir::genExprShape(mlir::OpBuilder &builder,
+ const mlir::Location &loc,
+ const hlfir::ExprType &expr) {
+ mlir::IndexType indexTy = builder.getIndexType();
+ llvm::SmallVector<mlir::Value> extents;
+ extents.reserve(expr.getRank());
+
+ for (std::int64_t extent : expr.getShape()) {
+ if (extent == hlfir::ExprType::getUnknownExtent())
+ return {};
+ extents.emplace_back(builder.create<mlir::arith::ConstantOp>(
+ loc, indexTy, builder.getIntegerAttr(indexTy, extent)));
+ }
+
+ fir::ShapeType shapeTy =
+ fir::ShapeType::get(builder.getContext(), expr.getRank());
+ fir::ShapeOp shape = builder.create<fir::ShapeOp>(loc, shapeTy, extents);
+ return shape.getResult();
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index ef3b4d57e1f7f..88acb2b4ac790 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -13,6 +13,7 @@
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -886,5 +887,55 @@ void hlfir::CopyInOp::build(mlir::OpBuilder &builder,
var_is_present);
}
+//===----------------------------------------------------------------------===//
+// ShapeOfOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::ShapeOfOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &result, mlir::Value expr) {
+ hlfir::ExprType exprTy = expr.getType().cast<hlfir::ExprType>();
+ mlir::Type type = fir::ShapeType::get(builder.getContext(), exprTy.getRank());
+ build(builder, result, type, expr);
+}
+
+std::size_t hlfir::ShapeOfOp::getRank() {
+ mlir::Type resTy = getResult().getType();
+ fir::ShapeType shape = resTy.cast<fir::ShapeType>();
+ return shape.getRank();
+}
+
+mlir::LogicalResult hlfir::ShapeOfOp::verify() {
+ mlir::Value expr = getExpr();
+ hlfir::ExprType exprTy = expr.getType().cast<hlfir::ExprType>();
+ std::size_t exprRank = exprTy.getShape().size();
+
+ if (exprRank == 0)
+ return emitOpError("cannot get the shape of a shape-less expression");
+
+ std::size_t shapeRank = getRank();
+ if (shapeRank != exprRank)
+ return emitOpError("result rank and expr rank do not match");
+
+ return mlir::success();
+}
+
+mlir::LogicalResult
+hlfir::ShapeOfOp::canonicalize(ShapeOfOp shapeOf,
+ mlir::PatternRewriter &rewriter) {
+ // if extent information is available at compile time, immediately fold the
+ // hlfir.shape_of into a fir.shape
+ mlir::Location loc = shapeOf.getLoc();
+ hlfir::ExprType expr = shapeOf.getExpr().getType().cast<hlfir::ExprType>();
+
+ mlir::Value shape = hlfir::genExprShape(rewriter, loc, expr);
+ if (!shape)
+ // shape information is not available at compile time
+ return mlir::LogicalResult::failure();
+
+ rewriter.replaceAllUsesWith(shapeOf.getResult(), shape);
+ rewriter.eraseOp(shapeOf);
+ return mlir::LogicalResult::success();
+}
+
#define GET_OP_CLASSES
#include "flang/Optimizer/HLFIR/HLFIROps.cpp.inc"
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index a8ba337ad8b5a..7ecb7cd221839 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -500,3 +500,15 @@ func.func @bad_parent_comp6(%arg0: !fir.box<!fir.array<10x!fir.type<t2{i:i32,j:i
%2 = hlfir.parent_comp %arg0 shape %1 : (!fir.box<!fir.array<10x!fir.type<t2{i:i32,j:i32}>>>, !fir.shape<1>) -> !fir.ref<!fir.array<10x!fir.type<t1{i:i32}>>>
return
}
+
+// -----
+func.func @bad_shapeof(%arg0: !hlfir.expr<!fir.char<1,10>>) {
+ // expected-error at +1 {{'hlfir.shape_of' op cannot get the shape of a shape-less expression}}
+ %0 = hlfir.shape_of %arg0 : (!hlfir.expr<!fir.char<1,10>>) -> !fir.shape<1>
+}
+
+// -----
+func.func @bad_shapeof2(%arg0: !hlfir.expr<10xi32>) {
+ // expected-error at +1 {{'hlfir.shape_of' op result rank and expr rank do not match}}
+ %0 = hlfir.shape_of %arg0 : (!hlfir.expr<10xi32>) -> !fir.shape<42>
+}
diff --git a/flang/test/HLFIR/shapeof.fir b/flang/test/HLFIR/shapeof.fir
new file mode 100644
index 0000000000000..b91efc276b62e
--- /dev/null
+++ b/flang/test/HLFIR/shapeof.fir
@@ -0,0 +1,29 @@
+// Test hlfir.shape_of operation parse, verify (no errors), and unparse
+// RUN: fir-opt %s | fir-opt | FileCheck --check-prefix CHECK --check-prefix CHECK-ALL %s
+
+// Test canonicalization
+// RUN: fir-opt %s --canonicalize | FileCheck --check-prefix CHECK-CANON --check-prefix CHECK-ALL %s
+
+func.func @shapeof(%arg0: !hlfir.expr<2x2xi32>) -> !fir.shape<2> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<2x2xi32>) -> !fir.shape<2>
+ return %shape : !fir.shape<2>
+}
+// CHECK-ALL-LABEL: func.func @shapeof
+// CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr<2x2xi32>
+
+// CHECK-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr<2x2xi32>) -> !fir.shape<2>
+
+// CHECK-CANON-NEXT: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-CANON-NEXT: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C2]] : (index, index) -> !fir.shape<2>
+
+// CHECK-ALL-NEXT: return %[[SHAPE]]
+
+// no canonicalization of expressions with unknown extents
+func.func @shapeof2(%arg0: !hlfir.expr<?x2xi32>) -> !fir.shape<2> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x2xi32>) -> !fir.shape<2>
+ return %shape : !fir.shape<2>
+}
+// CHECK-ALL-LABEL: func.func @shapeof2
+// CHECK-ALL: %[[EXPR:.*]]: !hlfir.expr<?x2xi32>
+// CHECK-ALL-NEXT: %[[SHAPE:.*]] = hlfir.shape_of %[[EXPR]] : (!hlfir.expr<?x2xi32>) -> !fir.shape<2>
+// CHECK-ALL-NEXT: return %[[SHAPE]]
More information about the flang-commits
mailing list