[Mlir-commits] [mlir] 0b781db - [mlir] Add new builders to linalg.reshape.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Jun 11 03:48:12 PDT 2020
Author: Alexander Belyaev
Date: 2020-06-11T12:47:35+02:00
New Revision: 0b781db9087977b758c27bb02be1f80cd00bf0d7
URL: https://github.com/llvm/llvm-project/commit/0b781db9087977b758c27bb02be1f80cd00bf0d7
DIFF: https://github.com/llvm/llvm-project/commit/0b781db9087977b758c27bb02be1f80cd00bf0d7.diff
LOG: [mlir] Add new builders to linalg.reshape.
Differential Revision: https://reviews.llvm.org/D81640
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/EDSC/builder-api-test.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index ce3a1cef3b35..fbb3f062a494 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -34,6 +34,9 @@ class PoolingMaxOp;
class PoolingMinOp;
class PoolingSumOp;
+using ReassociationIndicies = SmallVector<int64_t, 2>;
+using ReassociationExprs = SmallVector<AffineExpr, 2>;
+
/// Returns the name mangled library call name to disambiguate between
diff erent
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
/// type names:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1615957ff0c3..64f92ee5d406 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -64,16 +64,32 @@ def Linalg_RangeOp :
class Linalg_ReshapeLikeOp<string mnemonic> :
Linalg_Op<mnemonic, [NoSideEffect]> {
let builders = [
- // Builder for a contracting reshape whose result type is computed from
+ // Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
OpBuilder<"OpBuilder &b, OperationState &result, Value src, "
- "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
- "ArrayRef<NamedAttribute> attrs = {}">,
- // Builder for a reshape whose result type is passed explicitly. This may be
- // either a contracting or expanding reshape.
- OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, Value src,"
- "ArrayRef<ArrayRef<AffineExpr>> reassociation, "
- "ArrayRef<NamedAttribute> attrs = {}">];
+ "ArrayRef<ReassociationExprs> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ OpBuilder<"OpBuilder &b, OperationState &result, Value src, "
+ "ArrayRef<ReassociationIndicies> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}", [{
+ auto reassociationMaps =
+ convertReassociationIndiciesToMaps(b, reassociation);
+ build(b, result, src, reassociationMaps, attrs);
+ }]>,
+
+ // Builders for a reshape whose result type is passed explicitly. This may
+ // be either a contracting or expanding reshape.
+ OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, "
+ "Value src, ArrayRef<ReassociationExprs> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ OpBuilder<"OpBuilder &b, OperationState &result, Type resultType, "
+ "Value src, ArrayRef<ReassociationIndicies> reassociation, "
+ "ArrayRef<NamedAttribute> attrs = {}", [{
+ auto reassociationMaps =
+ convertReassociationIndiciesToMaps(b, reassociation);
+ build(b, result, src, reassociationMaps, attrs);
+ }]>
+ ];
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index db4587fce014..717baa02de61 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -476,9 +476,9 @@ static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
}
template <typename AffineExprTy>
-unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
+unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
unsigned pos = 0;
- for (auto exprs : exprArrays) {
+ for (const auto &exprs : exprArrays) {
for (auto expr : exprs) {
expr.walk([&pos](AffineExpr e) {
if (auto d = e.dyn_cast<AffineExprTy>())
@@ -490,23 +490,37 @@ unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
}
static SmallVector<AffineMap, 4>
-getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
+getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
"Expected symbol-less expressions");
SmallVector<AffineMap, 4> maps;
maps.reserve(reassociation.size());
- for (auto exprs : reassociation) {
- assert(exprs.size() != 0);
+ for (const auto &exprs : reassociation) {
+ assert(!exprs.empty());
maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
}
return maps;
}
-void mlir::linalg::ReshapeOp::build(
- OpBuilder &b, OperationState &result, Value src,
- ArrayRef<ArrayRef<AffineExpr>> reassociation,
- ArrayRef<NamedAttribute> attrs) {
+static SmallVector<SmallVector<AffineExpr, 2>, 2>
+convertReassociationIndiciesToMaps(
+ OpBuilder &b, ArrayRef<ReassociationIndicies> reassociationIndicies) {
+ SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
+ for (const auto &indicies : reassociationIndicies) {
+ SmallVector<AffineExpr, 2> reassociationMap;
+ reassociationMap.reserve(indicies.size());
+ for (int64_t index : indicies)
+ reassociationMap.push_back(b.getAffineDimExpr(index));
+ reassociationMaps.push_back(std::move(reassociationMap));
+ }
+ return reassociationMaps;
+}
+
+void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
+ Value src,
+ ArrayRef<ReassociationExprs> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
auto maps = getSymbolLessAffineMaps(reassociation);
auto memRefType = src.getType().cast<MemRefType>();
auto resultType = computeReshapeCollapsedType(memRefType, maps);
@@ -515,10 +529,10 @@ void mlir::linalg::ReshapeOp::build(
b.getAffineMapArrayAttr(maps));
}
-void mlir::linalg::ReshapeOp::build(
- OpBuilder &b, OperationState &result, Type resultType, Value src,
- ArrayRef<ArrayRef<AffineExpr>> reassociation,
- ArrayRef<NamedAttribute> attrs) {
+void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
+ Type resultType, Value src,
+ ArrayRef<ReassociationExprs> reassociation,
+ ArrayRef<NamedAttribute> attrs) {
auto maps = getSymbolLessAffineMaps(reassociation);
build(b, result, resultType, src, attrs);
result.addAttribute(ReshapeOp::getReassociationAttrName(),
@@ -622,7 +636,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
void mlir::linalg::TensorReshapeOp::build(
OpBuilder &b, OperationState &result, Value src,
- ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<ReassociationExprs> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto maps = getSymbolLessAffineMaps(reassociation);
auto resultType = computeTensorReshapeCollapsedType(
@@ -634,7 +648,7 @@ void mlir::linalg::TensorReshapeOp::build(
void mlir::linalg::TensorReshapeOp::build(
OpBuilder &b, OperationState &result, Type resultType, Value src,
- ArrayRef<ArrayRef<AffineExpr>> reassociation,
+ ArrayRef<ReassociationExprs> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto maps = getSymbolLessAffineMaps(reassociation);
build(b, result, resultType, src, attrs);
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 4b01f7110532..3435926f867e 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -941,6 +941,8 @@ TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
// CHECK: linalg.reshape {{.*}} [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : memref<32x16xf32> into memref<4x8x16xf32>
// clang-format on
TEST_FUNC(linalg_metadata_ops) {
+ using linalg::ReassociationExprs;
+
auto f32Type = FloatType::getF32(&globalContext());
auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0);
auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});
@@ -950,9 +952,10 @@ TEST_FUNC(linalg_metadata_ops) {
AffineExpr i, j, k;
bindDims(&globalContext(), i, j, k);
Value v(f.getArgument(0));
- auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
- linalg_reshape(memrefType, reshaped,
- ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
+ SmallVector<ReassociationExprs, 2> maps = {ReassociationExprs({i, j}),
+ ReassociationExprs({k})};
+ auto reshaped = linalg_reshape(v, maps);
+ linalg_reshape(memrefType, reshaped, maps);
f.print(llvm::outs());
f.erase();
More information about the Mlir-commits
mailing list