[Mlir-commits] [mlir] 3598c24 - [mlir][linalg] Change linalg.broadcast `dimensions` attribute to represent added dimensions.
Oleg Shyshkov
llvmlistbot at llvm.org
Mon Nov 21 04:17:00 PST 2022
Author: Oleg Shyshkov
Date: 2022-11-21T13:16:41+01:00
New Revision: 3598c24983be90a582cdafb7864e302193c340f4
URL: https://github.com/llvm/llvm-project/commit/3598c24983be90a582cdafb7864e302193c340f4
DIFF: https://github.com/llvm/llvm-project/commit/3598c24983be90a582cdafb7864e302193c340f4.diff
LOG: [mlir][linalg] Change linalg.broadcast `dimensions` attribute to represent added dimensions.
Original [RFC](discourse.llvm.org/t/rfc-primitive-ops-add-broadcastop-to-linalg/66313) defined `dimensions` as a map from input to init, but a discussion in reviews.llvm.org/D138291 concluded that it's more natural for `dimensions` to represent added dims. Also this way is more consistent with `linalg.reduce`.
Differential Revision: https://reviews.llvm.org/D138408
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e822435ce4d84..815d542ca83cf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -463,19 +463,14 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Static broadcast operator";
let description = [{
- Broadcast the input into the given shape by adding dimensions.
-
- Each index in `dimensions` attribute maps input dimension into the
- corresponding target dimension. The length of the `dimensions` list should
- match the `input` rank and dimensions should be in sorted order. There is no
- ambiguity at compile-time about shape information.
+ Broadcast the input into the given shape by adding `dimensions`.
Example:
```
%bcast = linalg.broadcast
ins(%input:tensor<16xf32>)
inits(%init:tensor<16x64xf32>)
- dimensions = [0]
+ dimensions = [1]
```
}];
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ea5263c3147fc..5ce936e6431c8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1511,10 +1511,6 @@ void BroadcastOp::print(OpAsmPrinter &p) {
LogicalResult BroadcastOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
- if (!llvm::is_sorted(dimensionsRef))
- return emitOpError() << "dimensions should be in sorted order, implicit "
- "transpose is not supported";
-
auto inputType = getInput().getType();
auto initType = getInit().getType();
@@ -1524,34 +1520,35 @@ LogicalResult BroadcastOp::verify() {
auto inputShape = inputType.getShape();
auto initShape = initType.getShape();
- if ((size_t)inputRank != dimensionsRef.size())
- return emitOpError()
- << "input rank does match the number of dimensions. expected: "
- << inputRank << ", got: " << dimensionsRef.size();
-
- // Mapping from init dims to input dims.
- const int64_t kUnmappedDim = -1;
- SmallVector<int64_t> reverseDimMap(initRank, kUnmappedDim);
+ if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
+ return emitOpError() << "input rank plus added dimensions does not "
+ "match init rank. input rank: "
+ << inputRank
+ << ", dimensions size: " << dimensionsRef.size()
+ << ", init rank: " << initRank;
for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
if (dim < 0 || dim >= initRank)
return emitOpError() << "dimension " << idx
<< " is out of range. expected range: [0, "
<< initRank - 1 << "], got: " << dim;
+ }
- reverseDimMap[dim] = idx;
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensionsRef, dim))
+ dimMap.push_back(dim);
}
- for (const auto &[idx, inputDimIdx] : llvm::enumerate(reverseDimMap)) {
- if (inputDimIdx != kUnmappedDim) {
- // This dimensions is mapped from the input. Init and input dims should
- // match.
- if (inputShape[inputDimIdx] != initShape[idx])
- return emitOpError()
- << "input dim " << inputDimIdx << " should match init dim "
- << idx << ". input: " << inputShape[inputDimIdx]
- << ", init: " << initShape[idx];
- }
+ for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
+ // This dimensions is mapped from the input. Init and input dims should
+ // match.
+ if (inputShape[inputDimIdx] != initShape[initDimIdx])
+ return emitOpError() << "input dim " << inputDimIdx
+ << " should match init dim " << initDimIdx
+ << ". input: " << inputShape[inputDimIdx]
+ << ", init: " << initShape[initDimIdx];
}
return success();
@@ -1566,8 +1563,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
Builder builder(getContext());
int64_t rank = getInit().getType().getRank();
return builder.getAffineMapArrayAttr(
- {builder.getMultiDimIdentityMap(rank).getSubMap(
- llvm::to_vector_of<unsigned>(getDimensions())),
+ {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
builder.getMultiDimIdentityMap(rank)});
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9eddc1c73bf64..03540be889905 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -676,27 +676,14 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
// -----
-func.func @broadcast_unsorted_dims(
- %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
- -> tensor<4x8x16xf32> {
- // expected-error @+1 {{'linalg.broadcast' op dimensions should be in sorted order}}
- %bcast = linalg.broadcast
- ins(%input:tensor<4x16xf32>)
- outs(%init:tensor<4x8x16xf32>)
- dimensions = [1, 0]
- func.return %bcast : tensor<4x8x16xf32>
-}
-
-// -----
-
func.func @broadcast_input_dims_rank_mismatch(
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
- // expected-error @+1 {{'linalg.broadcast' op input rank does match the number of dimensions. expected: 2, got: 1}}
+ // expected-error @+1 {{'linalg.broadcast' op input rank plus added dimensions does not match init rank. }}
%bcast = linalg.broadcast
ins(%input:tensor<4x16xf32>)
outs(%init:tensor<4x8x16xf32>)
- dimensions = [0]
+ dimensions = [1, 2]
func.return %bcast : tensor<4x8x16xf32>
}
@@ -705,11 +692,11 @@ func.func @broadcast_input_dims_rank_mismatch(
func.func @broadcast_unsorted_dims(
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
- // expected-error @+1 {{'linalg.broadcast' op dimension 1 is out of range. expected range: [0, 2], got: 5}}
+ // expected-error @+1 {{'linalg.broadcast' op dimension 0 is out of range. expected range: [0, 2], got: 5}}
%bcast = linalg.broadcast
ins(%input:tensor<4x16xf32>)
outs(%init:tensor<4x8x16xf32>)
- dimensions = [0, 5]
+ dimensions = [5]
func.return %bcast : tensor<4x8x16xf32>
}
@@ -722,7 +709,7 @@ func.func @broadcast_mapped_dim_mismatch(
%bcast = linalg.broadcast
ins(%input:tensor<4x16xf32>)
outs(%init:tensor<5x8x16xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<5x8x16xf32>
}
@@ -735,6 +722,6 @@ func.func @broadcast_size_1_extension_not_supported(
%bcast = linalg.broadcast
ins(%input:tensor<1x16xf32>)
outs(%init:tensor<4x?x16xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<4x?x16xf32>
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 9d100d5117fdd..424539b7e86f2 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -395,7 +395,7 @@ func.func @broadcast(%input: tensor<8x32xf32>,
%bcast = linalg.broadcast
ins(%input:tensor<8x32xf32>)
outs(%init:tensor<8x16x32xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<8x16x32xf32>
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 64c2bea1f7ee1..8f0c83fe202e1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -525,7 +525,7 @@ func.func @broadcast_static_sizes(%input: tensor<8x32xf32>,
%bcast = linalg.broadcast
ins(%input:tensor<8x32xf32>)
outs(%init:tensor<8x16x32xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<8x16x32xf32>
}
// CHECK-LABEL: func @broadcast_static_sizes
@@ -542,7 +542,7 @@ func.func @broadcast_with_dynamic_sizes(
%bcast = linalg.broadcast
ins(%input:tensor<8x?xf32>)
outs(%init:tensor<8x16x?xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return %bcast : tensor<8x16x?xf32>
}
// CHECK-LABEL: func @broadcast_with_dynamic_sizes
@@ -558,7 +558,7 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
linalg.broadcast
ins(%input:memref<8x32xf32>)
outs(%init:memref<8x16x32xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return
}
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index b2e3fd5eec3b1..f0d1938e79dd7 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -248,7 +248,7 @@ func.func @broadcast(%input: memref<8x32xf32>,
linalg.broadcast
ins(%input:memref<8x32xf32>)
outs(%init:memref<8x16x32xf32>)
- dimensions = [0, 2]
+ dimensions = [1]
func.return
}
// CHECK-LABEL: func.func @broadcast(
More information about the Mlir-commits
mailing list