[Mlir-commits] [mlir] dfee4c7 - [mlir][spirv] Fix scf.yield pattern conversion
Jakub Kuderski
llvmlistbot at llvm.org
Tue Mar 14 15:49:34 PDT 2023
Author: Jakub Kuderski
Date: 2023-03-14T18:47:34-04:00
New Revision: dfee4c7fb0fb02cf05e9a16b7fe058557f33eb15
URL: https://github.com/llvm/llvm-project/commit/dfee4c7fb0fb02cf05e9a16b7fe058557f33eb15
DIFF: https://github.com/llvm/llvm-project/commit/dfee4c7fb0fb02cf05e9a16b7fe058557f33eb15.diff
LOG: [mlir][spirv] Fix scf.yield pattern conversion
Only rewrite `scf.yield` when the parent op is supported by
scf-to-spirv.
Fixes: #61380, #61107, #61148
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D146080
Added:
mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
Modified:
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 2572bbcbd6bb5..81a6378fc7e49 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -291,18 +291,28 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
- // If the region is return values, store each value into the associated
+ Operation *parent = terminatorOp->getParentOp();
+
+ // TODO: Implement conversion for the remaining `scf` ops.
+ if (parent->getDialect()->getNamespace() ==
+ scf::SCFDialect::getDialectNamespace() &&
+ !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
+ return rewriter.notifyMatchFailure(
+ terminatorOp,
+ llvm::formatv("conversion not supported for parent op: '{0}'",
+ parent->getName()));
+
+ // If the region return values, store each value into the associated
// VariableOp created during lowering of the parent region.
if (!operands.empty()) {
- auto &allocas =
- scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
+ auto &allocas = scfToSPIRVContext->outputVars[parent];
if (allocas.size() != operands.size())
return failure();
auto loc = terminatorOp.getLoc();
for (unsigned i = 0, e = operands.size(); i < e; i++)
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
- if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
+ if (isa<spirv::LoopOp>(parent)) {
// For loops we also need to update the branch jumping back to the
// header.
auto br = cast<spirv::BranchOp>(
diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
new file mode 100644
index 0000000000000..6f388f366f744
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+// `scf.parallel` conversion is not supported yet.
+// Make sure that we do not accidentally invalidate this functio by removing
+// `scf.yield`.
+// CHECK-LABEL: func.func @func
+// CHECK: scf.parallel
+// CHECK-NEXT: spirv.Constant
+// CHECK-NEXT: memref.store
+// CHECK-NEXT: scf.yield
+// CHECK: spirv.Return
+func.func @func(%arg0: i64) {
+ %0 = arith.index_cast %arg0 : i64 to index
+ %alloc = memref.alloc() : memref<16xf32>
+ scf.parallel (%arg1) = (%0) to (%0) step (%0) {
+ %cst = arith.constant 1.000000e+00 : f32
+ memref.store %cst, %alloc[%arg1] : memref<16xf32>
+ scf.yield
+ }
+ return
+}
More information about the Mlir-commits
mailing list