[Mlir-commits] [mlir] [mlir][vector] Fix 0-d vector transfer mask inference (PR #116526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 16 20:59:27 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

When we infer the mask of a transfer operation which results in a single `i1` element, we can use `vector<i1>` or `vector<1xi1>` to represent it. To avoid issues with this type mismatch, this PR fixes the mask inference logic to always generate `vector<1xi1>` for these cases. We can enable 0-d masks if they are eventually needed.

See: https://github.com/llvm/llvm-project/issues/116197

---
Full diff: https://github.com/llvm/llvm-project/pull/116526.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+16) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..6ba6d30099ce91 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4122,6 +4122,10 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
   assert(invPermMap && "Inversed permutation map couldn't be computed");
   SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
 
+  // Turn a 0-D mask into a single-element 1-D mask.
+  if (maskShape.empty())
+    maskShape.push_back(1);
+
   SmallVector<bool> scalableDims =
       applyPermutationMap(invPermMap, vecType.getScalableDims());
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..bfbdc83e382272 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1752,6 +1752,22 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>
 
 // -----
 
+// We can support 0-D masks if eventually needed.
+func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
+                               %idx0: index, %idx1: index,
+                               %m0: vector<i1>) -> vector<1x1x4xi32> {
+  %cst = arith.constant 0 : i32
+  // expected-error at +1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
+  %res = vector.mask %m0 {
+    %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+      : tensor<2x4xi32>, vector<1x1x4xi32>
+    vector.yield %0 : vector<1x1x4xi32>
+  } : vector<i1> -> vector<1x1x4xi32>
+  return %res : vector<1x1x4xi32>
+}
+
+// -----
+
 func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
   // expected-error at +1 {{op failed to verify that position is a multiple of the source length.}}
   %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..04d9ff0546160a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1028,6 +1028,20 @@ func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @vector_mask_scalar_broadcast_transfer
+func.func @vector_mask_scalar_broadcast_transfer(%arg0: tensor<2x4xi32>,
+                                                 %idx0: index, %idx1: index,
+                                                 %m0: vector<1xi1>) -> vector<1x1x4xi32> {
+  %cst = arith.constant 0 : i32
+  // CHECK: vector.mask %{{.*}} { vector.transfer_read {{.*}} } : vector<1xi1> -> vector<1x1x4xi32>
+  %res = vector.mask %m0 {
+    %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+      : tensor<2x4xi32>, vector<1x1x4xi32>
+    vector.yield %0 : vector<1x1x4xi32>
+  } : vector<1xi1> -> vector<1x1x4xi32>
+  return %res : vector<1x1x4xi32>
+}
+
 // CHECK-LABEL: func @vector_scalable_insert(
 // CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
 // CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/116526


More information about the Mlir-commits mailing list