[Mlir-commits] [mlir] [mlir][vector] Add `use64bitIndex` option for VectorToSPIRVPass (PR #97061)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 28 07:34:07 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
This PR adds support for `use64bitIndex` option when lowering vector to SPIRV (e.g, `vector<...xindex>` to `vector<...xi64>`), instead of the current default lowering to i32.
---
Full diff: https://github.com/llvm/llvm-project/pull/97061.diff
4 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+5)
- (modified) mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h (+3)
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp (+3-1)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+43-13)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..44bc9b2e0a064 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1380,6 +1380,11 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
let summary = "Convert Vector dialect to SPIR-V dialect";
let constructor = "mlir::createConvertVectorToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
+ let options = [
+ Option<"use64bitIndex", "use-64bit-index",
+ "bool", /*default=*/"false",
+ "Use 64-bit integers to convert index types">,
+ ];
}
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..0df1afe196010 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -18,6 +18,9 @@
namespace mlir {
class SPIRVTypeConverter;
+#define GEN_PASS_DECL_CONVERTVECTORTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+
/// Appends to a pattern list additional patterns for translating Vector Ops to
/// SPIR-V ops.
void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 1932de1be603b..c9f8db36b4efd 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -40,7 +40,9 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
- SPIRVTypeConverter typeConverter(targetAttr);
+ SPIRVConversionOptions options;
+ options.use64bitIndex = this->use64bitIndex;
+ SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull in
// patterns for other dialects.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..dd34aa5ae0b33 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1,4 +1,6 @@
// RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv=use-64bit-index=false -verify-diagnostics %s -o - | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv=use-64bit-index=true -verify-diagnostics %s -o - | FileCheck %s --check-prefix=INDEX64
module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16], []>, #spirv.resource_limits<>> } {
@@ -182,12 +184,26 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
}
// -----
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+ // CHECK-LABEL: @insert_index_vector
+ // INDEX32-LABEL: @insert_index_vector
+ // INDEX64-LABEL: @insert_index_vector
+ // CHECK-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+ // INDEX32-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+ // INDEX64-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+ func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+ // CHECK: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi32>
+ // INDEX32: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi32>
+ // INDEX64: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi64>
-// CHECK-LABEL: @insert_index_vector
-// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
-func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
- %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
- return %1: vector<4xindex>
+ // CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+ // INDEX32: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+ // INDEX64: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i64 into vector<4xi64>
+ %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+ return %1: vector<4xindex>
+ }
}
// -----
@@ -411,14 +427,28 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
// -----
-// CHECK-LABEL: func @shuffle_index_vector
-// CHECK-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
-// CHECK-DAG: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-// CHECK-DAG: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
-// CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
-func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
- %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
- return %shuffle : vector<4xindex>
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+ // CHECK-LABEL: func @shuffle_index_vector
+ // INDEX32-LABEL: func @shuffle_index_vector
+ // INDEX64-LABEL: func @shuffle_index_vector
+ // CHECK-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+ // INDEX32-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+ // INDEX64-SAME: %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+ // CHECK-DAG: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK-DAG: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+ // INDEX32-DAG: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // INDEX32-DAG: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+ // INDEX64-DAG: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // INDEX64-DAG: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+ func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
+ // CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+ // INDEX32: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+ // INDEX64: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i64, i64, i64, i64) -> vector<4xi64>
+ %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
+ return %shuffle : vector<4xindex>
+ }
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/97061
More information about the Mlir-commits
mailing list