[Mlir-commits] [mlir] fce33e1 - [mlir][spirv] Consider target when converting one-element vector

Lei Zhang llvmlistbot at llvm.org
Tue Oct 18 22:55:55 PDT 2022


Author: Lei Zhang
Date: 2022-10-19T05:49:32Z
New Revision: fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d

URL: https://github.com/llvm/llvm-project/commit/fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d
DIFF: https://github.com/llvm/llvm-project/commit/fce33e1140bbf5ddf2afa5c3be89433ed2a70e4d.diff

LOG: [mlir][spirv] Consider target when converting one-element vector

Vectors with just one element will be converted into scalars.
However, we cannot just return the element types and assume it
is supported in the target environment; we need to conver the
element type again factoring in those considerations.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D136226

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 083b997aeaef5..2514cfe0301a1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
@@ -239,8 +240,9 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
                               const SPIRVConversionOptions &options,
                               VectorType type,
                               Optional<spirv::StorageClass> storageClass = {}) {
+  auto scalarType = type.getElementType().cast<spirv::ScalarType>();
   if (type.getRank() <= 1 && type.getNumElements() == 1)
-    return type.getElementType();
+    return convertScalarType(targetEnv, options, scalarType, storageClass);
 
   if (!spirv::CompositeType::isValid(type)) {
     // TODO: Vector types with more than four elements can be translated into
@@ -260,9 +262,8 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
     return type;
 
-  auto elementType = convertScalarType(
-      targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
-      storageClass);
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
   if (elementType)
     return VectorType::get(type.getShape(), elementType);
   return nullptr;

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 799d8c34aea18..4f1cd09efec30 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -207,6 +207,10 @@ func.func @float_vector(
   %arg1: vector<3xf64>
 ) { return }
 
+// CHECK-LABEL: spirv.func @one_element_vector
+// CHECK-SAME: %{{.+}}: i32
+func.func @one_element_vector(%arg0: vector<1xi8>) { return }
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list