[Mlir-commits] [mlir] fce33e1 - [mlir][spirv] Consider target when converting one-element vector
Lei Zhang
llvmlistbot at llvm.org
Tue Oct 18 22:55:55 PDT 2022
Author: Lei Zhang
Date: 2022-10-19T05:49:32Z
New Revision: fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d
URL: https://github.com/llvm/llvm-project/commit/fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d
DIFF: https://github.com/llvm/llvm-project/commit/fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d.diff
LOG: [mlir][spirv] Consider target when converting one-element vector
Vectors with just one element will be converted into scalars.
However, we cannot just return the element types and assume it
is supported in the target environment; we need to conver the
element type again factoring in those considerations.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D136226
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 083b997aeaef5..2514cfe0301a1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
@@ -239,8 +240,9 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options,
VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
+ auto scalarType = type.getElementType().cast<spirv::ScalarType>();
if (type.getRank() <= 1 && type.getNumElements() == 1)
- return type.getElementType();
+ return convertScalarType(targetEnv, options, scalarType, storageClass);
if (!spirv::CompositeType::isValid(type)) {
// TODO: Vector types with more than four elements can be translated into
@@ -260,9 +262,8 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
return type;
- auto elementType = convertScalarType(
- targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
- storageClass);
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
if (elementType)
return VectorType::get(type.getShape(), elementType);
return nullptr;
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 799d8c34aea18..4f1cd09efec30 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -207,6 +207,10 @@ func.func @float_vector(
%arg1: vector<3xf64>
) { return }
+// CHECK-LABEL: spirv.func @one_element_vector
+// CHECK-SAME: %{{.+}}: i32
+func.func @one_element_vector(%arg0: vector<1xi8>) { return }
+
} // end module
// -----
More information about the Mlir-commits
mailing list