[Mlir-commits] [mlir] [mlir][spirv] Enable using 64 bit index when lowering cf to spirv (PR #77344)
Jungwook Park
llvmlistbot at llvm.org
Mon Jan 8 09:06:58 PST 2024
https://github.com/jungpark-mlir created https://github.com/llvm/llvm-project/pull/77344
None
>From 7605a97e44f69e150143b35d7bf28f548a7d49a7 Mon Sep 17 00:00:00 2001
From: jungpark-mlir <jungwook at jungwook-22.04>
Date: Mon, 8 Jan 2024 16:17:18 +0000
Subject: [PATCH 1/2] Add 'use-64bit-index' option to the 'convert-cf-to-spirv'
pass.
---
mlir/include/mlir/Conversion/Passes.td | 5 ++++-
.../ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp | 1 +
.../Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir | 10 ++++++++++
3 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6193aeb545bc6b..ee8f845b510b16 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -340,7 +340,10 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"use64bitIndex", "use-64bit-index",
+ "bool", /*default=*/"false",
+ "Use 64-bit integers to convert index types">
];
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index a752b82eac7c34..a718df7cbe5c54 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
index 4e2f6a12c8e7ee..3ea85a7efbf537 100644
--- a/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/ControlFlowToSPIRV/cf-ops-to-spirv.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file --convert-cf-to-spirv --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --convert-cf-to-spirv='use-64bit-index=true' --verify-diagnostics %s | FileCheck %s -check-prefix=INDEX64
//===----------------------------------------------------------------------===//
// cf.br, cf.cond_br
@@ -41,19 +42,28 @@ func.func @simple_loop(%begin: i32, %end: i32, %step: i32) {
// Handle blocks whose arguments require type conversion.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int64], []>, #spirv.resource_limits<>>
+} {
+
// CHECK-LABEL: func.func @main_graph
func.func @main_graph(%arg0: index) {
%c3 = arith.constant 1 : index
// CHECK: spirv.Branch ^bb1({{.*}} : i32)
+// INDEX64: spirv.Branch ^bb1({{.*}} : i64)
cf.br ^bb1(%arg0 : index)
// CHECK: ^bb1({{.*}}: i32): // 2 preds: ^bb0, ^bb2
+// INDEX64: ^bb1({{.*}}: i64): // 2 preds: ^bb0, ^bb2
^bb1(%0: index): // 2 preds: ^bb0, ^bb2
%1 = arith.cmpi slt, %0, %c3 : index
// CHECK: spirv.BranchConditional {{.*}}, ^bb2, ^bb3
cf.cond_br %1, ^bb2, ^bb3
^bb2: // pred: ^bb1
// CHECK: spirv.Branch ^bb1({{.*}} : i32)
+// INDEX64: spirv.Branch ^bb1({{.*}} : i64)
cf.br ^bb1(%c3 : index)
^bb3: // pred: ^bb1
return
}
+
+}
>From 9a67841a5fad75130f3430e68edc6529276e23a3 Mon Sep 17 00:00:00 2001
From: jungpark-mlir <jungwook at jungwook-22.04>
Date: Mon, 8 Jan 2024 16:20:48 +0000
Subject: [PATCH 2/2] Convert unrealizedConversionCastOp in the 'func-to-spirv'
pass when it's possible. Borrowed the same code from 'func-to-llvm'.
---
.../Conversion/FuncToSPIRV/FuncToSPIRV.cpp | 30 ++++++++++++++++++-
.../FuncToSPIRV/func-ops-to-spirv.mlir | 15 ++++++++++
2 files changed, 44 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
index a5755514b3e584..b937ecc31365b3 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp
@@ -82,6 +82,32 @@ class CallOpPattern final : public OpConversionPattern<func::CallOp> {
}
};
+class UnrealizedConversionCastOpPattern final
+ : public OpConversionPattern<UnrealizedConversionCastOp> {
+ using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Type> convertedTypes;
+ if (succeeded(getTypeConverter()->convertTypes(op.getOutputs().getTypes(),
+ convertedTypes)) &&
+ convertedTypes == adaptor.getInputs().getTypes()) {
+ rewriter.replaceOp(op, adaptor.getInputs());
+ return success();
+ }
+
+ convertedTypes.clear();
+ if (succeeded(getTypeConverter()->convertTypes(
+ adaptor.getInputs().getTypes(), convertedTypes)) &&
+ convertedTypes == op.getOutputs().getType()) {
+ rewriter.replaceOp(op, adaptor.getInputs());
+ return success();
+ }
+ return failure();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -92,5 +118,7 @@ void mlir::populateFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<ReturnOpPattern, CallOpPattern>(typeConverter, context);
+ patterns
+ .add<ReturnOpPattern, CallOpPattern, UnrealizedConversionCastOpPattern>(
+ typeConverter, context);
}
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index a09f1697fd7249..2f0a5f9a7ff002 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -46,6 +46,21 @@ func.func @call_functions(%arg0: index) -> index {
return %0: index
}
+// CHECK-LABEL: spirv.func @resolve_unrealized_conversion_cast
+func.func @resolve_unrealized_conversion_cast(%arg0: index, %arg1: index) {
+ %0 = builtin.unrealized_conversion_cast %arg1 : index to i32
+ %1 = builtin.unrealized_conversion_cast %arg0 : index to i32
+// CHECK-NOT: builtin.unrealized_conversion_cast
+ spirv.Branch ^bb1(%1 : i32)
+^bb1(%2: i32): // 2 preds: ^bb0, ^bb2
+ %3 = spirv.SLessThan %2, %0 : i32
+ spirv.BranchConditional %3, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ spirv.Branch ^bb1(%0 : i32)
+^bb3: // pred: ^bb1
+ return
+}
+
}
// -----
More information about the Mlir-commits
mailing list