[Mlir-commits] [mlir] 8b12acd - [mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. (#126294)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 7 13:31:52 PST 2025
Author: Md Abdullah Shahneous Bari
Date: 2025-02-07T15:31:47-06:00
New Revision: 8b12acd2a4a030ad0be76295b98815f93b2631d8
URL: https://github.com/llvm/llvm-project/commit/8b12acd2a4a030ad0be76295b98815f93b2631d8
DIFF: https://github.com/llvm/llvm-project/commit/8b12acd2a4a030ad0be76295b98815f93b2631d8.diff
LOG: [mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. (#126294)
Add support for single element vector{load|store} lowering to SPIR-V.
Since, SPIR-V converts single element vector to scalars, it needs
special attention for vector{load|store} lowering to spirv{load|store}.
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = loadOp.getVectorType();
- auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
- Value castedAccessChain =
- rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+ // Use the converted vector type instead of original (single element vector
+ // would get converted to scalar).
+ auto spirvVectorType = typeConverter.convertType(vectorType);
+ auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ // the original vector type. Both is going to be the same, a pointer
+ // to a scalar.
+ Value castedAccessChain = (vectorType.getNumElements() == 1)
+ ? accessChain
+ : rewriter.create<spirv::BitcastOp>(
+ loc, vectorPtrType, accessChain);
+
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain);
return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
- Value castedAccessChain =
- rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ // the original vector type. Both is going to be the same, a pointer
+ // to a scalar.
+ Value castedAccessChain = (vectorType.getNumElements() == 1)
+ ? accessChain
+ : rewriter.create<spirv::BitcastOp>(
+ loc, vectorPtrType, accessChain);
+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
adaptor.getValueToStore());
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
return %0: vector<4xf32>
}
+
+// CHECK-LABEL: @vector_load_single_elem
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+// CHECK: return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+ %idx = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+ return %0: vector<1xf32>
+}
+
+
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_single_elem
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+ %idx = arith.constant 0 : index
+ vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
More information about the Mlir-commits
mailing list