[Mlir-commits] [mlir] 2c46051 - [mlir][spirv] Fix vector type mismatch in UnifyAliasedResourcePass

Lei Zhang llvmlistbot at llvm.org
Wed Feb 1 12:26:45 PST 2023


Author: Lei Zhang
Date: 2023-02-01T20:26:29Z
New Revision: 2c46051aa9d30dc1740f2183ceb45a235b994cc3

URL: https://github.com/llvm/llvm-project/commit/2c46051aa9d30dc1740f2183ceb45a235b994cc3
DIFF: https://github.com/llvm/llvm-project/commit/2c46051aa9d30dc1740f2183ceb45a235b994cc3.diff

LOG: [mlir][spirv] Fix vector type mismatch in UnifyAliasedResourcePass

For the cases where we have aliases of `vector<4xf16>` and
`vector<4xf32>`, we need to do casting before composite
construction.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D143042

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 3e5b934be677c..3acfba23ddfd1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -485,11 +485,28 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
       // bitwidth element type. For spirv.bitcast, the lower-numbered components
       // of the vector map to lower-ordered bits of the larger bitwidth element
       // type.
+
       Type vectorType = srcElemType;
       if (!srcElemType.isa<VectorType>())
         vectorType = VectorType::get({ratio}, dstElemType);
+
+      // If both the source and destination are vector types, we need to make
+      // sure the scalar type is the same for composite construction later.
+      if (auto srcElemVecType = srcElemType.dyn_cast<VectorType>())
+        if (auto dstElemVecType = dstElemType.dyn_cast<VectorType>()) {
+          if (srcElemVecType.getElementType() !=
+              dstElemVecType.getElementType()) {
+            int64_t count =
+                dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
+            auto castType =
+                VectorType::get({count}, srcElemVecType.getElementType());
+            for (auto &c : components)
+              c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
+          }
+        }
       Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
           loc, vectorType, components);
+
       if (!srcElemType.isa<VectorType>())
         vectorValue =
             rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index a456016e9320e..1d532f3f13812 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -448,3 +448,34 @@ spirv.module Logical GLSL450 {
 // CHECK: spirv.GlobalVariable @var01_i16 bind(0, 1) {aliased}
 
 // CHECK: spirv.func @scalar_type_bitwidth_smaller_than_vector
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.GlobalVariable @var00_v4f32 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+  spirv.GlobalVariable @var00_v4f16 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf16>, stride=8> [0])>, StorageBuffer>
+
+  spirv.func @vector_type_same_size_
diff erent_element_type(%i0: i32) -> vector<4xf32> "None" {
+    %c0 = spirv.Constant 0 : i32
+
+    %addr = spirv.mlir.addressof @var00_v4f32 : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+    %ac = spirv.AccessChain %addr[%c0, %i0] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
+    %val = spirv.Load "StorageBuffer" %ac :  vector<4xf32>
+
+    spirv.ReturnValue %val : vector<4xf32>
+  }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK: spirv.GlobalVariable @var00_v4f16 bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf16>, stride=8> [0])>, StorageBuffer>
+
+// CHECK: spirv.func @vector_type_same_size_
diff erent_element_type
+
+// CHECK:   %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16>
+// CHECK:   %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16>
+// CHECK:   %[[BC0:.+]] = spirv.Bitcast %[[LD0]] : vector<4xf16> to vector<2xf32>
+// CHECK:   %[[BC1:.+]] = spirv.Bitcast %[[LD1]] : vector<4xf16> to vector<2xf32>
+// CHECK:   %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32>
+// CHECK:   spirv.ReturnValue %[[CC]]
+


        


More information about the Mlir-commits mailing list