[Mlir-commits] [mlir] [mlir][vector] Fix 0-d vector transfer mask inference (PR #116526)
Diego Caballero
llvmlistbot at llvm.org
Wed Nov 20 20:19:08 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/116526
>From eae66c71d073eadbd273057e5db35155ce09069e Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Sat, 16 Nov 2024 20:51:01 -0800
Subject: [PATCH 1/2] [mlir][vector] Fix 0-d vector transfer mask inference
When we inferred the mask of a transfer operation that results in a
single `i1` element, we can represent it using a `vector<i1>` or a
`vector<1xi1>`. To avoid issues with the this type mismatch, this PR
fixes the mask inference logic to always generate `vector<1xi1>` for
these cases. We can enable 0-d masks if they are eventually needed.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++++
mlir/test/Dialect/Vector/invalid.mlir | 16 ++++++++++++++++
mlir/test/Dialect/Vector/ops.mlir | 14 ++++++++++++++
3 files changed, 34 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..6ba6d30099ce91 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4122,6 +4122,10 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
+ // Turn a 0-D mask into a single-element 1-D mask.
+ if (maskShape.empty())
+ maskShape.push_back(1);
+
SmallVector<bool> scalableDims =
applyPermutationMap(invPermMap, vecType.getScalableDims());
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..bfbdc83e382272 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1752,6 +1752,22 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>
// -----
+// We can support 0-D masks if eventually needed.
+func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
+ %idx0: index, %idx1: index,
+ %m0: vector<i1>) -> vector<1x1x4xi32> {
+ %cst = arith.constant 0 : i32
+ // expected-error at +1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
+ %res = vector.mask %m0 {
+ %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+ : tensor<2x4xi32>, vector<1x1x4xi32>
+ vector.yield %0 : vector<1x1x4xi32>
+ } : vector<i1> -> vector<1x1x4xi32>
+ return %res : vector<1x1x4xi32>
+}
+
+// -----
+
func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error at +1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..04d9ff0546160a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1028,6 +1028,20 @@ func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @vector_mask_scalar_broadcast_transfer
+func.func @vector_mask_scalar_broadcast_transfer(%arg0: tensor<2x4xi32>,
+ %idx0: index, %idx1: index,
+ %m0: vector<1xi1>) -> vector<1x1x4xi32> {
+ %cst = arith.constant 0 : i32
+ // CHECK: vector.mask %{{.*}} { vector.transfer_read {{.*}} } : vector<1xi1> -> vector<1x1x4xi32>
+ %res = vector.mask %m0 {
+ %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+ : tensor<2x4xi32>, vector<1x1x4xi32>
+ vector.yield %0 : vector<1x1x4xi32>
+ } : vector<1xi1> -> vector<1x1x4xi32>
+ return %res : vector<1x1x4xi32>
+}
+
// CHECK-LABEL: func @vector_scalable_insert(
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
>From 2208f2bfccc61588bf4b9579791e3bb51db2ee82 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Wed, 20 Nov 2024 20:18:48 -0800
Subject: [PATCH 2/2] Feedback
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 +++-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
mlir/test/Dialect/Vector/invalid.mlir | 1 -
3 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..cc4cafa869e63a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2475,7 +2475,9 @@ def Vector_MaskOp : Vector_Op<"mask", [
should not. The `vector.mask` operation returns the value produced by the
masked execution of the nested operation, if any. The masked-off lanes in
the result vector are taken from the corresponding lanes of the pass-thru
- argument, if provided, or left unmodified, otherwise.
+ argument, if provided, or left unmodified, otherwise. At this point, 0-D
+ vectors are not supported by `vector.mask`. They may be supported in the
+ future.
The `vector.mask` operation does not prescribe how a maskable operation
should be masked or how a masked operation should be lowered. Masking
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6ba6d30099ce91..1b2f9b7abba5e3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4122,7 +4122,8 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
- // Turn a 0-D mask into a single-element 1-D mask.
+ // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
+ // 0-D mask into a single-element 1-D mask.
if (maskShape.empty())
maskShape.push_back(1);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index bfbdc83e382272..0c093b0ccff141 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1752,7 +1752,6 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>
// -----
-// We can support 0-D masks if eventually needed.
func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
%idx0: index, %idx1: index,
%m0: vector<i1>) -> vector<1x1x4xi32> {
More information about the Mlir-commits
mailing list