[Mlir-commits] [mlir] [mlir][Vector] Remove 0-d corner case condition. (PR #112937)

Harrison Hao llvmlistbot at llvm.org
Thu Nov 14 19:10:21 PST 2024


https://github.com/harrisonGPU updated https://github.com/llvm/llvm-project/pull/112937

>From 5e07a46f98d641cc43f49780b9993940ef0c735f Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Tue, 5 Nov 2024 20:11:01 +0800
Subject: [PATCH 1/6] [mlir][Vector] Support 0-d vectors natively in
 VectorStoreToMemrefStoreLowering.

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 12 ++-----
 .../VectorToLLVM/vector-to-llvm.mlir          |  5 ++-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  | 36 +++++++++----------
 .../vector-transfer-to-vector-load-store.mlir |  8 ++---
 4 files changed, 26 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
 
     Value extracted;
-    if (vecType.getRank() == 0) {
-      // TODO: Unifiy once ExtractOp supports 0-d vectors.
-      extracted = rewriter.create<vector::ExtractElementOp>(
-          storeOp.getLoc(), storeOp.getValueToStore());
-    } else {
-      SmallVector<int64_t> indices(vecType.getRank(), 0);
-      extracted = rewriter.create<vector::ExtractOp>(
-          storeOp.getLoc(), storeOp.getValueToStore(), indices);
-    }
+    SmallVector<int64_t> indices(vecType.getRank(), 0);
+    extracted = rewriter.create<vector::ExtractOp>(
+        storeOp.getLoc(), storeOp.getValueToStore(), indices);
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 // CHECK-LABEL: func @vector_store_op_0d
 // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
 
 // -----
 
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
 func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
     %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   memref.store %[[S]], %[[MEM]][] : memref<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
     vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
 
     return

>From 3f9f7b0753fabc59d4dda6072c72fb6939b64858 Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Tue, 5 Nov 2024 14:41:31 +0000
Subject: [PATCH 2/6] [MLIR] Fix the lit test failure issue.

---
 mlir/test/Dialect/SPIRV/IR/availability.mlir | 36 ++++++++++----------
 1 file changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index ceebeeffcf2677..c583a48eba2704 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
 func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64

>From 5f61d162694bbb3e7d620d7dc1f4a100bc17716d Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Wed, 6 Nov 2024 15:48:41 +0000
Subject: [PATCH 3/6] [MLIR] Update comments.

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 1cb3baaef82baf..6c50473232e1b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -518,7 +518,7 @@ struct VectorLoadToMemrefLoadLowering
   }
 };
 
-/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
+/// Replace a vector.store with a vector.extract + memref.store.
 struct VectorStoreToMemrefStoreLowering
     : public OpRewritePattern<vector::StoreOp> {
   using OpRewritePattern::OpRewritePattern;

>From 6db82e0463ce7a0852925ea615c73b1dae63b23d Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Wed, 6 Nov 2024 15:51:12 +0000
Subject: [PATCH 4/6] [MLIR] Update comments again.

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 6c50473232e1b8..6f033cbe025098 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -518,7 +518,7 @@ struct VectorLoadToMemrefLoadLowering
   }
 };
 
-/// Replace a vector.store with a vector.extract + memref.store.
+/// Replace a vector.store with a vector.extractelement + memref.store.
 struct VectorStoreToMemrefStoreLowering
     : public OpRewritePattern<vector::StoreOp> {
   using OpRewritePattern::OpRewritePattern;

>From 6a5ac3ca2853b2b5822e8c642cd671038c1c6988 Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Sun, 10 Nov 2024 13:16:53 +0000
Subject: [PATCH 5/6] [MLIR][Vector] Remove 0-d corner case condition.

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 20 +++++++++----------
 .../VectorToLLVM/vector-to-llvm.mlir          |  5 +++--
 .../vector-transfer-to-vector-load-store.mlir |  8 +++++---
 3 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 6f033cbe025098..a953b242207018 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -97,9 +97,6 @@ struct TransferReadPermutationLowering
   matchAndRewriteMaskableOp(vector::TransferReadOp op,
                             MaskingOpInterface maskOp,
                             PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (op.getTransferRank() == 0)
-      return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
     // TODO: Support transfer_read inside MaskOp case.
     if (maskOp)
       return rewriter.notifyMatchFailure(op, "Masked case not supported");
@@ -326,9 +323,6 @@ struct TransferOpReduceRank
   matchAndRewriteMaskableOp(vector::TransferReadOp op,
                             MaskingOpInterface maskOp,
                             PatternRewriter &rewriter) const override {
-    // TODO: support 0-d corner case.
-    if (op.getTransferRank() == 0)
-      return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
     // TODO: support masked case.
     if (maskOp)
       return rewriter.notifyMatchFailure(op, "Masked case not supported");
@@ -518,7 +512,7 @@ struct VectorLoadToMemrefLoadLowering
   }
 };
 
-/// Replace a vector.store with a vector.extractelement + memref.store.
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
 struct VectorStoreToMemrefStoreLowering
     : public OpRewritePattern<vector::StoreOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -530,9 +524,15 @@ struct VectorStoreToMemrefStoreLowering
       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
 
     Value extracted;
-    SmallVector<int64_t> indices(vecType.getRank(), 0);
-    extracted = rewriter.create<vector::ExtractOp>(
-        storeOp.getLoc(), storeOp.getValueToStore(), indices);
+    if (vecType.getRank() == 0) {
+      // TODO: Unifiy once ExtractOp supports 0-d vectors.
+      extracted = rewriter.create<vector::ExtractElementOp>(
+          storeOp.getLoc(), storeOp.getValueToStore());
+    } else {
+      SmallVector<int64_t> indices(vecType.getRank(), 0);
+      extracted = rewriter.create<vector::ExtractOp>(
+          storeOp.getLoc(), storeOp.getValueToStore(), indices);
+    }
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index abbdbe02ce6c1e..c1de24fd0403ce 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,8 +2971,9 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 // CHECK-LABEL: func @vector_store_op_0d
 // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
-// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
+// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
 
 // -----
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f75f8f8489efc1..f90111b4c88618 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,13 +7,15 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
     %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   memref.store %[[S]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
+//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
     vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
 
     return

>From 700b0ea025a57fe888f25b492b40de7502159b6f Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Fri, 15 Nov 2024 11:00:28 +0800
Subject: [PATCH 6/6] [MLIR] Remove Transfer vectore lower pattern.

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |   2 -
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 -
 .../Vector/Transforms/LowerVectorTransfer.cpp |  12 +-
 .../Conversion/GPUCommon/transfer_write.mlir  |   5 +-
 .../VectorToLLVM/vector-to-llvm.mlir          |  32 +-
 .../VectorToLLVM/vector-xfer-to-llvm.mlir     | 319 ++----------------
 .../test/Dialect/Vector/transform-vector.mlir |   3 +-
 .../vector-transfer-to-vector-load-store.mlir |  93 +++--
 8 files changed, 97 insertions(+), 372 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..155b2241b7a93d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1910,8 +1910,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorScalableStepOpLowering>(converter);
-  // Transfer ops with rank > 1 are handled by VectorToSCF.
-  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4623b9667998cc..7635e10822a345 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -74,8 +74,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorInterleaveLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns,
                                             VectorTransformsOptions());
-    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
-    populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
@@ -84,7 +82,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());
   populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
-  populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
       converter, patterns, reassociateFPReductions, force32BitVectorIndices);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index a953b242207018..484363c6b1d8de 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -636,10 +636,10 @@ struct TransferWriteToVectorStoreLowering
 void mlir::vector::populateVectorTransferLoweringPatterns(
     RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
     PatternBenefit benefit) {
-  patterns.add<TransferReadToVectorLoadLowering,
-               TransferWriteToVectorStoreLowering>(patterns.getContext(),
-                                                   maxTransferRank, benefit);
-  patterns
-      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
-          patterns.getContext(), benefit);
+  // patterns.add<TransferReadToVectorLoadLowering,
+  //              TransferWriteToVectorStoreLowering>(patterns.getContext(),
+  //                                                  maxTransferRank, benefit);
+  // patterns
+  //     .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+  //         patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Conversion/GPUCommon/transfer_write.mlir b/mlir/test/Conversion/GPUCommon/transfer_write.mlir
index cd62b7b13fa9ae..d1127e4203c7b2 100644
--- a/mlir/test/Conversion/GPUCommon/transfer_write.mlir
+++ b/mlir/test/Conversion/GPUCommon/transfer_write.mlir
@@ -3,10 +3,7 @@
   func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: index, %arg3: vector<1xf32>) {
     %c0 = arith.constant 0 : index
     vector.warp_execute_on_lane_0(%arg0)[32] {
-      // CHECK:%[[val:[0-9]+]] = llvm.extractelement
-      // CHECK:%[[base:[0-9]+]] = llvm.extractvalue
-      // CHECK:%[[ptr:[0-9]+]] = llvm.getelementptr %[[base]]
-      // CHECK:llvm.store %[[val]], %[[ptr]]
+      // CHECK: vector.transfer_write %arg9, %[[MEM:.*]][%[[IDX:.*]], %[[IDX]]] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
       vector.transfer_write %arg3, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
     }
     return
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..be230cfcbd6e5b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2953,12 +2953,16 @@ func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : ind
 }
 
 // CHECK-LABEL: func @vector_load_op_0d
-// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
-// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
-// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
-// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
-// CHECK: return %[[cast]] : vector<f32>
+// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %arg2 : index to i64
+// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %arg1 : index to i64
+// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S3:.*]] = llvm.extractvalue %[[S2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S4:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[S5:.*]] = llvm.mul %[[S1]], %[[S4]] : i64
+// CHECK: %[[S6:.*]] = llvm.add %[[S5]], %[[S0]] : i64
+// CHECK: %[[S7:.*]] = llvm.getelementptr %[[S3]][%[[S6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: %[[S8:.*]] = llvm.load %[[S7]] {alignment = 4 : i64} : !llvm.ptr -> vector<1xf32>
+// CHECK: %[[S9:.*]] = builtin.unrealized_conversion_cast %[[S8]] : vector<1xf32> to vector<f32>
 
 // -----
 
@@ -2969,11 +2973,17 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 }
 
 // CHECK-LABEL: func @vector_store_op_0d
-// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
-// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[S0:.*]] = builtin.unrealized_conversion_cast %arg2 : index to i64
+// CHECK: %[[S1:.*]] = builtin.unrealized_conversion_cast %arg1 : index to i64
+// CHECK: %[[S2:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S3:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
+// CHECK: %[[S4:.*]] = builtin.unrealized_conversion_cast %[[S3]] : vector<f32> to vector<1xf32>
+// CHECK: %[[S5:.*]] = llvm.extractvalue %[[S2]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[S7:.*]] = llvm.mul %[[S1]], %[[S6]] : i64
+// CHECK: %[[S8:.*]] = llvm.add %[[S7]], %[[S0]] : i64
+// CHECK: %[[S9:.*]] = llvm.getelementptr %[[S5]][%[[S8]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// CHECK: llvm.store %[[S4]], %[[S9]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
index 8f01cc2b8d44c3..112e868e12107e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
@@ -12,67 +12,11 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
   return %f: vector<17xf32>
 }
 // CHECK-LABEL: func @transfer_read_write_1d
-//  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
-//  CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
-//       CHECK: %[[C7:.*]] = arith.constant 7.0
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
-//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
-//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
-//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
-//       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] :
-//  CMP32-SAME: index to i32
-//  CMP64-SAME: index to i64
-//       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
-//       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]>
-//  CMP64-SAME: : vector<17xi64>
-//
-// 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
-//
-// 5. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
-//  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
-//
-// 6. Rewrite as a masked read.
-//       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]],
-//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
-//  CHECK-SAME: -> vector<17xf32>
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
-//       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
-//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex_b:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
-//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
-//       CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]]
-//  CMP32-SAME: index to i32
-//       CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
-//       CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
-//       CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
-//  CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]>
-//
-// 4. Bitcast to vector form.
-//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
-//
-// 5. Rewrite as a masked write.
-//       CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
-//  CHECK-SAME: {alignment = 4 : i32} :
-//  CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
+// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref<?xf32>, vector<17xf32>
+// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<17xf32>, memref<?xf32>
+
+// -----
 
 func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) -> vector<[17]xf32> {
   %f7 = arith.constant 7.0: f32
@@ -85,62 +29,9 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
   return %f: vector<[17]xf32>
 }
 // CHECK-LABEL: func @transfer_read_write_1d_scalable
-//  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
-//  CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
-//       CHECK: %[[C7:.*]] = arith.constant 7.0
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
-//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
-//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
-//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
-//       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
-//       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
-//       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
-//  CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]>
-//
-// 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
-//
-// 5. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
-//  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
-//
-// 6. Rewrite as a masked read.
-//       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[gep]], %[[mask]],
-//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
-//  CHECK-SAME: -> vector<[17]xf32>
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
-//       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
-//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex_b:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
-//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
-//       CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to [[$IDX_TYPE]]
-//       CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
-//       CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
-//       CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
-//  CHECK-SAME: %[[boundVect_b]] : vector<[17]x[[$IDX_TYPE]]>
-//
-// 4. Bitcast to vector form.
-//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
-//
-// 5. Rewrite as a masked write.
-//       CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
-//  CHECK-SAME: {alignment = 4 : i32} :
-//  CHECK-SAME: vector<[17]xf32>, vector<[17]xi1> into !llvm.ptr
+// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref<?xf32>, vector<[17]xf32>
+// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<[17]xf32>, memref<?xf32>
 
 // -----
 
@@ -155,15 +46,9 @@ func.func @transfer_read_write_index_1d(%A : memref<?xindex>, %base: index) -> v
   return %f: vector<17xindex>
 }
 // CHECK-LABEL: func @transfer_read_write_index_1d
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
-//       CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex>
-//       CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64>
-
-//       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
-//  CHECK-SAME: (!llvm.ptr, vector<17xi1>, vector<17xi64>) -> vector<17xi64>
-
-//       CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
-//  CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr
+// CHECK: %[[CST:.*]] = arith.constant 7 : index
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref<?xindex>, vector<17xindex>
+// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<17xindex>, memref<?xindex>
 
 func.func @transfer_read_write_index_1d_scalable(%A : memref<?xindex>, %base: index) -> vector<[17]xindex> {
   %f7 = arith.constant 7: index
@@ -175,16 +60,10 @@ func.func @transfer_read_write_index_1d_scalable(%A : memref<?xindex>, %base: in
     vector<[17]xindex>, memref<?xindex>
   return %f: vector<[17]xindex>
 }
-// CHECK-LABEL: func @transfer_read_write_index_1d
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xindex>
-//       CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<[17]xindex>
-//       CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<[17]xindex> to vector<[17]xi64>
-
-//       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
-//  CHECK-SAME: (!llvm.ptr, vector<[17]xi1>, vector<[17]xi64>) -> vector<[17]xi64>
-
-//       CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
-//  CHECK-SAME: vector<[17]xi64>, vector<[17]xi1> into !llvm.ptr
+// CHECK-LABEL: func @transfer_read_write_index_1d_scalable
+// CHECK: %[[CST:.*]] = arith.constant 7 : index
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref<?xindex>, vector<[17]xindex>
+// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<[17]xindex>, memref<?xindex>
 
 // -----
 
@@ -196,24 +75,8 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
   return %f: vector<17xf32>
 }
 // CHECK-LABEL: func @transfer_read_2d_to_1d
-//  CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-//       CHECK: %[[c1:.*]] = arith.constant 1 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
-//
-// Compute the in-bound index (dim - offset)
-//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
-//
-// Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
-//  CHECK-SAME: vector<17x[[$IDX_TYPE]]>
-//
-// Create bound vector to compute in-bound mask:
-//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
-//       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
-//       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
-//       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg2], %[[CST]] : memref<?x?xf32>, vector<17xf32>
 
 func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
   %f7 = arith.constant 7.0: f32
@@ -222,23 +85,10 @@ func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index,
     memref<?x?xf32>, vector<[17]xf32>
   return %f: vector<[17]xf32>
 }
-// CHECK-LABEL: func @transfer_read_2d_to_1d
-//  CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
-//       CHECK: %[[c1:.*]] = arith.constant 1 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
-//
-// Compute the in-bound index (dim - offset)
-//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
-//
-// Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
-//
-// Create bound vector to compute in-bound mask:
-//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
-//       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
-//       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
-//       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+// CHECK-LABEL: func @transfer_read_2d_to_1d_scalable
+// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg2], %[[CST]] : memref<?x?xf32>, vector<[17]xf32>
+// CHECK: return %[[TRANSFER_READ]] : vector<[17]xf32>
 
 // -----
 
@@ -253,126 +103,7 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %bas
   return %f: vector<17xf32>
 }
 // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-//
-// 1. Check address space for GEP is correct.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
-//
-// 2. Check address space of the memref is correct.
-//       CHECK: %[[c0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
-//
-// 3. Check address space for GEP is correct.
-//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
-
-func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32, 3>, %base: index) -> vector<[17]xf32> {
-  %f7 = arith.constant 7.0: f32
-  %f = vector.transfer_read %A[%base], %f7
-      {permutation_map = affine_map<(d0) -> (d0)>} :
-    memref<?xf32, 3>, vector<[17]xf32>
-  vector.transfer_write %f, %A[%base]
-      {permutation_map = affine_map<(d0) -> (d0)>} :
-    vector<[17]xf32>, memref<?xf32, 3>
-  return %f: vector<[17]xf32>
-}
-// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
-//
-// 1. Check address space for GEP is correct.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
-//
-// 2. Check address space of the memref is correct.
-//       CHECK: %[[c0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
-//
-// 3. Check address space for GEP is correct.
-//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
-
-// -----
-
-func.func @transfer_read_1d_inbounds(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
-  %f7 = arith.constant 7.0: f32
-  %f = vector.transfer_read %A[%base], %f7 {in_bounds = [true]} :
-    memref<?xf32>, vector<17xf32>
-  return %f: vector<17xf32>
-}
-// CHECK-LABEL: func @transfer_read_1d_inbounds
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-//
-// 1. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
-//
-// 2. Rewrite as a load.
-//       CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<17xf32>
-
-func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index) -> vector<[17]xf32> {
-  %f7 = arith.constant 7.0: f32
-  %f = vector.transfer_read %A[%base], %f7 {in_bounds = [true]} :
-    memref<?xf32>, vector<[17]xf32>
-  return %f: vector<[17]xf32>
-}
-// CHECK-LABEL: func @transfer_read_1d_inbounds_scalable
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
-//
-// 1. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
-//  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
-//
-// 2. Rewrite as a load.
-//       CHECK: %[[loaded:.*]] = llvm.load %[[gep]] {alignment = 4 : i64} : !llvm.ptr -> vector<[17]xf32>
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_write_1d_mask
-// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
-// CHECK: %[[cmpi:.*]] = arith.cmpi slt
-// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
-// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
-// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
-// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
-// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
-// CHECK: return %[[r]]
-func.func @transfer_read_write_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5xf32> {
-  %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
-  %f7 = arith.constant 7.0: f32
-  %f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<5xf32>
-  vector.transfer_write %f, %A[%base], %m : vector<5xf32>, memref<?xf32>
-  return %f: vector<5xf32>
-}
-
-// CHECK-LABEL: func @transfer_read_write_1d_mask_scalable
-// CHECK-SAME: %[[mask:[a-zA-Z0-9]*]]: vector<[5]xi1>
-// CHECK: %[[cmpi:.*]] = arith.cmpi slt
-// CHECK: %[[mask1:.*]] = arith.andi %[[cmpi]], %[[mask]]
-// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask1]]
-// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
-// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi_1]], %[[mask]]
-// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask2]]
-// CHECK: return %[[r]]
-func.func @transfer_read_write_1d_mask_scalable(%A : memref<?xf32>, %base : index, %m : vector<[5]xi1>) -> vector<[5]xf32> {
-  %f7 = arith.constant 7.0: f32
-  %f = vector.transfer_read %A[%base], %f7, %m : memref<?xf32>, vector<[5]xf32>
-  vector.transfer_write %f, %A[%base], %m : vector<[5]xf32>, memref<?xf32>
-  return %f: vector<[5]xf32>
-}
-
-// -----
-
-// Can't lower xfer_read/xfer_write on tensors, but this shouldn't crash
-
-// CHECK-LABEL: func @transfer_read_write_tensor
-//       CHECK:   vector.transfer_read
-//       CHECK:   vector.transfer_write
-func.func @transfer_read_write_tensor(%A: tensor<?xf32>, %base : index) -> vector<4xf32> {
-  %f7 = arith.constant 7.0: f32
-  %c0 = arith.constant 0: index
-  %f = vector.transfer_read %A[%base], %f7 : tensor<?xf32>, vector<4xf32>
-  %w = vector.transfer_write %f, %A[%c0] : vector<4xf32>, tensor<?xf32>
-  "test.some_use"(%w) : (tensor<?xf32>) -> ()
-  return %f : vector<4xf32>
-}
+// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
+// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1], %[[CST]] : memref<?xf32, 3>, vector<17xf32>
+// CHECK: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1] : vector<17xf32>, memref<?xf32, 3>
+// CHECK: return %[[TRANSFER_READ]] : 
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4b38db79bff3e1..e590e462c728a5 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -5,8 +5,9 @@ func.func @matmul_tensors(
   %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>)
     -> tensor<8x32xf32> {
 // CHECK-NOT: linalg
+// CHECK: vector.transfer_read {{.*}} : memref<8x16xf32>, vector<2xf32>
 // CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32>
-// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32>
+// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<8x32xf32>
   %0 = linalg.matmul  ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>)
                      outs(%arg2: tensor<8x32xf32>)
     -> tensor<8x32xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..7acfdad930b8e6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -1,23 +1,17 @@
 // RUN: mlir-opt %s --transform-interpreter -canonicalize --split-input-file | FileCheck %s
 
-// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
+// CHECK-LABEL: func @vector_transfer_ops_0d_memref
 //  CHECK-SAME:   %[[MEM:.*]]: memref<f32>
 //  CHECK-SAME:   %[[VEC:.*]]: vector<1x1x1xf32>
-func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
+//  CHECK-NEXT:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-NEXT:   %[[V:.*]] = vector.transfer_read %arg0[], %[[CST]] : memref<f32>, vector<f32>
+//  CHECK-NEXT:   vector.transfer_write %0, %arg0[] : vector<f32>, memref<f32>
+//  CHECK-NEXT:   vector.store %arg1, %arg0[] : memref<f32>, vector<1x1x1xf32>
+func.func @vector_transfer_ops_0d_memref(%M: memref<f32>, %v: vector<1x1x1xf32>) {
     %f0 = arith.constant 0.0 : f32
-
-//  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
-    %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
-
-//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
-    vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
-
-//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
-    vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
-
+    %0 = vector.transfer_read %M[], %f0 : memref<f32>, vector<f32>
+    vector.transfer_write %0, %M[] : vector<f32>, memref<f32>
+    vector.store %v, %M[] : memref<f32>, vector<1x1x1xf32>
     return
 }
 
@@ -36,13 +30,11 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
 }
 
 // transfer_read/write are lowered to vector.load/store
-// CHECK-LABEL:   func @transfer_to_load(
-// CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32>,
-// CHECK-SAME:      %[[IDX:.*]]: index) -> vector<4xf32> {
-// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
-// CHECK-NEXT:      vector.store  %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
-// CHECK-NEXT:    }
+// CHECK-LABEL: func @transfer_to_load(
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT: vector.transfer_write %0, %arg0[%arg1, %arg1] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32>
+// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<4xf32>
 
 func.func @transfer_to_load(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf32> {
   %cf0 = arith.constant 0.0 : f32
@@ -70,12 +62,10 @@ func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %idx : index, %mask :
 
 // n-D results are also supported.
 // CHECK-LABEL:   func @transfer_2D(
-// CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32>,
-// CHECK-SAME:      %[[IDX:.*]]: index) -> vector<2x4xf32> {
-// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
-// CHECK-NEXT:      vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<2x4xf32>
-// CHECK-NEXT:    }
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true, true]} : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT: vector.transfer_write  %[[TRANSFER_READ]], %arg0[%arg1, %arg1] {in_bounds = [true, true]} : vector<2x4xf32>, memref<8x8xf32>
+// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<2x4xf32>
 
 func.func @transfer_2D(%mem : memref<8x8xf32>, %idx : index) -> vector<2x4xf32> {
   %cf0 = arith.constant 0.0 : f32
@@ -88,10 +78,10 @@ func.func @transfer_2D(%mem : memref<8x8xf32>, %idx : index) -> vector<2x4xf32>
 // CHECK-LABEL:   func @transfer_vector_element(
 // CHECK-SAME:      %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<2x4xf32> {
-// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
-// CHECK-NEXT:      vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<2x4xf32>
-// CHECK-NEXT:    }
+// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x4xf32>
+// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1],  %[[CST]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1, %arg1] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>>
+// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<2x4xf32>
 
 func.func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %idx : index) -> vector<2x4xf32> {
   %cf0 = arith.constant dense<0.0> : vector<2x4xf32>
@@ -157,10 +147,10 @@ func.func @transfer_not_inbounds(%mem : memref<8x8xf32>, %idx : index) -> vector
 // CHECK-LABEL:   func @transfer_nondefault_layout(
 // CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<4xf32> {
-// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>, vector<4xf32>
-// CHECK-NEXT:      vector.store %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32, #{{.*}}>,  vector<4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
-// CHECK-NEXT:    }
+// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true]} : memref<8x8xf32, #map>, vector<4xf32>
+// CHECK-NEXT: vector.transfer_write %[[TRANSFER_READ]], %arg0[%arg1, %arg1] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32, #map>
+// CHECK-NEXT: return %[[TRANSFER_READ]] : vector<4xf32>
 
 #layout = affine_map<(d0, d1) -> (d0*16 + d1)>
 func.func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %idx : index) -> vector<4xf32> {
@@ -191,11 +181,11 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf
 // CHECK-LABEL:   func @transfer_broadcasting(
 // CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<4xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
+// CHECK-NEXT:      %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] : memref<8x8xf32>, vector<f32>
+// CHECK-NEXT:      %[[BROADCAST:.*]] = vector.broadcast %[[TRANSFER_READ]] : vector<f32> to vector<4xf32>
+// CHECK-NEXT:      return %[[BROADCAST]] : vector<4xf32>
 // CHECK-NEXT:    }
-
 #broadcast_1d = affine_map<(d0, d1) -> (0)>
 func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf32> {
   %cf0 = arith.constant 0.0 : f32
@@ -208,9 +198,9 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector
 // CHECK-LABEL:   func @transfer_scalar(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<1xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<1xf32>
+// CHECK-NEXT:      %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] {in_bounds = [true]} : memref<?x?xf32>, vector<1xf32>
+// CHECK-NEXT:      return %[[TRANSFER_READ]] : vector<1xf32>
 // CHECK-NEXT:    }
 func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32> {
   %cf0 = arith.constant 0.0 : f32
@@ -222,9 +212,10 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32
 // CHECK-LABEL:   func @transfer_broadcasting_2D(
 // CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<4x4xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<4x4xf32>
+// CHECK-NEXT:      %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]] : memref<8x8xf32>, vector<f32>
+// CHECK-NEXT:      %[[BROADCAST:.*]] = vector.broadcast %[[TRANSFER_READ]] : vector<f32> to vector<4x4xf32>
+// CHECK-NEXT:      return %[[BROADCAST]] : vector<4x4xf32>
 // CHECK-NEXT:    }
 
 #broadcast_2d = affine_map<(d0, d1) -> (0, 0)>
@@ -240,9 +231,9 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %idx : index) -> vec
 // CHECK-LABEL:   func @transfer_broadcasting_complex(
 // CHECK-SAME:      %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32>
-// CHECK-NEXT:      return %[[RES]] : vector<3x2x4x5xf32>
+// CHECK-NEXT:      %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[TRANSFER_READ:.*]] = vector.transfer_read %arg0[%arg1, %arg1, %arg1, %arg1, %arg1], %[[CST]] {in_bounds = [true, true, true, true], permutation_map = #map2} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
+// CHECK-NEXT:      return %[[TRANSFER_READ]] : vector<3x2x4x5xf32>
 // CHECK-NEXT:    }
 
 #broadcast_2d_in_4d = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
@@ -322,8 +313,8 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
   %6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref<?x?xf32>, vector<8xf32>
-// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>
-// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32>
+// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref<?x?xf32>, vector<f32>
+// CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<8xf32>
 
   return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
          vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,



More information about the Mlir-commits mailing list