[flang-commits] [flang] 8ec0f22 - Update fir.insert_on_range syntax to make the range more explicit (NFC)
Mehdi Amini via flang-commits
flang-commits at lists.llvm.org
Tue Nov 23 18:06:25 PST 2021
Author: Mehdi Amini
Date: 2021-11-24T02:06:17Z
New Revision: 8ec0f221843c51096cf3e7a479e780be371388a8
URL: https://github.com/llvm/llvm-project/commit/8ec0f221843c51096cf3e7a479e780be371388a8
DIFF: https://github.com/llvm/llvm-project/commit/8ec0f221843c51096cf3e7a479e780be371388a8.diff
LOG: Update fir.insert_on_range syntax to make the range more explicit (NFC)
Also replace ArrayAttr with IndexElementsAttr to model subscript dimensions.
An array of attribute is a sparse inefficient storage, with an API that
requires to unpack/repack integers at every call site.
Instead we can store dense array of integer as IndexElementsAttr.
Reviewed By: clementval, kiranchandramohan
Differential Revision: https://reviews.llvm.org/D112899
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/test/Fir/convert-to-llvm.fir
flang/test/Fir/fir-ops.fir
flang/test/Fir/invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 031b244f95446..1368009b71da9 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -1971,18 +1971,18 @@ def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoSideEffect]> {
```mlir
%a = fir.undefined !fir.array<10x10xf32>
%c = arith.constant 3.0 : f32
- %1 = fir.insert_on_range %a, %c, [0 : index, 7 : index, 0 : index, 2 : index] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
+ %1 = fir.insert_on_range %a, %c from (0, 0) to (7, 2) : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
```
The first 28 elements of %1, with coordinates from (0,0) to (7,2), have
the value 3.0.
}];
- let arguments = (ins fir_SequenceType:$seq, AnyType:$val, ArrayAttr:$coor);
+ let arguments = (ins fir_SequenceType:$seq, AnyType:$val, IndexElementsAttr:$coor);
let results = (outs fir_SequenceType);
let assemblyFormat = [{
- $seq `,` $val `,` $coor attr-dict `:` functional-type(operands, results)
+ $seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
}];
let verifier = "return ::verify(*this);";
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1e1f3eefe4d72..7583d5dfcabb4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -929,14 +929,16 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
return success();
}
- bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
+ bool isFullRange(mlir::DenseIntElementsAttr indexes,
+ fir::SequenceType seqTy) const {
auto extents = seqTy.getShape();
- if (indexes.size() / 2 != extents.size())
+ if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
return false;
+ auto cur_index = indexes.value_begin<int64_t>();
for (unsigned i = 0; i < indexes.size(); i += 2) {
- if (indexes[i].cast<IntegerAttr>().getInt() != 0)
+ if (*(cur_index++) != 0)
return false;
- if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
+ if (*(cur_index++) != extents[i / 2] - 1)
return false;
}
return true;
@@ -1728,14 +1730,10 @@ struct InsertOnRangeOpConversion
SmallVector<uint64_t> lBounds;
SmallVector<uint64_t> uBounds;
- // Extract integer value from the attribute
- SmallVector<int64_t> coordinates = llvm::to_vector<4>(
- llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-
// Unzip the upper and lower bound and convert to a row major format.
- for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
+ mlir::DenseIntElementsAttr coor = range.coor();
+ auto reversedCoor = llvm::reverse(coor.getValues<int64_t>());
+ for (auto i = reversedCoor.begin(), e = reversedCoor.end(); i != e; ++i) {
uBounds.push_back(*i++);
lBounds.push_back(*i);
}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 262df2a101a40..9ec3bc52382aa 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -17,10 +17,14 @@
#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1374,16 +1378,62 @@ void fir::FieldIndexOp::build(mlir::OpBuilder &builder,
// InsertOnRangeOp
//===----------------------------------------------------------------------===//
+static ParseResult
+parseCustomRangeSubscript(mlir::OpAsmParser &parser,
+ mlir::DenseIntElementsAttr &coord) {
+ llvm::SmallVector<int64_t> lbounds;
+ llvm::SmallVector<int64_t> ubounds;
+ if (parser.parseKeyword("from") ||
+ parser.parseCommaSeparatedList(
+ AsmParser::Delimiter::Paren,
+ [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) ||
+ parser.parseKeyword("to") ||
+ parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&] {
+ return parser.parseInteger(ubounds.emplace_back(0));
+ }))
+ return failure();
+ llvm::SmallVector<int64_t> zippedBounds;
+ for (auto zip : llvm::zip(lbounds, ubounds)) {
+ zippedBounds.push_back(std::get<0>(zip));
+ zippedBounds.push_back(std::get<1>(zip));
+ }
+ coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds);
+ return success();
+}
+
+void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, InsertOnRangeOp op,
+ mlir::DenseIntElementsAttr coord) {
+ printer << "from (";
+ auto enumerate = llvm::enumerate(coord.getValues<int64_t>());
+ // Even entries are the lower bounds.
+ llvm::interleaveComma(
+ make_filter_range(
+ enumerate,
+ [](auto indexed_value) { return indexed_value.index() % 2 == 0; }),
+ printer, [&](auto indexed_value) { printer << indexed_value.value(); });
+ printer << ") to (";
+ // Odd entries are the upper bounds.
+ llvm::interleaveComma(
+ make_filter_range(
+ enumerate,
+ [](auto indexed_value) { return indexed_value.index() % 2 != 0; }),
+ printer, [&](auto indexed_value) { printer << indexed_value.value(); });
+ printer << ")";
+}
+
/// Range bounds must be nonnegative, and the range must not be empty.
static mlir::LogicalResult verify(fir::InsertOnRangeOp op) {
if (fir::hasDynamicSize(op.seq().getType()))
return op.emitOpError("must have constant shape and size");
- if (op.coor().size() < 2 || op.coor().size() % 2 != 0)
+ mlir::DenseIntElementsAttr coor = op.coor();
+ if (coor.size() < 2 || coor.size() % 2 != 0)
return op.emitOpError("has uneven number of values in ranges");
bool rangeIsKnownToBeNonempty = false;
- for (auto i = op.coor().end(), b = op.coor().begin(); i != b;) {
- int64_t ub = (*--i).cast<IntegerAttr>().getInt();
- int64_t lb = (*--i).cast<IntegerAttr>().getInt();
+ for (auto i = coor.getValues<int64_t>().end(),
+ b = coor.getValues<int64_t>().begin();
+ i != b;) {
+ int64_t ub = (*--i);
+ int64_t lb = (*--i);
if (lb < 0 || ub < 0)
return op.emitOpError("negative range bound");
if (rangeIsKnownToBeNonempty)
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 33b7941a45818..1ba4e544eb60f 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -80,7 +80,7 @@ fir.global @symbol : i64 {
fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (0, 0) to (31, 31) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
@@ -97,7 +97,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
fir.global internal @_QEmultiarray : !fir.array<32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32xi32>
- %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (5) to (31) : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
fir.has_value %2 : !fir.array<32xi32>
}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index a6ae02a12a54d..631200f397c7a 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -617,10 +617,10 @@ func @test_misc_ops(%arr1 : !fir.ref<!fir.array<?x?xf32>>, %m : index, %n : inde
%c1_i32 = arith.constant 9 : i32
// CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32>
- // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]], [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
+ // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]] from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
// CHECK: fir.call @noret1([[ARR3]]) : (!fir.array<10xi32>) -> ()
%arr2 = fir.zero_bits !fir.array<10xi32>
- %arr3 = fir.insert_on_range %arr2, %c1_i32, [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
+ %arr3 = fir.insert_on_range %arr2, %c1_i32 from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
fir.call @noret1(%arr3) : (!fir.array<10xi32>) -> ()
// CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2>
@@ -664,6 +664,14 @@ func @test_const_complex() {
return
}
+// CHECK-LABEL: @insert_on_range_multi_dim
+// CHECK-SAME: %[[ARR:.*]]: !fir.array<10x20xi32>, %[[CST:.*]]: i32
+func @insert_on_range_multi_dim(%arr : !fir.array<10x20xi32>, %cst : i32) {
+ // CHECK: fir.insert_on_range %[[ARR]], %[[CST]] from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32>
+ %arr3 = fir.insert_on_range %arr, %cst from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32>
+ return
+}
+
// CHECK-LABEL: @test_shift
func @test_shift(%arg0: !fir.box<!fir.array<?xf32>>) -> !fir.ref<f32> {
%c4 = arith.constant 4 : index
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 8bc2ac6793e8c..98ee4a4538b07 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -428,7 +428,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error at +1 {{'fir.insert_on_range' op has uneven number of values in ranges}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0, 31, 0]> : tensor<3xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
@@ -438,7 +438,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error at +1 {{'fir.insert_on_range' op has uneven number of values in ranges}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0]> : tensor<1xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
@@ -448,7 +448,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error at +1 {{'fir.insert_on_range' op negative range bound}}
- %2 = fir.insert_on_range %0, %c0_i32, [-1 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (-1) to (0) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
@@ -458,7 +458,7 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<32x32xi32>
// expected-error at +1 {{'fir.insert_on_range' op empty range}}
- %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (10) to (9) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
fir.has_value %2 : !fir.array<32x32xi32>
}
@@ -468,7 +468,7 @@ fir.global internal @_QEmultiarray : !fir.array<?xi32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<?xi32>
// expected-error at +1 {{'fir.insert_on_range' op must have constant shape and size}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<?xi32>, i32) -> !fir.array<?xi32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<?xi32>, i32) -> !fir.array<?xi32>
fir.has_value %2 : !fir.array<?xi32>
}
@@ -478,7 +478,7 @@ fir.global internal @_QEmultiarray : !fir.array<*:i32> {
%c0_i32 = arith.constant 1 : i32
%0 = fir.undefined !fir.array<*:i32>
// expected-error at +1 {{'fir.insert_on_range' op must have constant shape and size}}
- %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<*:i32>, i32) -> !fir.array<*:i32>
+ %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<*:i32>, i32) -> !fir.array<*:i32>
fir.has_value %2 : !fir.array<*:i32>
}
More information about the flang-commits
mailing list