[Mlir-commits] [mlir] [mlir][vector] Verify vector in `transferOp` has the same base elemental type as the source (PR #108710)
Longsheng Mou
llvmlistbot at llvm.org
Sat Sep 14 09:53:04 PDT 2024
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/108710
This PR fixes a bug in `verifyTransferOp` that allowed the `vector` in `vector.transferOp` to have a different base elemental type than the source. Fixes #108360.
>From e35e2942658293d2226efaa7b22382f212aefbc8 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Sun, 15 Sep 2024 00:42:50 +0800
Subject: [PATCH] [mlir][vector] Verify vector in `transferOp` has the same
base elemental type as the source
This PR fixes a bug in `verifyTransferOp` that allowed the `vector`
in `vector.transferOp` to have a different base elemental type than
the source.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 ++++
mlir/test/Dialect/Vector/invalid.mlir | 41 ++++++++++++++++++++++++
2 files changed, 47 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..50a86070058d7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3871,9 +3871,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
"requires source to be a memref or ranked tensor type");
auto elementType = shapedType.getElementType();
+ Type baseElementalType = elementType;
DataLayout dataLayout = DataLayout::closest(op);
if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
// Memref or tensor has vector element type.
+ baseElementalType = vectorElementType.getElementType();
unsigned sourceVecSize =
dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
vectorElementType.getShape().back();
@@ -3915,6 +3917,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
"the same rank as the vector type");
}
+ if (baseElementalType != vectorType.getElementType())
+ return op->emitOpError("expects a vector of the same base elemental "
+ "type as the source");
+
if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..5739d2bda16c6b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -475,6 +475,27 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
// -----
+func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 3.0 : f32
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %cst : memref<?x?xf32>, vector<128xi32>
+ return
+}
+
+// -----
+
+func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2xf32>>) {
+ %c3 = arith.constant 3 : index
+ %f0 = arith.constant 0.0 : f32
+ %vf0 = vector.splat %f0 : vector<2xf32>
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<2xf32>>, vector<2xi32>
+ return
+}
+
+// -----
+
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
@@ -638,6 +659,26 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
// -----
+func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant dense<1> : vector<3x4xi32>
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ vector.transfer_write %cst, %arg0[%c3, %c3] : vector<3x4xi32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+func.func @test_vector.transfer_write(%arg0: memref<?x?xvector<2xf32>>) {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant dense<1> : vector<2xi32>
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ vector.transfer_write %cst, %arg0[%c3, %c3] : vector<2xi32>, memref<?x?xvector<2xf32>>
+ return
+}
+
+// -----
+
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
// expected-error at +1 {{expected offsets of same size as destination vector rank}}
%1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
More information about the Mlir-commits
mailing list