[Mlir-commits] [mlir] [mlir][vector] Verify vector in `transferOp` has the same base elemental type as the source (PR #108710)

Longsheng Mou llvmlistbot at llvm.org
Sat Sep 14 09:53:04 PDT 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/108710

This PR fixes a bug in `verifyTransferOp` that allowed the `vector` in `vector.transferOp` to have a different base elemental type than the source. Fixes #108360.

>From e35e2942658293d2226efaa7b22382f212aefbc8 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Sun, 15 Sep 2024 00:42:50 +0800
Subject: [PATCH] [mlir][vector] Verify vector in `transferOp` has the same
 base elemental type as the source

This PR fixes a bug in `verifyTransferOp` that allowed the `vector`
in `vector.transferOp` to have a different base elemental type than
the source.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp |  6 ++++
 mlir/test/Dialect/Vector/invalid.mlir    | 41 ++++++++++++++++++++++++
 2 files changed, 47 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..50a86070058d7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3871,9 +3871,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
         "requires source to be a memref or ranked tensor type");
 
   auto elementType = shapedType.getElementType();
+  Type baseElementalType = elementType;
   DataLayout dataLayout = DataLayout::closest(op);
   if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
     // Memref or tensor has vector element type.
+    baseElementalType = vectorElementType.getElementType();
     unsigned sourceVecSize =
         dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
         vectorElementType.getShape().back();
@@ -3915,6 +3917,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                              "the same rank as the vector type");
   }
 
+  if (baseElementalType != vectorType.getElementType())
+    return op->emitOpError("expects a vector of the same base elemental "
+                           "type as the source");
+
   if (permutationMap.getNumSymbols() != 0)
     return op->emitOpError("requires permutation_map without symbols");
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..5739d2bda16c6b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -475,6 +475,27 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 3.0 : f32
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  %0 = vector.transfer_read %arg0[%c3, %c3], %cst : memref<?x?xf32>, vector<128xi32>
+  return
+}
+
+// -----
+
+func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2xf32>>) {
+  %c3 = arith.constant 3 : index
+  %f0 = arith.constant 0.0 : f32
+  %vf0 = vector.splat %f0 : vector<2xf32>
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<2xf32>>, vector<2xi32>
+  return
+}
+
+// -----
+
 func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
   %c3 = arith.constant 3 : index
   %f0 = arith.constant 0.0 : f32
@@ -638,6 +659,26 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 // -----
 
+func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant dense<1> : vector<3x4xi32>
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  vector.transfer_write %cst, %arg0[%c3, %c3] : vector<3x4xi32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+func.func @test_vector.transfer_write(%arg0: memref<?x?xvector<2xf32>>) {
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant dense<1> : vector<2xi32>
+  // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+  vector.transfer_write %cst, %arg0[%c3, %c3] : vector<2xi32>, memref<?x?xvector<2xf32>>
+  return
+}
+
+// -----
+
 func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected offsets of same size as destination vector rank}}
   %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>



More information about the Mlir-commits mailing list