[Mlir-commits] [mlir] [mlir][Vector] Support 0-d vectors natively in VectorStoreToMemrefStoreLowering. (PR #112937)

Harrison Hao llvmlistbot at llvm.org
Tue Nov 5 04:16:52 PST 2024


https://github.com/harrisonGPU updated https://github.com/llvm/llvm-project/pull/112937

>From 5e07a46f98d641cc43f49780b9993940ef0c735f Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Tue, 5 Nov 2024 20:11:01 +0800
Subject: [PATCH] [mlir][Vector] Support 0-d vectors natively in
 VectorStoreToMemrefStoreLowering.

---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 12 ++-----
 .../VectorToLLVM/vector-to-llvm.mlir          |  5 ++-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  | 36 +++++++++----------
 .../vector-transfer-to-vector-load-store.mlir |  8 ++---
 4 files changed, 26 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
       return rewriter.notifyMatchFailure(storeOp, "not single element vector");
 
     Value extracted;
-    if (vecType.getRank() == 0) {
-      // TODO: Unifiy once ExtractOp supports 0-d vectors.
-      extracted = rewriter.create<vector::ExtractElementOp>(
-          storeOp.getLoc(), storeOp.getValueToStore());
-    } else {
-      SmallVector<int64_t> indices(vecType.getRank(), 0);
-      extracted = rewriter.create<vector::ExtractOp>(
-          storeOp.getLoc(), storeOp.getValueToStore(), indices);
-    }
+    SmallVector<int64_t> indices(vecType.getRank(), 0);
+    extracted = rewriter.create<vector::ExtractOp>(
+        storeOp.getLoc(), storeOp.getValueToStore(), indices);
 
     rewriter.replaceOpWithNewOp<memref::StoreOp>(
         storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
 // CHECK-LABEL: func @vector_store_op_0d
 // CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
 
 // -----
 
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
 func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDot %a, %a: vector<4xi8> -> i64
   return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
 func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDot %a, %a: vector<4xi16> -> i64
   return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
 func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
   %r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
   return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
 func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
   return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
 func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
   // CHECK: min version: v1.0
   // CHECK: max version: v1.6
-  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
   // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
   %r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
   return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
     %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   memref.store %[[S]], %[[MEM]][] : memref<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
     vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
 
     return



More information about the Mlir-commits mailing list