[Mlir-commits] [mlir] a16fbff - [mlir][spirv] Create a pass for testing SCFToSPIRV patterns

Lei Zhang llvmlistbot at llvm.org
Wed Dec 23 11:32:08 PST 2020


Author: Lei Zhang
Date: 2020-12-23T14:31:55-05:00
New Revision: a16fbff17d329c3f2cc1e49d501f61b3996e9b8a

URL: https://github.com/llvm/llvm-project/commit/a16fbff17d329c3f2cc1e49d501f61b3996e9b8a
DIFF: https://github.com/llvm/llvm-project/commit/a16fbff17d329c3f2cc1e49d501f61b3996e9b8a.diff

LOG: [mlir][spirv] Create a pass for testing SCFToSPIRV patterns

Previously all SCF to SPIR-V conversion patterns were tested as
the -convert-gpu-to-spirv pass. That obscured the structure we
want. This commit fixed it.

Reviewed By: ThomasRaoux, hanchung

Differential Revision: https://reviews.llvm.org/D93488

Added: 
    mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
    mlir/test/Conversion/GPUToSPIRV/entry-point.mlir
    mlir/test/Conversion/SCFToSPIRV/for.mlir
    mlir/test/Conversion/SCFToSPIRV/if.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
    mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt

Removed: 
    mlir/test/Conversion/GPUToSPIRV/if.mlir
    mlir/test/Conversion/GPUToSPIRV/loop.mlir
    mlir/test/Conversion/GPUToSPIRV/test_spirv_entry_point.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index cc4f59c12496..21b35804ab36 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b364700bd849..2dc438534a44 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -230,6 +230,23 @@ def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> {
   let dependentDialects = ["omp::OpenMPDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// SCFToSPIRV
+//===----------------------------------------------------------------------===//
+
+def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
+  let summary = "Convert SCF dialect to SPIR-V dialect.";
+  let description = [{
+    This pass converts SCF ops into SPIR-V structured control flow ops.
+    SPIR-V structured control flow ops does not support yielding values.
+    So for SCF ops yielding values, SPIR-V variables are created for
+    holding the values and load/store operations are emitted for updating
+    them.
+  }];
+  let constructor = "mlir::createConvertSCFToSPIRVPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // SCFToStandard
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
new file mode 100644
index 000000000000..94705d2200c4
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
@@ -0,0 +1,21 @@
+//===- SCFToSPIRVPass.h - SCF to SPIR-V Conversion Pass ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRVPASS_H
+#define MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRVPASS_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+/// Creates a pass to convert SCF ops into SPIR-V ops.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToSPIRVPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRVPASS_H

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 973aa3d79bd8..2c2a47fcc5ed 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -12,12 +12,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
+
 #include "../PassDetail.h"
 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
-#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -58,10 +57,8 @@ void GPUToSPIRVPass::runOnOperation() {
       spirv::SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
-  ScfToSPIRVContext scfContext;
   OwningRewritePatternList patterns;
   populateGPUToSPIRVPatterns(context, typeConverter, patterns);
-  populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
   if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
index d77a759d5f38..d9400716a0b2 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRSCFToSPIRV
   SCFToSPIRV.cpp
+  SCFToSPIRVPass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToSPIRV
@@ -8,12 +9,10 @@ add_mlir_conversion_library(MLIRSCFToSPIRV
   MLIRConversionPassIncGen
 
   LINK_LIBS PUBLIC
-  MLIRAffine
-  MLIRAffineToStandard
   MLIRSPIRV
   MLIRSPIRVConversion
+  MLIRStandardToSPIRVTransforms
   MLIRIR
-  MLIRLinalg
   MLIRPass
   MLIRStandard
   MLIRSupport

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
new file mode 100644
index 000000000000..5c74cc3397be
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -0,0 +1,51 @@
+//===- SCFToSPIRVPass.cpp - SCF to SPIR-V Dialect Conversion Pass ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert SCF dialect into SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct SCFToSPIRVPass : public SCFToSPIRVBase<SCFToSPIRVPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void SCFToSPIRVPass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  ModuleOp module = getOperation();
+
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  std::unique_ptr<ConversionTarget> target =
+      spirv::SPIRVConversionTarget::get(targetAttr);
+
+  SPIRVTypeConverter typeConverter(targetAttr);
+  ScfToSPIRVContext scfContext;
+  OwningRewritePatternList patterns;
+  populateSCFToSPIRVPatterns(context, typeConverter, scfContext, patterns);
+  populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+
+  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+    return signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToSPIRVPass() {
+  return std::make_unique<SCFToSPIRVPass>();
+}

diff  --git a/mlir/test/Conversion/GPUToSPIRV/test_spirv_entry_point.mlir b/mlir/test/Conversion/GPUToSPIRV/entry-point.mlir
similarity index 100%
rename from mlir/test/Conversion/GPUToSPIRV/test_spirv_entry_point.mlir
rename to mlir/test/Conversion/GPUToSPIRV/entry-point.mlir

diff  --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir
deleted file mode 100644
index e5e4a4fc7e97..000000000000
--- a/mlir/test/Conversion/GPUToSPIRV/if.mlir
+++ /dev/null
@@ -1,167 +0,0 @@
-// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
-
-module attributes {
-  gpu.container_module,
-  spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
-} {
-  func @main(%arg0 : memref<10xf32>, %arg1 : i1) {
-    %c0 = constant 1 : index
-    gpu.launch_func @kernels::@kernel_simple_selection
-        blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0)
-        args(%arg0 : memref<10xf32>, %arg1 : i1)
-    return
-  }
-
-  gpu.module @kernels {
-    // CHECK-LABEL: @kernel_simple_selection
-    gpu.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) kernel
-    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
-      %value = constant 0.0 : f32
-      %i = constant 0 : index
-
-      // CHECK:       spv.selection {
-      // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]]
-      // CHECK-NEXT:  [[TRUE]]:
-      // CHECK:         spv.Branch [[MERGE]]
-      // CHECK-NEXT:  [[MERGE]]:
-      // CHECK-NEXT:    spv.mlir.merge
-      // CHECK-NEXT:  }
-      // CHECK-NEXT:  spv.Return
-
-      scf.if %arg3 {
-        store %value, %arg2[%i] : memref<10xf32>
-      }
-      gpu.return
-    }
-
-    // CHECK-LABEL: @kernel_nested_selection
-    gpu.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) kernel
-    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
-      %i = constant 0 : index
-      %j = constant 9 : index
-
-      // CHECK:       spv.selection {
-      // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]]
-      // CHECK-NEXT:  [[TRUE_TOP]]:
-      // CHECK-NEXT:    spv.selection {
-      // CHECK-NEXT:      spv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]]
-      // CHECK-NEXT:    [[TRUE_NESTED_TRUE_PATH]]:
-      // CHECK:           spv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]]
-      // CHECK-NEXT:    [[FALSE_NESTED_TRUE_PATH]]:
-      // CHECK:           spv.Branch [[MERGE_NESTED_TRUE_PATH]]
-      // CHECK-NEXT:    [[MERGE_NESTED_TRUE_PATH]]:
-      // CHECK-NEXT:      spv.mlir.merge
-      // CHECK-NEXT:    }
-      // CHECK-NEXT:    spv.Branch [[MERGE_TOP:\^.*]]
-      // CHECK-NEXT:  [[FALSE_TOP]]:
-      // CHECK-NEXT:    spv.selection {
-      // CHECK-NEXT:      spv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]]
-      // CHECK-NEXT:    [[TRUE_NESTED_FALSE_PATH]]:
-      // CHECK:           spv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]]
-      // CHECK-NEXT:    [[FALSE_NESTED_FALSE_PATH]]:
-      // CHECK:           spv.Branch [[MERGE_NESTED_FALSE_PATH]]
-      // CHECK:         [[MERGE_NESTED_FALSE_PATH]]:
-      // CHECK-NEXT:      spv.mlir.merge
-      // CHECK-NEXT:    }
-      // CHECK-NEXT:    spv.Branch [[MERGE_TOP]]
-      // CHECK-NEXT:  [[MERGE_TOP]]:
-      // CHECK-NEXT:    spv.mlir.merge
-      // CHECK-NEXT:  }
-      // CHECK-NEXT:  spv.Return
-
-      scf.if %arg5 {
-        scf.if %arg6 {
-          %value = load %arg3[%i] : memref<10xf32>
-          store %value, %arg4[%i] : memref<10xf32>
-        } else {
-          %value = load %arg4[%i] : memref<10xf32>
-          store %value, %arg3[%i] : memref<10xf32>
-        }
-      } else {
-        scf.if %arg6 {
-          %value = load %arg3[%j] : memref<10xf32>
-          store %value, %arg4[%j] : memref<10xf32>
-        } else {
-          %value = load %arg4[%j] : memref<10xf32>
-          store %value, %arg3[%j] : memref<10xf32>
-        }
-      }
-      gpu.return
-    }
-    // CHECK-LABEL: @simple_if_yield
-    gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel
-    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
-      // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
-      // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
-      // CHECK:       spv.selection {
-      // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
-      // CHECK-NEXT:  [[TRUE]]:
-      // CHECK:         %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32
-      // CHECK:         %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32
-      // CHECK-DAG:     spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32
-      // CHECK-DAG:     spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32
-      // CHECK:         spv.Branch ^[[MERGE:.*]]
-      // CHECK-NEXT:  [[FALSE]]:
-      // CHECK:         %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32
-      // CHECK:         %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32
-      // CHECK-DAG:     spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32
-      // CHECK-DAG:     spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32
-      // CHECK:         spv.Branch ^[[MERGE]]
-      // CHECK-NEXT:  ^[[MERGE]]:
-      // CHECK:         spv.mlir.merge
-      // CHECK-NEXT:  }
-      // CHECK-DAG:   %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
-      // CHECK-DAG:   %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
-      // CHECK:       spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
-      // CHECK:       spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
-      // CHECK:       spv.Return
-      %0:2 = scf.if %arg3 -> (f32, f32) {
-        %c0 = constant 0.0 : f32
-        %c1 = constant 1.0 : f32
-        scf.yield %c0, %c1 : f32, f32
-      } else {
-        %c0 = constant 2.0 : f32
-        %c1 = constant 3.0 : f32
-        scf.yield %c1, %c0 : f32, f32
-      }
-      %i = constant 0 : index
-      %j = constant 1 : index
-      store %0#0, %arg2[%i] : memref<10xf32>
-      store %0#1, %arg2[%j] : memref<10xf32>
-      gpu.return
-    }
-    // TODO: The transformation should only be legal if
-    // VariablePointer capability is supported. This test is still useful to
-    // make sure we can handle scf op result with type change.
-    // CHECK-LABEL: @simple_if_yield_type_change
-    // CHECK:       %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>, Function>
-    // CHECK:       spv.selection {
-    // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
-    // CHECK-NEXT:  [[TRUE]]:
-    // CHECK:         spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK:         spv.Branch ^[[MERGE:.*]]
-    // CHECK-NEXT:  [[FALSE]]:
-    // CHECK:         spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK:         spv.Branch ^[[MERGE]]
-    // CHECK-NEXT:  ^[[MERGE]]:
-    // CHECK:         spv.mlir.merge
-    // CHECK-NEXT:  }
-    // CHECK:       %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK:       %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
-    // CHECK:       spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32
-    // CHECK:       spv.Return
-    gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel
-    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
-      %i = constant 0 : index
-      %value = constant 0.0 : f32
-      %0 = scf.if %arg4 -> (memref<10xf32>) {
-        scf.yield %arg2 : memref<10xf32>
-      } else {
-        scf.yield %arg3 : memref<10xf32>
-      }
-      store %value, %0[%i] : memref<10xf32>
-      gpu.return
-    }
-  }
-}

diff  --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
deleted file mode 100644
index 812007c1fd64..000000000000
--- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir
+++ /dev/null
@@ -1,98 +0,0 @@
-// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
-
-module attributes {
-  gpu.container_module,
-  spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
-} {
-  func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) {
-    %c0 = constant 1 : index
-    gpu.launch_func @kernels::@loop_kernel
-        blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0)
-        args(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>)
-    return
-  }
-
-  gpu.module @kernels {
-    gpu.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel
-    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
-      // CHECK: %[[LB:.*]] = spv.constant 4 : i32
-      %lb = constant 4 : index
-      // CHECK: %[[UB:.*]] = spv.constant 42 : i32
-      %ub = constant 42 : index
-      // CHECK: %[[STEP:.*]] = spv.constant 2 : i32
-      %step = constant 2 : index
-      // CHECK:      spv.loop {
-      // CHECK-NEXT:   spv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
-      // CHECK:      ^[[HEADER]](%[[INDVAR:.*]]: i32):
-      // CHECK:        %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
-      // CHECK:        spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
-      // CHECK:      ^[[BODY]]:
-      // CHECK:        %[[ZERO1:.*]] = spv.constant 0 : i32
-      // CHECK:        %[[OFFSET1:.*]] = spv.constant 0 : i32
-      // CHECK:        %[[STRIDE1:.*]] = spv.constant 1 : i32
-      // CHECK:        %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
-      // CHECK:        %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32
-      // CHECK:        spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
-      // CHECK:        %[[ZERO2:.*]] = spv.constant 0 : i32
-      // CHECK:        %[[OFFSET2:.*]] = spv.constant 0 : i32
-      // CHECK:        %[[STRIDE2:.*]] = spv.constant 1 : i32
-      // CHECK:        %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
-      // CHECK:        %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32
-      // CHECK:        spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
-      // CHECK:        %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
-      // CHECK:        spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
-      // CHECK:      ^[[MERGE]]
-      // CHECK:        spv.mlir.merge
-      // CHECK:      }
-      scf.for %arg4 = %lb to %ub step %step {
-        %1 = load %arg2[%arg4] : memref<10xf32>
-        store %1, %arg3[%arg4] : memref<10xf32>
-      }
-      gpu.return
-    }
-
-
-    // CHECK-LABEL: @loop_yield
-    gpu.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel
-    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
-      // CHECK: %[[LB:.*]] = spv.constant 4 : i32
-      %lb = constant 4 : index
-      // CHECK: %[[UB:.*]] = spv.constant 42 : i32
-      %ub = constant 42 : index
-      // CHECK: %[[STEP:.*]] = spv.constant 2 : i32
-      %step = constant 2 : index
-      // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32
-      %s0 = constant 0.0 : f32
-      // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32
-      %s1 = constant 1.0 : f32
-      // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
-      // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
-      // CHECK: spv.loop {
-      // CHECK:   spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
-      // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
-      // CHECK:   %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
-      // CHECK:   spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
-      // CHECK: ^[[BODY]]:
-      // CHECK:   %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
-      // CHECK-DAG:   %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
-      // CHECK-DAG:   spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
-      // CHECK-DAG:   spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
-      // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
-      // CHECK: ^[[MERGE]]:
-      // CHECK:   spv.mlir.merge
-      // CHECK: }
-      %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
-        %sn = addf %si, %si : f32
-        scf.yield %sn, %sn : f32, f32
-      }
-      // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
-      // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
-      // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
-      // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
-      store %result#0, %arg3[%lb] : memref<10xf32>
-      store %result#1, %arg3[%ub] : memref<10xf32>
-      gpu.return
-    }
-  }
-}

diff  --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
new file mode 100644
index 000000000000..3e4545b1b1b5
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
+} {
+
+func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
+  // CHECK: %[[LB:.*]] = spv.constant 4 : i32
+  %lb = constant 4 : index
+  // CHECK: %[[UB:.*]] = spv.constant 42 : i32
+  %ub = constant 42 : index
+  // CHECK: %[[STEP:.*]] = spv.constant 2 : i32
+  %step = constant 2 : index
+  // CHECK:      spv.loop {
+  // CHECK-NEXT:   spv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
+  // CHECK:      ^[[HEADER]](%[[INDVAR:.*]]: i32):
+  // CHECK:        %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
+  // CHECK:        spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+  // CHECK:      ^[[BODY]]:
+  // CHECK:        %[[ZERO1:.*]] = spv.constant 0 : i32
+  // CHECK:        %[[OFFSET1:.*]] = spv.constant 0 : i32
+  // CHECK:        %[[STRIDE1:.*]] = spv.constant 1 : i32
+  // CHECK:        %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
+  // CHECK:        %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32
+  // CHECK:        spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
+  // CHECK:        %[[ZERO2:.*]] = spv.constant 0 : i32
+  // CHECK:        %[[OFFSET2:.*]] = spv.constant 0 : i32
+  // CHECK:        %[[STRIDE2:.*]] = spv.constant 1 : i32
+  // CHECK:        %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
+  // CHECK:        %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32
+  // CHECK:        spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
+  // CHECK:        %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
+  // CHECK:        spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
+  // CHECK:      ^[[MERGE]]
+  // CHECK:        spv.mlir.merge
+  // CHECK:      }
+  scf.for %arg4 = %lb to %ub step %step {
+    %1 = load %arg2[%arg4] : memref<10xf32>
+    store %1, %arg3[%arg4] : memref<10xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: @loop_yield
+func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
+  // CHECK: %[[LB:.*]] = spv.constant 4 : i32
+  %lb = constant 4 : index
+  // CHECK: %[[UB:.*]] = spv.constant 42 : i32
+  %ub = constant 42 : index
+  // CHECK: %[[STEP:.*]] = spv.constant 2 : i32
+  %step = constant 2 : index
+  // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32
+  %s0 = constant 0.0 : f32
+  // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32
+  %s1 = constant 1.0 : f32
+  // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
+  // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
+  // CHECK: spv.loop {
+  // CHECK:   spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
+  // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
+  // CHECK:   %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
+  // CHECK:   spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+  // CHECK: ^[[BODY]]:
+  // CHECK:   %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
+  // CHECK-DAG:   %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
+  // CHECK-DAG:   spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
+  // CHECK-DAG:   spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
+  // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
+  // CHECK: ^[[MERGE]]:
+  // CHECK:   spv.mlir.merge
+  // CHECK: }
+  %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+    %sn = addf %si, %si : f32
+    scf.yield %sn, %sn : f32, f32
+  }
+  // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
+  // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
+  // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
+  // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
+  store %result#0, %arg3[%lb] : memref<10xf32>
+  store %result#1, %arg3[%ub] : memref<10xf32>
+  return
+}
+
+} // end module

diff  --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir
new file mode 100644
index 000000000000..d7c048f517ab
--- /dev/null
+++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir
@@ -0,0 +1,156 @@
+// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
+} {
+
+// CHECK-LABEL: @kernel_simple_selection
+func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) {
+  %value = constant 0.0 : f32
+  %i = constant 0 : index
+
+  // CHECK:       spv.selection {
+  // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[MERGE:\^.*]]
+  // CHECK-NEXT:  [[TRUE]]:
+  // CHECK:         spv.Branch [[MERGE]]
+  // CHECK-NEXT:  [[MERGE]]:
+  // CHECK-NEXT:    spv.mlir.merge
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  spv.Return
+
+  scf.if %arg3 {
+    store %value, %arg2[%i] : memref<10xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: @kernel_nested_selection
+func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) {
+  %i = constant 0 : index
+  %j = constant 9 : index
+
+  // CHECK:       spv.selection {
+  // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE_TOP:\^.*]], [[FALSE_TOP:\^.*]]
+  // CHECK-NEXT:  [[TRUE_TOP]]:
+  // CHECK-NEXT:    spv.selection {
+  // CHECK-NEXT:      spv.BranchConditional {{%.*}}, [[TRUE_NESTED_TRUE_PATH:\^.*]], [[FALSE_NESTED_TRUE_PATH:\^.*]]
+  // CHECK-NEXT:    [[TRUE_NESTED_TRUE_PATH]]:
+  // CHECK:           spv.Branch [[MERGE_NESTED_TRUE_PATH:\^.*]]
+  // CHECK-NEXT:    [[FALSE_NESTED_TRUE_PATH]]:
+  // CHECK:           spv.Branch [[MERGE_NESTED_TRUE_PATH]]
+  // CHECK-NEXT:    [[MERGE_NESTED_TRUE_PATH]]:
+  // CHECK-NEXT:      spv.mlir.merge
+  // CHECK-NEXT:    }
+  // CHECK-NEXT:    spv.Branch [[MERGE_TOP:\^.*]]
+  // CHECK-NEXT:  [[FALSE_TOP]]:
+  // CHECK-NEXT:    spv.selection {
+  // CHECK-NEXT:      spv.BranchConditional {{%.*}}, [[TRUE_NESTED_FALSE_PATH:\^.*]], [[FALSE_NESTED_FALSE_PATH:\^.*]]
+  // CHECK-NEXT:    [[TRUE_NESTED_FALSE_PATH]]:
+  // CHECK:           spv.Branch [[MERGE_NESTED_FALSE_PATH:\^.*]]
+  // CHECK-NEXT:    [[FALSE_NESTED_FALSE_PATH]]:
+  // CHECK:           spv.Branch [[MERGE_NESTED_FALSE_PATH]]
+  // CHECK:         [[MERGE_NESTED_FALSE_PATH]]:
+  // CHECK-NEXT:      spv.mlir.merge
+  // CHECK-NEXT:    }
+  // CHECK-NEXT:    spv.Branch [[MERGE_TOP]]
+  // CHECK-NEXT:  [[MERGE_TOP]]:
+  // CHECK-NEXT:    spv.mlir.merge
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  spv.Return
+
+  scf.if %arg5 {
+    scf.if %arg6 {
+      %value = load %arg3[%i] : memref<10xf32>
+      store %value, %arg4[%i] : memref<10xf32>
+    } else {
+      %value = load %arg4[%i] : memref<10xf32>
+      store %value, %arg3[%i] : memref<10xf32>
+    }
+  } else {
+    scf.if %arg6 {
+      %value = load %arg3[%j] : memref<10xf32>
+      store %value, %arg4[%j] : memref<10xf32>
+    } else {
+      %value = load %arg4[%j] : memref<10xf32>
+      store %value, %arg3[%j] : memref<10xf32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: @simple_if_yield
+func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) {
+  // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
+  // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
+  // CHECK:       spv.selection {
+  // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
+  // CHECK-NEXT:  [[TRUE]]:
+  // CHECK:         %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32
+  // CHECK:         %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32
+  // CHECK-DAG:     spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32
+  // CHECK-DAG:     spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32
+  // CHECK:         spv.Branch ^[[MERGE:.*]]
+  // CHECK-NEXT:  [[FALSE]]:
+  // CHECK:         %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32
+  // CHECK:         %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32
+  // CHECK-DAG:     spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32
+  // CHECK-DAG:     spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32
+  // CHECK:         spv.Branch ^[[MERGE]]
+  // CHECK-NEXT:  ^[[MERGE]]:
+  // CHECK:         spv.mlir.merge
+  // CHECK-NEXT:  }
+  // CHECK-DAG:   %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
+  // CHECK-DAG:   %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
+  // CHECK:       spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
+  // CHECK:       spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
+  // CHECK:       spv.Return
+  %0:2 = scf.if %arg3 -> (f32, f32) {
+    %c0 = constant 0.0 : f32
+    %c1 = constant 1.0 : f32
+    scf.yield %c0, %c1 : f32, f32
+  } else {
+    %c0 = constant 2.0 : f32
+    %c1 = constant 3.0 : f32
+    scf.yield %c1, %c0 : f32, f32
+  }
+  %i = constant 0 : index
+  %j = constant 1 : index
+  store %0#0, %arg2[%i] : memref<10xf32>
+  store %0#1, %arg2[%j] : memref<10xf32>
+  return
+}
+
+// TODO: The transformation should only be legal if VariablePointer capability
+// is supported. This test is still useful to make sure we can handle scf op
+// result with type change.
+func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) {
+  // CHECK-LABEL: @simple_if_yield_type_change
+  // CHECK:       %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>, Function>
+  // CHECK:       spv.selection {
+  // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
+  // CHECK-NEXT:  [[TRUE]]:
+  // CHECK:         spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK:         spv.Branch ^[[MERGE:.*]]
+  // CHECK-NEXT:  [[FALSE]]:
+  // CHECK:         spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK:         spv.Branch ^[[MERGE]]
+  // CHECK-NEXT:  ^[[MERGE]]:
+  // CHECK:         spv.mlir.merge
+  // CHECK-NEXT:  }
+  // CHECK:       %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK:       %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK:       spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32
+  // CHECK:       spv.Return
+  %i = constant 0 : index
+  %value = constant 0.0 : f32
+  %0 = scf.if %arg4 -> (memref<10xf32>) {
+    scf.yield %arg2 : memref<10xf32>
+  } else {
+    scf.yield %arg3 : memref<10xf32>
+  }
+  store %value, %0[%i] : memref<10xf32>
+  return
+}
+
+} // end module


        


More information about the Mlir-commits mailing list