[Mlir-commits] [mlir] [mlir][spirv] Linearize ND vectors. (PR #80451)

Ivan Butygin llvmlistbot at llvm.org
Fri Feb 2 07:45:13 PST 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/80451

SPIR-V only supports 1D vectors, try to linearize vector in type converter.

Not sure is this is a right approach or we should have a dedicated vector linearization pass.

>From 3ac69457bc32c4866e60db1200499ae7dfe7cbc3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 2 Feb 2024 16:36:10 +0100
Subject: [PATCH] [mlir][spirv] Linearize ND vectors.

SPIR-V only supports 1D vectors, try to linearize vector in type converter.
---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp        | 9 +++++----
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp    | 4 ++++
 .../ArithToSPIRV/arith-to-spirv-unsupported.mlir         | 4 ++--
 mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir    | 2 ++
 4 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f39..b55eda69f99ec 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final
     if (!srcType || srcType.getNumElements() == 1)
       return failure();
 
-    // arith.constant should only have vector or tenor types.
-    assert((isa<VectorType, RankedTensorType>(srcType)));
+    assert((isa<VectorType, RankedTensorType>(srcType) &&
+            "arith.constant should only have vector or tensor types"));
 
     Type dstType = getTypeConverter()->convertType(srcType);
     if (!dstType)
@@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final
                                             srcType.getElementType());
         dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
       } else {
-        // TODO: add support for large vectors.
-        return failure();
+        dstAttrType =
+            VectorType::get(srcType.getNumElements(), srcType.getElementType());
+        dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
       }
     }
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e8..1ce7dff8ff0e4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   if (type.getRank() <= 1 && type.getNumElements() == 1)
     return convertScalarType(targetEnv, options, scalarType, storageClass);
 
+  // Linearize ND vectors
+  if (type.getRank() > 1)
+    type = VectorType::get(type.getNumElements(), scalarType);
+
   if (!spirv::CompositeType::isValid(type)) {
     LLVM_DEBUG(llvm::dbgs()
                << type << " illegal: not a valid composite type\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..551b036ba85e5 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -26,9 +26,9 @@ module attributes {
     #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
 } {
 
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) {
   // expected-error at +1 {{failed to legalize operation 'arith.muli'}}
-  %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+  %2 = arith.muli %arg0, %arg0: vector<3x5xi32>
   return
 }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..4a2ef1f0275c6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -556,6 +556,8 @@ func.func @constant() {
   %9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
   // CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
   %10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
+  // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
   return
 }
 



More information about the Mlir-commits mailing list