[Mlir-commits] [mlir] [mlir][vector] Verify vector in `transferOp` has the same base elemental type as the source (PR #108710)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 14 09:53:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR fixes a bug in `verifyTransferOp` that allowed the `vector` in `vector.transferOp` to have a different base elemental type than the source. Fixes #<!-- -->108360.
---
Full diff: https://github.com/llvm/llvm-project/pull/108710.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+6)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+41)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..50a86070058d7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3871,9 +3871,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
"requires source to be a memref or ranked tensor type");
auto elementType = shapedType.getElementType();
+ Type baseElementalType = elementType;
DataLayout dataLayout = DataLayout::closest(op);
if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
// Memref or tensor has vector element type.
+ baseElementalType = vectorElementType.getElementType();
unsigned sourceVecSize =
dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
vectorElementType.getShape().back();
@@ -3915,6 +3917,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
"the same rank as the vector type");
}
+ if (baseElementalType != vectorType.getElementType())
+ return op->emitOpError("expects a vector of the same base elemental "
+ "type as the source");
+
if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c95b8bd5ed6147..5739d2bda16c6b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -475,6 +475,27 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
// -----
+func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 3.0 : f32
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %cst : memref<?x?xf32>, vector<128xi32>
+ return
+}
+
+// -----
+
+func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2xf32>>) {
+ %c3 = arith.constant 3 : index
+ %f0 = arith.constant 0.0 : f32
+ %vf0 = vector.splat %f0 : vector<2xf32>
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<2xf32>>, vector<2xi32>
+ return
+}
+
+// -----
+
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
@@ -638,6 +659,26 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
// -----
+func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant dense<1> : vector<3x4xi32>
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ vector.transfer_write %cst, %arg0[%c3, %c3] : vector<3x4xi32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+func.func @test_vector.transfer_write(%arg0: memref<?x?xvector<2xf32>>) {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant dense<1> : vector<2xi32>
+ // expected-error at +1 {{expects a vector of the same base elemental type as the source}}
+ vector.transfer_write %cst, %arg0[%c3, %c3] : vector<2xi32>, memref<?x?xvector<2xf32>>
+ return
+}
+
+// -----
+
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
// expected-error at +1 {{expected offsets of same size as destination vector rank}}
%1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/108710
More information about the Mlir-commits
mailing list