[Mlir-commits] [mlir] 07f9a0d - [mlir][spirv] Do not introduce vector<1xT> in UnifyAliasedResource
Jakub Kuderski
llvmlistbot at llvm.org
Tue Jul 25 08:31:21 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-25T11:30:18-04:00
New Revision: 07f9a0ddd69b48bfab4fe1fc209789202fdc8209
URL: https://github.com/llvm/llvm-project/commit/07f9a0ddd69b48bfab4fe1fc209789202fdc8209
DIFF: https://github.com/llvm/llvm-project/commit/07f9a0ddd69b48bfab4fe1fc209789202fdc8209.diff
LOG: [mlir][spirv] Do not introduce vector<1xT> in UnifyAliasedResource
1-element vectors are not valid in SPIR-V and fail `Bitcast` op verification.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D156207
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index ea856c7486777c..49382856a64cb4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -506,9 +506,14 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
dstElemVecType.getElementType()) {
int64_t count =
dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
- auto castType =
- VectorType::get({count}, srcElemVecType.getElementType());
- for (auto &c : components)
+
+ // Make sure not to create 1-element vectors, which are illegal in
+ // SPIR-V.
+ Type castType = srcElemVecType.getElementType();
+ if (count > 1)
+ castType = VectorType::get({count}, castType);
+
+ for (Value &c : components)
c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
}
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index 4f3df8639364d5..ac9589ba24323e 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -508,6 +508,47 @@ spirv.module Logical GLSL450 {
// -----
+spirv.module Logical GLSL450 {
+ spirv.GlobalVariable @var01_v2f16 bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>
+ spirv.GlobalVariable @var01_v2f32 bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
+
+ spirv.func @aliased(%index: i32) -> vector<3xf32> "None" {
+ %c0 = spirv.Constant 0 : i32
+ %v0 = spirv.Constant dense<0.0> : vector<3xf32>
+ %addr0 = spirv.mlir.addressof @var01_v2f16 : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>
+ %ac0 = spirv.AccessChain %addr0[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>, i32, i32
+ %value0 = spirv.Load "StorageBuffer" %ac0 : vector<2xf16>
+
+ %addr1 = spirv.mlir.addressof @var01_v2f32 : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
+ %ac1 = spirv.AccessChain %addr1[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>, i32, i32
+ %value1 = spirv.Load "StorageBuffer" %ac1 : vector<2xf32>
+
+ %val0_as_f32 = spirv.Bitcast %value0 : vector<2xf16> to f32
+
+ %res = spirv.CompositeConstruct %val0_as_f32, %value1 : (f32, vector<2xf32>) -> vector<3xf32>
+
+ spirv.ReturnValue %res : vector<3xf32>
+ }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK: spirv.GlobalVariable @var01_v2f16 bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>
+// CHECK: spirv.func @aliased
+
+// CHECK: %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16>
+// CHECK: %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16>
+// CHECK: %[[LD2:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16>
+
+// CHECK-DAG: %[[ELEM0:.+]] = spirv.Bitcast %[[LD0]] : vector<2xf16> to f32
+// CHECK-DAG: %[[ELEM1:.+]] = spirv.Bitcast %[[LD1]] : vector<2xf16> to f32
+// CHECK-DAG: %[[ELEM2:.+]] = spirv.Bitcast %[[LD2]] : vector<2xf16> to f32
+
+// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[ELEM0]], %{{.+}} : (f32, vector<2xf32>) -> vector<3xf32>
+// CHECK: spirv.ReturnValue %[[RES]] : vector<3xf32>
+
+// -----
+
// Make sure we do not crash on function arguments.
spirv.module Logical GLSL450 {
More information about the Mlir-commits
mailing list