[Mlir-commits] [mlir] Support for dynamic dimensions in 'tensor.splat' (PR #74626)

Rafael Ubal llvmlistbot at llvm.org
Wed Dec 6 09:14:51 PST 2023


https://github.com/rafaelubalmw created https://github.com/llvm/llvm-project/pull/74626

This feature had been marked as `TODO` in the `tensor.splat` documentation for a while. This MR includes:

- Support for dynamically shaped tensors in the return type of `tensor.splat` with the syntax suggested in the `TODO` comment.

- Updated op documentation.

- Bufferization support.

- Updates in op folders affected by the new feature.

- Unit tests for valid/invalid syntax, valid/invalid folding, and lowering through bufferization.

- Additional op builders resembling those available in `tensor.empty`.

>From 66287c8d3d23cfd3003baf82160013514b4bedb5 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Tue, 5 Dec 2023 22:54:16 -0500
Subject: [PATCH 1/4] Progress in 'tensor.splat' extensions

---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       | 49 +++++++++++++------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 43 ++++++++++++++++
 2 files changed, 77 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index f50e3464867be..60f188607e454 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1728,6 +1728,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
 
 def Tensor_SplatOp : Tensor_Op<"splat", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
     Pure,
     TypesMatchWith<"operand type matches element type of result",
                    "aggregate", "input",
@@ -1736,38 +1737,56 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
   let summary = "tensor splat or broadcast operation";
   let description = [{
     Broadcast the operand to all elements of the result tensor. The operand is
-    required to be of integer/index/float type, and the result tensor must be
-    statically shaped.
+    required to be of integer/index/float type.
 
-    Example:
+    An additional argument of type `index` must be provided for each dynamic
+    dimension present in the result type.
+
+    Example for a statically shaped tensor:
 
     ```mlir
     %s = arith.constant 10.1 : f32
     %t = tensor.splat %s : tensor<8x16xf32>
     ```
 
-    TODO: This operation is easy to extend to broadcast to dynamically shaped
-          tensors:
+    Example for a tensor containing dynamic dimensions:
 
     ```mlir
-    // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding
-    // to the sizes of the two dynamic dimensions.
-    %m = "foo"() : () -> (index)
-    %n = "bar"() : () -> (index)
-    %t = tensor.splat %s [%m, %n] : tensor<?x?xf32>
+    // Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding
+    // to dimensions 0 and 2 of the resulting tensor, respectively.
+    %m = arith.constant 10 : index
+    %n = arith.constant 30 : index
+    %t = tensor.splat %s[%m, %n] : tensor<?x20x?xf32>
     ```
   }];
 
   let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
-                                 "integer/index/float type">:$input);
-  let results = (outs AnyStaticShapeTensor:$aggregate);
+                                 "integer/index/float type">:$input,
+                       Variadic<Index>:$dynamicSizes);
+  let results = (outs AnyRankedTensor:$aggregate);
 
   let builders = [
-    OpBuilder<(ins "Value":$element, "Type":$aggregateType),
-    [{ build($_builder, $_state, aggregateType, element); }]>];
-  let assemblyFormat = "$input attr-dict `:` type($aggregate)";
+    // Build with an explicit result type and a list of values corresponding
+    // to the dynamic sizes present in the result type.
+    OpBuilder<(ins "Value":$element,
+                   "Type":$aggregateType,
+                   CArg<"ValueRange", "{}">:$dynamicSizes)>,
+
+    // Build with a result tensor shape and a list of values corresponding to
+    // the elements in the result tensor shape set to ShapedType::kDynamic.
+    OpBuilder<(ins "Value":$element,
+                   "ArrayRef<int64_t>":$staticShape,
+                   CArg<"ValueRange", "{}">:$dynamicSizes)>,
+
+    // Build with mixed static/dynamic sizes, where an attribute represents
+    // a static dimension and a value represents a dynamic dimension.
+    OpBuilder<(ins "Value":$element, "ArrayRef<OpFoldResult>":$sizes)>
+  ];
+
+  let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f15695383d34a..b5e15d8a6d457 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3430,11 +3430,54 @@ LogicalResult ScatterOp::verify() {
 // SplatOp
 //===----------------------------------------------------------------------===//
 
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+                    Type aggregateType, ValueRange dynamicSizes) {
+  build(builder, result, aggregateType, element, dynamicSizes);
+}
+
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+                    ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
+  auto aggregateType = RankedTensorType::get(staticShape, element.getType());
+  build(builder, result, aggregateType, element, dynamicSizes);
+}
+
+void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
+                    ArrayRef<OpFoldResult> sizes) {
+  SmallVector<int64_t> staticShape;
+  SmallVector<Value> dynamicSizes;
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
+  build(builder, result, element, staticShape, dynamicSizes);
+}
+
 void SplatOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
   setNameFn(getResult(), "splat");
 }
 
+LogicalResult SplatOp::verify() {
+  if (getType().getNumDynamicDims() !=
+      static_cast<int64_t>(getDynamicSizes().size()))
+    return emitOpError("incorrect number of dynamic sizes, has ")
+           << getDynamicSizes().size() << ", expected "
+           << getType().getNumDynamicDims();
+  return success();
+}
+
+LogicalResult
+SplatOp::reifyResultShapes(OpBuilder &builder,
+                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+  unsigned ctr = 0;
+  for (int64_t i = 0; i < getType().getRank(); ++i) {
+    if (getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
+    } else {
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+    }
+  }
+  return success();
+}
+
 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   auto constOperand = adaptor.getInput();
   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())

>From 823899dd3977768c4d99c63efc4d356661f2d46d Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 6 Dec 2023 09:49:22 -0500
Subject: [PATCH 2/4] Added unit tests

---
 mlir/test/Dialect/Tensor/bufferize.mlir | 20 ++++++++++++++++++++
 mlir/test/Dialect/Tensor/invalid.mlir   |  8 ++++++++
 mlir/test/Dialect/Tensor/ops.mlir       | 10 ++++++++++
 3 files changed, 38 insertions(+)

diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index a8b3c6af9ae89..e3c6ebbeb9d91 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -602,3 +602,23 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
   %t = tensor.splat %f : tensor<10x2x4xf32>
   return %t : tensor<10x2x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor.splat.dynamic(
+// CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:  %[[N:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG:     %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
+// CHECK:         %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:         %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
+// CHECK:         () {
+// CHECK:           linalg.yield %[[F]] : f32
+// CHECK:         }
+// CHECK:         return %[[MAPPED]] : tensor<?x3x?xf32>
+// CHECK:       }
+func.func @tensor.splat.dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf32> {
+  %0 = tensor.splat %f[%m, %n] : tensor<?x3x?xf32>
+  return %0 : tensor<?x3x?xf32>
+}
+
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 9b6c2327879cf..943a6df16ce01 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -456,6 +456,14 @@ func.func @invalid_splat(%v : vector<8xf32>) {
 
 // -----
 
+func.func @invalid_splat(%v: f32, %m: index) {
+  // expected-error at +1 {{incorrect number of dynamic sizes, has 1, expected 2}}
+  %w = tensor.splat %v[%m] : tensor<?x8x?xf32>
+  return
+}
+
+// -----
+
 func.func @gather_empty_dims(
     %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
   // expected-error at +1 {{gather_dims must be non-empty}}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2282da38803af..2b0a74acce082 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -309,6 +309,16 @@ func.func @test_splat_op(%s : f32) {
   return
 }
 
+// CHECK-LABEL: func @test_splat_op
+// CHECK-SAME: [[S:arg[0-9]+]]: f32
+// CHECK-SAME: [[M:arg[0-9]+]]: index
+// CHECK-SAME: [[N:arg[0-9]+]]: index
+func.func @test_splat_op_dynamic(%s: f32, %m: index, %n: index) {
+  // CHECK: tensor.splat %[[S]][%[[M]], %[[N]]] : tensor<?x8x?xf32>
+  %v = tensor.splat %s[%m, %n] : tensor<?x8x?xf32>
+  return
+}
+
 // -----
 
 // CHECK-LABEL: func.func @gather_scatter(

>From 8e6a26b2e3d2cf5bbfa71f6c8da462a69299c1d3 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 6 Dec 2023 10:43:49 -0500
Subject: [PATCH 3/4] Added unit tests for no-fold cases

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  6 +++-
 mlir/test/Dialect/Tensor/canonicalize.mlir | 40 ++++++++++++++++++++++
 2 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b5e15d8a6d457..8fad57eea64ae 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1736,7 +1736,7 @@ class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
     auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
-    if (!splatOp)
+    if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
       return failure();
 
     rewriter.replaceOpWithNewOp<tensor::SplatOp>(
@@ -3483,6 +3483,10 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
     return {};
 
+  // Do not fold if the splat is not statically shaped
+  if (!getType().hasStaticShape())
+    return {};
+
   // SplatElementsAttr::get treats single value for second arg as being a
   // splat.
   return SplatElementsAttr::get(getType(), {constOperand});
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 84c44a09aa3dd..6b86341911f59 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1204,6 +1204,19 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
 
 // -----
 
+// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
+// CHECK-SAME: %[[F:.+]]: f32
+// CHECK-SAME: %[[M:.+]]: index
+func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> {
+  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+  // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
+  %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32>
+  return %0 : tensor<2x2x?xf32>
+}
+
+// -----
+
 func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
   %c0 = tensor.splat %arg : tensor<2x2x2xf32>
   %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
@@ -1217,6 +1230,20 @@ func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
 //       CHECK:   return %[[CST]]
 
 // -----
+
+// CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold
+// CHECK-SAME: %[[F:.+]]: f32
+// CHECK-SAME: %[[M:.+]]: index
+func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
+  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+  // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
+  %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
+  %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
+  return %0 : tensor<2x?xf32>
+}
+
+// -----
+
 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
   %c0 = arith.constant dense<42> : tensor<2x8xi16>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]
@@ -1627,6 +1654,19 @@ func.func @splat_fold() -> tensor<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @splat_dynamic_no_fold
+// CHECK-SAME: %[[M:.+]]: index
+func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
+  // CHECK: %[[F:.+]] = arith.constant
+  %f = arith.constant 1.0 : f32
+
+  // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
+  %t = tensor.splat %f[%m] : tensor<4x?xf32>
+  return %t : tensor<4x?xf32>
+}
+
+// -----
+
 // There was an issue in cast + insert_slice folding generating invalid ir.
 // https://github.com/llvm/llvm-project/issues/53099
 // CHECK-LABEL: func @insert_slice_cast

>From 1fdb7487fabe5bb8c775e343091ae9442100de7b Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Wed, 6 Dec 2023 12:09:29 -0500
Subject: [PATCH 4/4] Changed 'tensor.splat' example

---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 60f188607e454..251a53ed3b888 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1745,7 +1745,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
     Example for a statically shaped tensor:
 
     ```mlir
-    %s = arith.constant 10.1 : f32
+    %s = arith.constant 1.0 : f32
     %t = tensor.splat %s : tensor<8x16xf32>
     ```
 



More information about the Mlir-commits mailing list