[Mlir-commits] [mlir] 8602204 - [mlir][tensor] Relax input type requirement on `tensor.splat` (#145893)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 30 00:49:26 PDT 2025
Author: Markus Böck
Date: 2025-06-30T09:49:19+02:00
New Revision: 8602204d9fc483c7c58fa4e4d422d9bffb4e4e95
URL: https://github.com/llvm/llvm-project/commit/8602204d9fc483c7c58fa4e4d422d9bffb4e4e95
DIFF: https://github.com/llvm/llvm-project/commit/8602204d9fc483c7c58fa4e4d422d9bffb4e4e95.diff
LOG: [mlir][tensor] Relax input type requirement on `tensor.splat` (#145893)
`tensor.splat` is currently restricted to only accepting input values
that are of integer, index or float type.
This is much more restrictive than the tensor type itself as well as any
lowerings of it.
This PR therefore removes this restriction by using `AnyType` for the
input value. Whether the type is actually valid or not for a tensor
remains verified through the type equality of the result tensor element
type and the input type.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 47962f75558ea..7d396e5c64c28 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
]> {
let summary = "tensor splat or broadcast operation";
let description = [{
- Broadcast the operand to all elements of the result tensor. The operand is
- required to be of integer/index/float type.
+ Broadcast the operand to all elements of the result tensor.
An additional argument of type `index` must be provided for each dynamic
dimension present in the result type.
@@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
```
}];
- let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
- "integer/index/float type">:$input,
+ let arguments = (ins AnyType:$input,
Variadic<Index>:$dynamicSizes);
let results = (outs AnyRankedTensor:$aggregate);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index c0adc8a49bf70..296ca02564e35 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.splat_other(
+// CHECK-SAME: %[[F:.*]]: !test.memref_element)
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element>
+// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: %[[MAPPED:.*]] = linalg.map
+// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>)
+// CHECK: linalg.yield %[[F]]
+// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element>
+func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> {
+ %t = tensor.splat %f : tensor<10x2x4x!test.memref_element>
+ return %t : tensor<10x2x4x!test.memref_element>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.concat(
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f35d52e700084..665657a67dc61 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
// -----
-func.func @invalid_splat(%v : vector<8xf32>) {
- // expected-error at +1 {{must be integer/index/float type}}
- %w = tensor.splat %v : tensor<8xvector<8xf32>>
+// expected-note at +1 {{prior use here}}
+func.func @invalid_splat(%v : f32) {
+ // expected-error at +1 {{expects
diff erent type than prior uses: 'i32' vs 'f32'}}
+ %w = tensor.splat %v : tensor<1xi32>
return
}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 930986211cb6d..681a934ba0698 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -313,13 +313,17 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// -----
// CHECK-LABEL: func @test_splat_op
-// CHECK-SAME: [[S:%arg[0-9]+]]: f32
-func.func @test_splat_op(%s : f32) {
- // CHECK: tensor.splat [[S]] : tensor<8xf32>
+// CHECK-SAME: %[[S:.*]]: f32
+// CHECK-SAME: %[[P:.*]]: !llvm.ptr
+func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
+ // CHECK: tensor.splat %[[S]] : tensor<8xf32>
%v = tensor.splat %s : tensor<8xf32>
- // CHECK: tensor.splat [[S]] : tensor<4xf32>
+ // CHECK: tensor.splat %[[S]] : tensor<4xf32>
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
+
+ // CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr>
+ %w = tensor.splat %p : tensor<8x!llvm.ptr>
return
}
More information about the Mlir-commits
mailing list