[Mlir-commits] [mlir] [mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. (PR #126294)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 7 11:30:34 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Md Abdullah Shahneous Bari (mshahneo)

<details>
<summary>Changes</summary>

Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.

---
Full diff: https://github.com/llvm/llvm-project/pull/126294.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+23-6) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+39) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
 
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = loadOp.getVectorType();
-    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
-    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+    // Use the converted vector type instead of original (single element vector
+    // would get converted to scalar).
+    auto spirvVectorType = typeConverter.convertType(vectorType);
+    auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
                                                castedAccessChain);
 
     return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
                                                 adaptor.getValueToStore());
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
   return %0: vector<4xf32>
 }
 
+
+// CHECK-LABEL: @vector_load_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+//       CHECK:   %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+//       CHECK:   return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return %0: vector<1xf32>
+}
+
+
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
 //       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>
+//       CHECK:  %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:  %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+//       CHECK:  %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:  %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:  %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:  %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+//       CHECK:  spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+  %idx = arith.constant 0 : index
+  vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/126294


More information about the Mlir-commits mailing list