[Mlir-commits] [mlir] Support for dynamic dimensions in 'tensor.splat' (PR #74626)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 6 09:15:21 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rafael Ubal (rafaelubalmw)
<details>
<summary>Changes</summary>
This feature had been marked as `TODO` in the `tensor.splat` documentation for a while. This MR includes:
- Support for dynamically shaped tensors in the return type of `tensor.splat` with the syntax suggested in the `TODO` comment.
- Updated op documentation.
- Bufferization support.
- Updates in op folders affected by the new feature.
- Unit tests for valid/invalid syntax, valid/invalid folding, and lowering through bufferization.
- Additional op builders resembling those available in `tensor.empty`.
---
Full diff: https://github.com/llvm/llvm-project/pull/74626.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+35-16)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+48-1)
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+20)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+40)
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+8)
- (modified) mlir/test/Dialect/Tensor/ops.mlir (+10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index f50e3464867be..251a53ed3b888 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1728,6 +1728,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
def Tensor_SplatOp : Tensor_Op<"splat", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
Pure,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
@@ -1736,38 +1737,56 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
let summary = "tensor splat or broadcast operation";
let description = [{
Broadcast the operand to all elements of the result tensor. The operand is
- required to be of integer/index/float type, and the result tensor must be
- statically shaped.
+ required to be of integer/index/float type.
- Example:
+ An additional argument of type `index` must be provided for each dynamic
+ dimension present in the result type.
+
+ Example for a statically shaped tensor:
```mlir
- %s = arith.constant 10.1 : f32
+ %s = arith.constant 1.0 : f32
%t = tensor.splat %s : tensor<8x16xf32>
```
- TODO: This operation is easy to extend to broadcast to dynamically shaped
- tensors:
+ Example for a tensor containing dynamic dimensions:
```mlir
- // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
- // to the sizes of the two dynamic dimensions.
- %m = "foo"() : () -> (index)
- %n = "bar"() : () -> (index)
- %t = tensor.splat %s [%m, %n] : tensor<?x?xf32>
+ // Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding
+ // to dimensions 0 and 2 of the resulting tensor, respectively.
+ %m = arith.constant 10 : index
+ %n = arith.constant 30 : index
+ %t = tensor.splat %s[%m, %n] : tensor<?x20x?xf32>
```
}];
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
- "integer/index/float type">:$input);
- let results = (outs AnyStaticShapeTensor:$aggregate);
+ "integer/index/float type">:$input,
+ Variadic<Index>:$dynamicSizes);
+ let results = (outs AnyRankedTensor:$aggregate);
let builders = [
- OpBuilder<(ins "Value":$element, "Type":$aggregateType),
- [{ build($_builder, $_state, aggregateType, element); }]>];
- let assemblyFormat = "$input attr-dict `:` type($aggregate)";
+ // Build with an explicit result type and a list of values corresponding
+ // to the dynamic sizes present in the result type.
+ OpBuilder<(ins "Value":$element,
+ "Type":$aggregateType,
+ CArg<"ValueRange", "{}">:$dynamicSizes)>,
+
+ // Build with a result tensor shape and a list of values corresponding to
+ // the elements in the result tensor shape set to ShapedType::kDynamic.
+ OpBuilder<(ins "Value":$element,
+ "ArrayRef<int64_t>":$staticShape,
+ CArg<"ValueRange", "{}">:$dynamicSizes)>,
+
+ // Build with mixed static/dynamic sizes, where an attribute represents
+ // a static dimension and a value represents a dynamic dimension.
+ OpBuilder<(ins "Value":$element, "ArrayRef<OpFoldResult>":$sizes)>
+ ];
+
+ let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f15695383d34a..8fad57eea64ae 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1736,7 +1736,7 @@ class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
- if (!splatOp)
+ if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
return failure();
rewriter.replaceOpWithNewOp<tensor::SplatOp>(
@@ -3430,16 +3430,63 @@ LogicalResult ScatterOp::verify() {
// SplatOp
//===----------------------------------------------------------------------===//
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+ Type aggregateType, ValueRange dynamicSizes) {
+ build(builder, result, aggregateType, element, dynamicSizes);
+}
+
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+ ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
+ auto aggregateType = RankedTensorType::get(staticShape, element.getType());
+ build(builder, result, aggregateType, element, dynamicSizes);
+}
+
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+ ArrayRef<OpFoldResult> sizes) {
+ SmallVector<int64_t> staticShape;
+ SmallVector<Value> dynamicSizes;
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
+ build(builder, result, element, staticShape, dynamicSizes);
+}
+
void SplatOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "splat");
}
+LogicalResult SplatOp::verify() {
+ if (getType().getNumDynamicDims() !=
+ static_cast<int64_t>(getDynamicSizes().size()))
+ return emitOpError("incorrect number of dynamic sizes, has ")
+ << getDynamicSizes().size() << ", expected "
+ << getType().getNumDynamicDims();
+ return success();
+}
+
+LogicalResult
+SplatOp::reifyResultShapes(OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+ unsigned ctr = 0;
+ for (int64_t i = 0; i < getType().getRank(); ++i) {
+ if (getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
+ } else {
+ reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+ }
+ }
+ return success();
+}
+
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
return {};
+ // Do not fold if the splat is not statically shaped
+ if (!getType().hasStaticShape())
+ return {};
+
// SplatElementsAttr::get treats single value for second arg as being a
// splat.
return SplatElementsAttr::get(getType(), {constOperand});
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index a8b3c6af9ae89..e3c6ebbeb9d91 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -602,3 +602,23 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
%t = tensor.splat %f : tensor<10x2x4xf32>
return %t : tensor<10x2x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @tensor.splat.dynamic(
+// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[N:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
+// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
+// CHECK: () {
+// CHECK: linalg.yield %[[F]] : f32
+// CHECK: }
+// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32>
+// CHECK: }
+func.func @tensor.splat.dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf32> {
+ %0 = tensor.splat %f[%m, %n] : tensor<?x3x?xf32>
+ return %0 : tensor<?x3x?xf32>
+}
+
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 84c44a09aa3dd..6b86341911f59 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1204,6 +1204,19 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
// -----
+// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
+// CHECK-SAME: %[[F:.+]]: f32
+// CHECK-SAME: %[[M:.+]]: index
+func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> {
+ // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
+ %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32>
+ return %0 : tensor<2x2x?xf32>
+}
+
+// -----
+
func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
%c0 = tensor.splat %arg : tensor<2x2x2xf32>
%0 = tensor.collapse_shape %c0 [[0], [1, 2]]
@@ -1217,6 +1230,20 @@ func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
// CHECK: return %[[CST]]
// -----
+
+// CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold
+// CHECK-SAME: %[[F:.+]]: f32
+// CHECK-SAME: %[[M:.+]]: index
+func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
+ // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
+ %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
+ %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
+ return %0 : tensor<2x?xf32>
+}
+
+// -----
+
func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
%c0 = arith.constant dense<42> : tensor<2x8xi16>
%0 = tensor.expand_shape %c0 [[0], [1, 2]]
@@ -1627,6 +1654,19 @@ func.func @splat_fold() -> tensor<4xf32> {
// -----
+// CHECK-LABEL: func @splat_dynamic_no_fold
+// CHECK-SAME: %[[M:.+]]: index
+func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
+ // CHECK: %[[F:.+]] = arith.constant
+ %f = arith.constant 1.0 : f32
+
+ // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
+ %t = tensor.splat %f[%m] : tensor<4x?xf32>
+ return %t : tensor<4x?xf32>
+}
+
+// -----
+
// There was an issue in cast + insert_slice folding generating invalid ir.
// https://github.com/llvm/llvm-project/issues/53099
// CHECK-LABEL: func @insert_slice_cast
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 9b6c2327879cf..943a6df16ce01 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -456,6 +456,14 @@ func.func @invalid_splat(%v : vector<8xf32>) {
// -----
+func.func @invalid_splat(%v: f32, %m: index) {
+ // expected-error at +1 {{incorrect number of dynamic sizes, has 1, expected 2}}
+ %w = tensor.splat %v[%m] : tensor<?x8x?xf32>
+ return
+}
+
+// -----
+
func.func @gather_empty_dims(
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
// expected-error at +1 {{gather_dims must be non-empty}}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2282da38803af..2b0a74acce082 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -309,6 +309,16 @@ func.func @test_splat_op(%s : f32) {
return
}
+// CHECK-LABEL: func @test_splat_op
+// CHECK-SAME: [[S:arg[0-9]+]]: f32
+// CHECK-SAME: [[M:arg[0-9]+]]: index
+// CHECK-SAME: [[N:arg[0-9]+]]: index
+func.func @test_splat_op_dynamic(%s: f32, %m: index, %n: index) {
+ // CHECK: tensor.splat %[[S]][%[[M]], %[[N]]] : tensor<?x8x?xf32>
+ %v = tensor.splat %s[%m, %n] : tensor<?x8x?xf32>
+ return
+}
+
// -----
// CHECK-LABEL: func.func @gather_scatter(
``````````
</details>
https://github.com/llvm/llvm-project/pull/74626
More information about the Mlir-commits
mailing list