[Mlir-commits] [mlir] [mlir][spirv] Linearize ND vectors. (PR #80451)
Ivan Butygin
llvmlistbot at llvm.org
Fri Feb 2 07:45:13 PST 2024
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/80451
SPIR-V only supports 1D vectors, try to linearize vector in type converter.
Not sure is this is a right approach or we should have a dedicated vector linearization pass.
>From 3ac69457bc32c4866e60db1200499ae7dfe7cbc3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 2 Feb 2024 16:36:10 +0100
Subject: [PATCH] [mlir][spirv] Linearize ND vectors.
SPIR-V only supports 1D vectors, try to linearize vector in type converter.
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 +++++----
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 4 ++++
.../ArithToSPIRV/arith-to-spirv-unsupported.mlir | 4 ++--
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 2 ++
4 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f39..b55eda69f99ec 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final
if (!srcType || srcType.getNumElements() == 1)
return failure();
- // arith.constant should only have vector or tenor types.
- assert((isa<VectorType, RankedTensorType>(srcType)));
+ assert((isa<VectorType, RankedTensorType>(srcType) &&
+ "arith.constant should only have vector or tensor types"));
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
@@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
} else {
- // TODO: add support for large vectors.
- return failure();
+ dstAttrType =
+ VectorType::get(srcType.getNumElements(), srcType.getElementType());
+ dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
}
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e8..1ce7dff8ff0e4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
if (type.getRank() <= 1 && type.getNumElements() == 1)
return convertScalarType(targetEnv, options, scalarType, storageClass);
+ // Linearize ND vectors
+ if (type.getRank() > 1)
+ type = VectorType::get(type.getNumElements(), scalarType);
+
if (!spirv::CompositeType::isValid(type)) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: not a valid composite type\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..551b036ba85e5 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -26,9 +26,9 @@ module attributes {
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) {
// expected-error at +1 {{failed to legalize operation 'arith.muli'}}
- %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+ %2 = arith.muli %arg0, %arg0: vector<3x5xi32>
return
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..4a2ef1f0275c6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -556,6 +556,8 @@ func.func @constant() {
%9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
// CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
%10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
+ // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
return
}
More information about the Mlir-commits
mailing list