[Mlir-commits] [mlir] [mlir][Vector] Support 0-d vectors natively in VectorStoreToMemrefStoreLowering. (PR #112937)
Harrison Hao
llvmlistbot at llvm.org
Tue Nov 5 04:16:52 PST 2024
https://github.com/harrisonGPU updated https://github.com/llvm/llvm-project/pull/112937
>From 5e07a46f98d641cc43f49780b9993940ef0c735f Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Tue, 5 Nov 2024 20:11:01 +0800
Subject: [PATCH] [mlir][Vector] Support 0-d vectors natively in
VectorStoreToMemrefStoreLowering.
---
.../Vector/Transforms/LowerVectorTransfer.cpp | 12 ++-----
.../VectorToLLVM/vector-to-llvm.mlir | 5 ++-
mlir/test/Dialect/SPIRV/IR/availability.mlir | 36 +++++++++----------
.../vector-transfer-to-vector-load-store.mlir | 8 ++---
4 files changed, 26 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..1cb3baaef82baf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -530,15 +530,9 @@ struct VectorStoreToMemrefStoreLowering
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
Value extracted;
- if (vecType.getRank() == 0) {
- // TODO: Unifiy once ExtractOp supports 0-d vectors.
- extracted = rewriter.create<vector::ExtractElementOp>(
- storeOp.getLoc(), storeOp.getValueToStore());
- } else {
- SmallVector<int64_t> indices(vecType.getRank(), 0);
- extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.getValueToStore(), indices);
- }
+ SmallVector<int64_t> indices(vecType.getRank(), 0);
+ extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.getValueToStore(), indices);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1de24fd0403ce..abbdbe02ce6c1e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2971,9 +2971,8 @@ func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : in
// CHECK-LABEL: func @vector_store_op_0d
// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
-// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
-// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<1xf32> to f32
+// CHECK: memref.store %[[cast2]], %{{.*}}[%{{.*}}, %{{.*}}]
// -----
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index c583a48eba2704..ceebeeffcf2677 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -58,7 +58,7 @@ func.func @module_physical_storage_buffer64_vulkan() {
func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -68,7 +68,7 @@ func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -78,7 +78,7 @@ func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -88,7 +88,7 @@ func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SUDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -98,7 +98,7 @@ func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SUDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -108,7 +108,7 @@ func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SUDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -118,7 +118,7 @@ func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.UDot %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -128,7 +128,7 @@ func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.UDot %a, %a: vector<4xi8> -> i64
return %r: i64
@@ -138,7 +138,7 @@ func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.UDot %a, %a: vector<4xi16> -> i64
return %r: i64
@@ -148,7 +148,7 @@ func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -158,7 +158,7 @@ func.func @sdot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -168,7 +168,7 @@ func.func @sdot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
@@ -178,7 +178,7 @@ func.func @sdot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.SUDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -188,7 +188,7 @@ func.func @sudot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -198,7 +198,7 @@ func.func @sudot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.SUDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
@@ -208,7 +208,7 @@ func.func @sudot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
%r = spirv.UDotAccSat %a, %a, %a, <PackedVectorFormat4x8Bit>: i32 -> i32
return %r: i32
@@ -218,7 +218,7 @@ func.func @udot_acc_sat_scalar_i32_i32(%a: i32) -> i32 {
func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi8> -> i64
return %r: i64
@@ -228,7 +228,7 @@ func.func @udot_acc_sat_vector_4xi8_i64(%a: vector<4xi8>, %acc: i64) -> i64 {
func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
- // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: extensions: [ [SPV_KHR_16bit_storage] ]
// CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
%r = spirv.UDotAccSat %a, %a, %acc: vector<4xi16> -> i64
return %r: i64
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..f75f8f8489efc1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -7,15 +7,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
%f0 = arith.constant 0.0 : f32
// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
-// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
+// CHECK-NEXT: memref.store %[[S]], %[[MEM]][] : memref<f32>
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
-// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
+// CHECK-NEXT: %[[V:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
+// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
return
More information about the Mlir-commits
mailing list