[Mlir-commits] [mlir] [mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. (PR #126294)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Fri Feb 7 11:30:00 PST 2025


https://github.com/mshahneo created https://github.com/llvm/llvm-project/pull/126294

Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.

>From fc9329fb72437a7a145c6125f6e04cf22ee58969 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Fri, 7 Feb 2025 19:13:53 +0000
Subject: [PATCH] [mlir][vector][spirv] Handle 1-element vector.{load|store}
 lowering.

Add support for single element vector{load|store} lowering to SPIR-V.
Since, SPIR-V converts single element vector to scalars, it needs special
attention for vector{load|store} lowering to spirv{load|store}.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 29 +++++++++++---
 .../VectorToSPIRV/vector-to-spirv.mlir        | 39 +++++++++++++++++++
 2 files changed, 62 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
 
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = loadOp.getVectorType();
-    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
-    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+    // Use the converted vector type instead of original (single element vector
+    // would get converted to scalar).
+    auto spirvVectorType = typeConverter.convertType(vectorType);
+    auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
                                                castedAccessChain);
 
     return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
                                                 adaptor.getValueToStore());
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
   return %0: vector<4xf32>
 }
 
+
+// CHECK-LABEL: @vector_load_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+//       CHECK:   %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+//       CHECK:   return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return %0: vector<1xf32>
+}
+
+
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
 //       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>
+//       CHECK:  %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:  %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+//       CHECK:  %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:  %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:  %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:  %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+//       CHECK:  spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+  %idx = arith.constant 0 : index
+  vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>



More information about the Mlir-commits mailing list