[Mlir-commits] [mlir] [mlir][vector] Verify vector in `transferOp` has the same base elemental type as the source (PR #108710)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 14 09:53:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR fixes a bug in `verifyTransferOp` that allowed the `vector` in `vector.transferOp` to have a different base elemental type than the source. Fixes #<!-- -->108360.

---
Full diff: https://github.com/llvm/llvm-project/pull/108710.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+41) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..50a86070058d7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3871,9 +3871,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
         "requires source to be a memref or ranked tensor type");
 
   auto elementType = shapedType.getElementType();
+  Type baseElementalType = elementType;
   DataLayout dataLayout = DataLayout::closest(op);
   if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
     // Memref or tensor has vector element type.
+    baseElementalType = vectorElementType.getElementType();
     unsigned sourceVecSize =
         dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
         vectorElementType.getShape().back();
@@ -3915,6 +3917,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                              "the same rank as the vector type");
   }
 
+  if (baseElementalType != vectorType.getElementType())
+    return op->emitOpError("expects a vector of the same base elemental "
+                           "type as the source");
+
   if (permutationMap.getNumSymbols() != 0)
     return op->emitOpError("requires permutation_map without symbols");
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..5739d2bda16c6b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -475,6 +475,27 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 3.0 : f32
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  %0 = vector.transfer_read %arg0[%c3, %c3], %cst : memref<?x?xf32>, vector<128xi32>
+  return
+}
+
+// -----
+
+func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2xf32>>) {
+  %c3 = arith.constant 3 : index
+  %f0 = arith.constant 0.0 : f32
+  %vf0 = vector.splat %f0 : vector<2xf32>
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<2xf32>>, vector<2xi32>
+  return
+}
+
+// -----
+
 func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
@@ -638,6 +659,26 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 // -----
 
+func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant dense<1> : vector<3x4xi32>
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  vector.transfer_write %cst, %arg0[%c3, %c3] : vector<3x4xi32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+func.func @test_vector.transfer_write(%arg0: memref<?x?xvector<2xf32>>) {
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant dense<1> : vector<2xi32>
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  vector.transfer_write %cst, %arg0[%c3, %c3] : vector<2xi32>, memref<?x?xvector<2xf32>>
+  return
+}
+
+// -----
+
 func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected offsets of same size as destination vector rank}}
   %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/108710


More information about the Mlir-commits mailing list