[Mlir-commits] [mlir] 07f9a0d - [mlir][spirv] Do not introduce vector<1xT> in UnifyAliasedResource

Jakub Kuderski llvmlistbot at llvm.org
Tue Jul 25 08:31:21 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-25T11:30:18-04:00
New Revision: 07f9a0ddd69b48bfab4fe1fc209789202fdc8209

URL: https://github.com/llvm/llvm-project/commit/07f9a0ddd69b48bfab4fe1fc209789202fdc8209
DIFF: https://github.com/llvm/llvm-project/commit/07f9a0ddd69b48bfab4fe1fc209789202fdc8209.diff

LOG: [mlir][spirv] Do not introduce vector<1xT> in UnifyAliasedResource

1-element vectors are not valid in SPIR-V and fail `Bitcast` op verification.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D156207

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index ea856c7486777c..49382856a64cb4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -506,9 +506,14 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
               dstElemVecType.getElementType()) {
             int64_t count =
                 dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
-            auto castType =
-                VectorType::get({count}, srcElemVecType.getElementType());
-            for (auto &c : components)
+
+            // Make sure not to create 1-element vectors, which are illegal in
+            // SPIR-V.
+            Type castType = srcElemVecType.getElementType();
+            if (count > 1)
+              castType = VectorType::get({count}, castType);
+
+            for (Value &c : components)
               c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
           }
         }

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index 4f3df8639364d5..ac9589ba24323e 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -508,6 +508,47 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+spirv.module Logical GLSL450 {
+  spirv.GlobalVariable @var01_v2f16 bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>
+  spirv.GlobalVariable @var01_v2f32 bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
+
+  spirv.func @aliased(%index: i32) -> vector<3xf32> "None" {
+    %c0 = spirv.Constant 0 : i32
+    %v0 = spirv.Constant dense<0.0> : vector<3xf32>
+    %addr0 = spirv.mlir.addressof @var01_v2f16 : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>
+    %ac0 = spirv.AccessChain %addr0[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>, i32, i32
+    %value0 = spirv.Load "StorageBuffer" %ac0 : vector<2xf16>
+
+    %addr1 = spirv.mlir.addressof @var01_v2f32 : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
+    %ac1 = spirv.AccessChain %addr1[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>, i32, i32
+    %value1 = spirv.Load "StorageBuffer" %ac1 : vector<2xf32>
+
+    %val0_as_f32 = spirv.Bitcast %value0 : vector<2xf16> to f32
+
+    %res = spirv.CompositeConstruct %val0_as_f32, %value1 : (f32, vector<2xf32>) -> vector<3xf32>
+
+    spirv.ReturnValue %res : vector<3xf32>
+  }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK: spirv.GlobalVariable @var01_v2f16 bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<2xf16>, stride=4> [0])>, StorageBuffer>
+// CHECK: spirv.func @aliased
+
+// CHECK:     %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16>
+// CHECK:     %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16>
+// CHECK:     %[[LD2:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<2xf16>
+
+// CHECK-DAG: %[[ELEM0:.+]] = spirv.Bitcast %[[LD0]] : vector<2xf16> to f32
+// CHECK-DAG: %[[ELEM1:.+]] = spirv.Bitcast %[[LD1]] : vector<2xf16> to f32
+// CHECK-DAG: %[[ELEM2:.+]] = spirv.Bitcast %[[LD2]] : vector<2xf16> to f32
+
+// CHECK:     %[[RES:.+]] = spirv.CompositeConstruct %[[ELEM0]], %{{.+}} : (f32, vector<2xf32>) -> vector<3xf32>
+// CHECK:     spirv.ReturnValue %[[RES]] : vector<3xf32>
+
+// -----
+
 // Make sure we do not crash on function arguments.
 
 spirv.module Logical GLSL450 {


        


More information about the Mlir-commits mailing list