[Mlir-commits] [mlir] [mlir][vector] Add `use64bitIndex` option for VectorToSPIRVPass (PR #97061)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 28 07:34:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Artem Kroviakov (akroviakov)

<details>
<summary>Changes</summary>

This PR adds support for `use64bitIndex` option when lowering vector to SPIRV (e.g, `vector<...xindex>` to  `vector<...xi64>`), instead of the current default lowering to i32.

---
Full diff: https://github.com/llvm/llvm-project/pull/97061.diff


4 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+5) 
- (modified) mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h (+3) 
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp (+3-1) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+43-13) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..44bc9b2e0a064 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1380,6 +1380,11 @@ def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
   let summary = "Convert Vector dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertVectorToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">,
+  ];
 }
 
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index f8c02c54066b8..0df1afe196010 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -18,6 +18,9 @@
 namespace mlir {
 class SPIRVTypeConverter;
 
+#define GEN_PASS_DECL_CONVERTVECTORTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+
 /// Appends to a pattern list additional patterns for translating Vector Ops to
 /// SPIR-V ops.
 void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 1932de1be603b..c9f8db36b4efd 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -40,7 +40,9 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
-  SPIRVTypeConverter typeConverter(targetAttr);
+  SPIRVConversionOptions options;
+  options.use64bitIndex = this->use64bitIndex;
+  SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
   // patterns for other dialects.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..dd34aa5ae0b33 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1,4 +1,6 @@
 // RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv=use-64bit-index=false -verify-diagnostics %s -o - | FileCheck %s --check-prefix=INDEX32
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv=use-64bit-index=true -verify-diagnostics %s -o - | FileCheck %s --check-prefix=INDEX64
 
 module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16], []>, #spirv.resource_limits<>> } {
 
@@ -182,12 +184,26 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 }
 
 // -----
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: @insert_index_vector
+  // INDEX32-LABEL: @insert_index_vector
+  // INDEX64-LABEL: @insert_index_vector
+  // CHECK-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+  // INDEX32-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+  // INDEX64-SAME: %[[IN_VEC:.*]]: vector<4xindex>
+  func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
+    // CHECK: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi32> 
+    // INDEX32: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi32> 
+    // INDEX64: builtin.unrealized_conversion_cast %[[IN_VEC]] : vector<4xindex> to vector<4xi64> 
 
-// CHECK-LABEL: @insert_index_vector
-//       CHECK:   spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
-func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {
-  %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
-  return %1: vector<4xindex>
+    // CHECK:   spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+    // INDEX32: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
+    // INDEX64: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i64 into vector<4xi64>
+    %1 = vector.insert %arg1, %arg0[2] : index into vector<4xindex>
+    return %1: vector<4xindex>
+  }
 }
 
 // -----
@@ -411,14 +427,28 @@ func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> {
 
 // -----
 
-// CHECK-LABEL:  func @shuffle_index_vector
-//  CHECK-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
-//   CHECK-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
-//   CHECK-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
-//       CHECK:    spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
-func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {
-  %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
-  return %shuffle : vector<4xindex>
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL:  func @shuffle_index_vector
+  // INDEX32-LABEL:  func @shuffle_index_vector
+  // INDEX64-LABEL:  func @shuffle_index_vector
+  //  CHECK-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+  //  INDEX32-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+  //  INDEX64-SAME:  %[[ARG0:.+]]: vector<1xindex>, %[[ARG1:.+]]: vector<1xindex>
+  //   CHECK-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   CHECK-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+  //   INDEX32-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   INDEX32-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+  //   INDEX64-DAG:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  //   INDEX64-DAG:    %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]]
+  func.func @shuffle_index_vector(%v0 : vector<1xindex>, %v1: vector<1xindex>) -> vector<4xindex> {    
+    //  CHECK: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+    //  INDEX32: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i32, i32, i32, i32) -> vector<4xi32>
+    //  INDEX64: spirv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (i64, i64, i64, i64) -> vector<4xi64>
+    %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xindex>, vector<1xindex>
+    return %shuffle : vector<4xindex>
+  }
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/97061


More information about the Mlir-commits mailing list