[Mlir-commits] [mlir] 214d32c - Support for dynamic dimensions in 'tensor.splat' (#74626)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 15 05:54:49 PST 2023


Author: Rafael Ubal
Date: 2023-12-15T13:54:45Z
New Revision: 214d32ccd2be05262a328563d3792ec0d36404b0

URL: https://github.com/llvm/llvm-project/commit/214d32ccd2be05262a328563d3792ec0d36404b0
DIFF: https://github.com/llvm/llvm-project/commit/214d32ccd2be05262a328563d3792ec0d36404b0.diff

LOG: Support for dynamic dimensions in 'tensor.splat' (#74626)

This feature had been marked as `TODO` in the `tensor.splat`
documentation for a while. This MR includes:

- Support for dynamically shaped tensors in the return type of
`tensor.splat` with the syntax suggested in the `TODO` comment.

- Updated op documentation.

- Bufferization support.

- Updates in op folders affected by the new feature.

- Unit tests for valid/invalid syntax, valid/invalid folding, and
lowering through bufferization.

- Additional op builders resembling those available in `tensor.empty`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 6d3371e9071d05..eb0c79c01bee1a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1728,6 +1728,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
 
 def Tensor_SplatOp : Tensor_Op<"splat", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
     Pure,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
@@ -1736,38 +1737,56 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
   let summary = "tensor splat or broadcast operation";
   let description = [{
     Broadcast the operand to all elements of the result tensor. The operand is
-    required to be of integer/index/float type, and the result tensor must be
-    statically shaped.
+    required to be of integer/index/float type.
 
-    Example:
+    An additional argument of type `index` must be provided for each dynamic
+    dimension present in the result type.
+
+    Example for a statically shaped tensor:
 
     ```mlir
-    %s = arith.constant 10.1 : f32
+    %s = arith.constant 1.0 : f32
     %t = tensor.splat %s : tensor<8x16xf32>
     ```
 
-    TODO: This operation is easy to extend to broadcast to dynamically shaped
-          tensors:
+    Example for a tensor containing dynamic dimensions:
 
     ```mlir
-    // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
-    // to the sizes of the two dynamic dimensions.
-    %m = "foo"() : () -> (index)
-    %n = "bar"() : () -> (index)
-    %t = tensor.splat %s [%m, %n] : tensor<?x?xf32>
+    // Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding
+    // to dimensions 0 and 2 of the resulting tensor, respectively.
+    %m = arith.constant 10 : index
+    %n = arith.constant 30 : index
+    %t = tensor.splat %s[%m, %n] : tensor<?x20x?xf32>
     ```
   }];
 
   let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input);
-  let results = (outs AnyStaticShapeTensor:$aggregate);
+                                 "integer/index/float type">:$input,
+                       Variadic<Index>:$dynamicSizes);
+  let results = (outs AnyRankedTensor:$aggregate);
 
   let builders = [
-    OpBuilder<(ins "Value":$element, "Type":$aggregateType),
-    [{ build($_builder, $_state, aggregateType, element); }]>];
-  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
+    // Build with an explicit result type and a list of values corresponding
+    // to the dynamic sizes present in the result type.
+    OpBuilder<(ins "Value":$element,
+                   "Type":$aggregateType,
+                   CArg<"ValueRange", "{}">:$dynamicSizes)>,
+
+    // Build with a result tensor shape and a list of values corresponding to
+    // the elements in the result tensor shape set to ShapedType::kDynamic.
+    OpBuilder<(ins "Value":$element,
+                   "ArrayRef<int64_t>":$staticShape,
+                   CArg<"ValueRange", "{}">:$dynamicSizes)>,
+
+    // Build with mixed static/dynamic sizes, where an attribute represents
+    // a static dimension and a value represents a dynamic dimension.
+    OpBuilder<(ins "Value":$element, "ArrayRef<OpFoldResult>":$sizes)>
+  ];
+
+  let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9ef4ae84536841..1b0cdbd0f4f739 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1704,7 +1704,7 @@ class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
     auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
-    if (!splatOp)
+    if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
       return failure();
 
     rewriter.replaceOpWithNewOp<tensor::SplatOp>(
@@ -3400,16 +3400,63 @@ LogicalResult ScatterOp::verify() {
 // SplatOp
 //===----------------------------------------------------------------------===//
 
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+                    Type aggregateType, ValueRange dynamicSizes) {
+  build(builder, result, aggregateType, element, dynamicSizes);
+}
+
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+                    ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
+  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
+  build(builder, result, aggregateType, element, dynamicSizes);
+}
+
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+                    ArrayRef<OpFoldResult> sizes) {
+  SmallVector<int64_t> staticShape;
+  SmallVector<Value> dynamicSizes;
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
+  build(builder, result, element, staticShape, dynamicSizes);
+}
+
 void SplatOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "splat");
 }
 
+LogicalResult SplatOp::verify() {
+  if (getType().getNumDynamicDims() !=
+      static_cast<int64_t>(getDynamicSizes().size()))
+    return emitOpError("incorrect number of dynamic sizes, has ")
+           << getDynamicSizes().size() << ", expected "
+           << getType().getNumDynamicDims();
+  return success();
+}
+
+LogicalResult
+SplatOp::reifyResultShapes(OpBuilder &builder,
+                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+  unsigned ctr = 0;
+  for (int64_t i = 0; i < getType().getRank(); ++i) {
+    if (getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
+    } else {
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+    }
+  }
+  return success();
+}
+
 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   auto constOperand = adaptor.getInput();
   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
     return {};
 
+  // Do not fold if the splat is not statically shaped
+  if (!getType().hasStaticShape())
+    return {};
+
   // SplatElementsAttr::get treats single value for second arg as being a
   // splat.
   return SplatElementsAttr::get(getType(), {constOperand});

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index a8b3c6af9ae893..815bc383af95a6 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -602,3 +602,23 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
   %t = tensor.splat %f : tensor<10x2x4xf32>
   return %t : tensor<10x2x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor.splat_dynamic(
+// CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:  %[[N:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
+// CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
+// CHECK:         () {
+// CHECK:           linalg.yield %[[F]] : f32
+// CHECK:         }
+// CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>
+// CHECK:       }
+func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf32> {
+  %0 = tensor.splat %f[%m, %n] : tensor<?x3x?xf32>
+  return %0 : tensor<?x3x?xf32>
+}
+

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 8542fc9567412b..ed964071358ace 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1248,6 +1248,19 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
 
 // -----
 
+// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
+// CHECK-SAME: %[[F:.+]]: f32
+// CHECK-SAME: %[[M:.+]]: index
+func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> {
+  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+  // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
+  %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32>
+  return %0 : tensor<2x2x?xf32>
+}
+
+// -----
+
 func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
   %c0 = tensor.splat %arg : tensor<2x2x2xf32>
   %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
@@ -1261,6 +1274,20 @@ func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
 //       CHECK:   return %[[CST]]
 
 // -----
+
+// CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold
+// CHECK-SAME: %[[F:.+]]: f32
+// CHECK-SAME: %[[M:.+]]: index
+func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
+  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
+  %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
+  %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}
+
+// -----
+
 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
   %c0 = arith.constant dense<42> : tensor<2x8xi16>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]
@@ -1671,6 +1698,19 @@ func.func @splat_fold() -> tensor<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @splat_dynamic_no_fold
+// CHECK-SAME: %[[M:.+]]: index
+func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
+  // CHECK: %[[F:.+]] = arith.constant
+  %f = arith.constant 1.0 : f32
+
+  // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
+  %t = tensor.splat %f[%m] : tensor<4x?xf32>
+  return %t : tensor<4x?xf32>
+}
+
+// -----
+
 // There was an issue in cast + insert_slice folding generating invalid ir.
 // https://github.com/llvm/llvm-project/issues/53099
 // CHECK-LABEL: func @insert_slice_cast

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index bdada43e325c55..735e5146e9dbc8 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -443,6 +443,14 @@ func.func @invalid_splat(%v : vector<8xf32>) {
 
 // -----
 
+func.func @invalid_splat(%v: f32, %m: index) {
+  // expected-error at +1 {{incorrect number of dynamic sizes, has 1, expected 2}}
+  %w = tensor.splat %v[%m] : tensor<?x8x?xf32>
+  return
+}
+
+// -----
+
 func.func @gather_empty_dims(
     %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
   // expected-error at +1 {{gather_dims must be non-empty}}

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2282da38803af0..2b0a74acce0826 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -309,6 +309,16 @@ func.func @test_splat_op(%s : f32) {
   return
 }
 
+// CHECK-LABEL: func @test_splat_op
+// CHECK-SAME: [[S:arg[0-9]+]]: f32
+// CHECK-SAME: [[M:arg[0-9]+]]: index
+// CHECK-SAME: [[N:arg[0-9]+]]: index
+func.func @test_splat_op_dynamic(%s: f32, %m: index, %n: index) {
+  // CHECK: tensor.splat %[[S]][%[[M]], %[[N]]] : tensor<?x8x?xf32>
+  %v = tensor.splat %s[%m, %n] : tensor<?x8x?xf32>
+  return
+}
+
 // -----
 
 // CHECK-LABEL: func.func @gather_scatter(


        


More information about the Mlir-commits mailing list