[Mlir-commits] [mlir] [mlir][spirv] Enable using 64 bit index when lowering cf to spirv (PR #77344)

Jungwook Park llvmlistbot at llvm.org
Mon Jan 8 09:06:58 PST 2024


https://github.com/jungpark-mlir created https://github.com/llvm/llvm-project/pull/77344

None

>From 7605a97e44f69e150143b35d7bf28f548a7d49a7 Mon Sep 17 00:00:00 2001
From: jungpark-mlir <jungwook at jungwook-22.04>
Date: Mon, 8 Jan 2024 16:17:18 +0000
Subject: [PATCH 1/2] Add 'use-64bit-index' option to the 'convert-cf-to-spirv'
 pass.

---
 mlir/include/mlir/Conversion/Passes.td                 |  5 ++++-
 .../ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp      |  1 +
 .../Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir | 10 ++++++++++
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6193aeb545bc6b..ee8f845b510b16 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -340,7 +340,10 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
   ];
 }
 
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index a752b82eac7c34..a718df7cbe5c54 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.use64bitIndex = this->use64bitIndex;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // TODO: We should also take care of block argument type conversion.
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 4e2f6a12c8e7ee..3ea85a7efbf537 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt --split-input-file --convert-cf-to-spirv --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-cf-to-spirv='use-64bit-index=true' --verify-diagnostics %s | FileCheck %s -check-prefix=INDEX64
 
 //===----------------------------------------------------------------------===//
 // cf.br, cf.cond_br
@@ -41,19 +42,28 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
 
 // Handle blocks whose arguments require type conversion.
 
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+  
 // CHECK-LABEL: func.func @main_graph
 func.func @main_graph(%arg0: index) {
   %c3 = arith.constant 1 : index
 // CHECK:  spirv.Branch ^bb1({{.*}} : i32)
+// INDEX64:  spirv.Branch ^bb1({{.*}} : i64)
   cf.br ^bb1(%arg0 : index)
 // CHECK:      ^bb1({{.*}}: i32):       // 2 preds: ^bb0, ^bb2
+// INDEX64:      ^bb1({{.*}}: i64):       // 2 preds: ^bb0, ^bb2
 ^bb1(%0: index):  // 2 preds: ^bb0, ^bb2
   %1 = arith.cmpi slt, %0, %c3 : index
 // CHECK:        spirv.BranchConditional {{.*}}, ^bb2, ^bb3
   cf.cond_br %1, ^bb2, ^bb3
 ^bb2:  // pred: ^bb1
 // CHECK:  spirv.Branch ^bb1({{.*}} : i32)
+// INDEX64:  spirv.Branch ^bb1({{.*}} : i64)
   cf.br ^bb1(%c3 : index)
 ^bb3:  // pred: ^bb1
   return
 }
+
+}

>From 9a67841a5fad75130f3430e68edc6529276e23a3 Mon Sep 17 00:00:00 2001
From: jungpark-mlir <jungwook at jungwook-22.04>
Date: Mon, 8 Jan 2024 16:20:48 +0000
Subject: [PATCH 2/2] Convert unrealizedConversionCastOp in the 'func-to-spirv'
 pass when it's possible. Borrowed the same code from 'func-to-llvm'.

---
 .../Conversion/FuncToSPIRV/FuncToSPIRV.cpp    | 30 ++++++++++++++++++-
 .../FuncToSPIRV/func-ops-to-spirv.mlir        | 15 ++++++++++
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
index a5755514b3e584..b937ecc31365b3 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
@@ -82,6 +82,32 @@ class CallOpPattern final : public OpConversionPattern<func::CallOp> {
   }
 };
 
+class UnrealizedConversionCastOpPattern final
+    : public OpConversionPattern<UnrealizedConversionCastOp> {
+  using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Type> convertedTypes;
+    if (succeeded(getTypeConverter()->convertTypes(op.getOutputs().getTypes(),
+                                                   convertedTypes)) &&
+        convertedTypes == adaptor.getInputs().getTypes()) {
+      rewriter.replaceOp(op, adaptor.getInputs());
+      return success();
+    }
+
+    convertedTypes.clear();
+    if (succeeded(getTypeConverter()->convertTypes(
+            adaptor.getInputs().getTypes(), convertedTypes)) &&
+        convertedTypes == op.getOutputs().getType()) {
+      rewriter.replaceOp(op, adaptor.getInputs());
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -92,5 +118,7 @@ void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                        RewritePatternSet &patterns) {
   MLIRContext *context = patterns.getContext();
 
-  patterns.add<ReturnOpPattern, CallOpPattern>(typeConverter, context);
+  patterns
+      .add<ReturnOpPattern, CallOpPattern, UnrealizedConversionCastOpPattern>(
+          typeConverter, context);
 }
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index a09f1697fd7249..2f0a5f9a7ff002 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -46,6 +46,21 @@ func.func @call_functions(%arg0: index) -> index {
   return %0: index
 }
 
+// CHECK-LABEL: spirv.func @resolve_unrealized_conversion_cast
+func.func @resolve_unrealized_conversion_cast(%arg0: index, %arg1: index) {
+  %0 = builtin.unrealized_conversion_cast %arg1 : index to i32
+  %1 = builtin.unrealized_conversion_cast %arg0 : index to i32
+// CHECK-NOT: builtin.unrealized_conversion_cast
+  spirv.Branch ^bb1(%1 : i32)
+^bb1(%2: i32):  // 2 preds: ^bb0, ^bb2
+  %3 = spirv.SLessThan %2, %0 : i32
+  spirv.BranchConditional %3, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  spirv.Branch ^bb1(%0 : i32)
+^bb3:  // pred: ^bb1
+  return
+}
+
 }
 
 // -----



More information about the Mlir-commits mailing list