[Mlir-commits] [mlir] dfee4c7 - [mlir][spirv] Fix scf.yield pattern conversion

Jakub Kuderski llvmlistbot at llvm.org
Tue Mar 14 15:49:34 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-14T18:47:34-04:00
New Revision: dfee4c7fb0fb02cf05e9a16b7fe058557f33eb15

URL: https://github.com/llvm/llvm-project/commit/dfee4c7fb0fb02cf05e9a16b7fe058557f33eb15
DIFF: https://github.com/llvm/llvm-project/commit/dfee4c7fb0fb02cf05e9a16b7fe058557f33eb15.diff

LOG: [mlir][spirv] Fix scf.yield pattern conversion

Only rewrite `scf.yield` when the parent op is supported by
scf-to-spirv.

Fixes: #61380, #61107, #61148

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D146080

Added: 
    mlir/test/Conversion/SCFToSPIRV/unsupported.mlir

Modified: 
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 2572bbcbd6bb5..81a6378fc7e49 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -291,18 +291,28 @@ struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
                   ConversionPatternRewriter &rewriter) const override {
     ValueRange operands = adaptor.getOperands();
 
-    // If the region is return values, store each value into the associated
+    Operation *parent = terminatorOp->getParentOp();
+
+    // TODO: Implement conversion for the remaining `scf` ops.
+    if (parent->getDialect()->getNamespace() ==
+            scf::SCFDialect::getDialectNamespace() &&
+        !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
+      return rewriter.notifyMatchFailure(
+          terminatorOp,
+          llvm::formatv("conversion not supported for parent op: '{0}'",
+                        parent->getName()));
+
+    // If the region return values, store each value into the associated
     // VariableOp created during lowering of the parent region.
     if (!operands.empty()) {
-      auto &allocas =
-          scfToSPIRVContext->outputVars[terminatorOp->getParentOp()];
+      auto &allocas = scfToSPIRVContext->outputVars[parent];
       if (allocas.size() != operands.size())
         return failure();
 
       auto loc = terminatorOp.getLoc();
       for (unsigned i = 0, e = operands.size(); i < e; i++)
         rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
-      if (isa<spirv::LoopOp>(terminatorOp->getParentOp())) {
+      if (isa<spirv::LoopOp>(parent)) {
         // For loops we also need to update the branch jumping back to the
         // header.
         auto br = cast<spirv::BranchOp>(

diff  --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
new file mode 100644
index 0000000000000..6f388f366f744
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+// `scf.parallel` conversion is not supported yet.
+// Make sure that we do not accidentally invalidate this functio by removing
+// `scf.yield`.
+// CHECK-LABEL: func.func @func
+// CHECK:         scf.parallel
+// CHECK-NEXT:      spirv.Constant
+// CHECK-NEXT:      memref.store
+// CHECK-NEXT:      scf.yield
+// CHECK:         spirv.Return
+func.func @func(%arg0: i64) {
+  %0 = arith.index_cast %arg0 : i64 to index
+  %alloc = memref.alloc() : memref<16xf32>
+  scf.parallel (%arg1) = (%0) to (%0) step (%0) {
+    %cst = arith.constant 1.000000e+00 : f32
+    memref.store %cst, %alloc[%arg1] : memref<16xf32>
+    scf.yield
+  }
+  return
+}


        


More information about the Mlir-commits mailing list