[Mlir-commits] [mlir] 2c46051 - [mlir][spirv] Fix vector type mismatch in UnifyAliasedResourcePass
Lei Zhang
llvmlistbot at llvm.org
Wed Feb 1 12:26:45 PST 2023
Author: Lei Zhang
Date: 2023-02-01T20:26:29Z
New Revision: 2c46051aa9d30dc1740f2183ceb45a235b994cc3
URL: https://github.com/llvm/llvm-project/commit/2c46051aa9d30dc1740f2183ceb45a235b994cc3
DIFF: https://github.com/llvm/llvm-project/commit/2c46051aa9d30dc1740f2183ceb45a235b994cc3.diff
LOG: [mlir][spirv] Fix vector type mismatch in UnifyAliasedResourcePass
For the cases where we have aliases of `vector<4xf16>` and
`vector<4xf32>`, we need to do casting before composite
construction.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D143042
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 3e5b934be677c..3acfba23ddfd1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -485,11 +485,28 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
// bitwidth element type. For spirv.bitcast, the lower-numbered components
// of the vector map to lower-ordered bits of the larger bitwidth element
// type.
+
Type vectorType = srcElemType;
if (!srcElemType.isa<VectorType>())
vectorType = VectorType::get({ratio}, dstElemType);
+
+ // If both the source and destination are vector types, we need to make
+ // sure the scalar type is the same for composite construction later.
+ if (auto srcElemVecType = srcElemType.dyn_cast<VectorType>())
+ if (auto dstElemVecType = dstElemType.dyn_cast<VectorType>()) {
+ if (srcElemVecType.getElementType() !=
+ dstElemVecType.getElementType()) {
+ int64_t count =
+ dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
+ auto castType =
+ VectorType::get({count}, srcElemVecType.getElementType());
+ for (auto &c : components)
+ c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
+ }
+ }
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
loc, vectorType, components);
+
if (!srcElemType.isa<VectorType>())
vectorValue =
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index a456016e9320e..1d532f3f13812 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -448,3 +448,34 @@ spirv.module Logical GLSL450 {
// CHECK: spirv.GlobalVariable @var01_i16 bind(0, 1) {aliased}
// CHECK: spirv.func @scalar_type_bitwidth_smaller_than_vector
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.GlobalVariable @var00_v4f32 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ spirv.GlobalVariable @var00_v4f16 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf16>, stride=8> [0])>, StorageBuffer>
+
+ spirv.func @vector_type_same_size_
diff erent_element_type(%i0: i32) -> vector<4xf32> "None" {
+ %c0 = spirv.Constant 0 : i32
+
+ %addr = spirv.mlir.addressof @var00_v4f32 : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ %ac = spirv.AccessChain %addr[%c0, %i0] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
+ %val = spirv.Load "StorageBuffer" %ac : vector<4xf32>
+
+ spirv.ReturnValue %val : vector<4xf32>
+ }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK: spirv.GlobalVariable @var00_v4f16 bind(0, 0) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf16>, stride=8> [0])>, StorageBuffer>
+
+// CHECK: spirv.func @vector_type_same_size_
diff erent_element_type
+
+// CHECK: %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16>
+// CHECK: %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16>
+// CHECK: %[[BC0:.+]] = spirv.Bitcast %[[LD0]] : vector<4xf16> to vector<2xf32>
+// CHECK: %[[BC1:.+]] = spirv.Bitcast %[[LD1]] : vector<4xf16> to vector<2xf32>
+// CHECK: %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32>
+// CHECK: spirv.ReturnValue %[[CC]]
+
More information about the Mlir-commits
mailing list