[Mlir-commits] [mlir] 0670f85 - [mlir][spirv] Add support for lowering scf.for scf/if with return value

Thomas Raoux llvmlistbot at llvm.org
Wed Jul 1 17:11:45 PDT 2020


Author: Thomas Raoux
Date: 2020-07-01T17:08:08-07:00
New Revision: 0670f855a7d8f48a86d67d83e6be45fab016f080

URL: https://github.com/llvm/llvm-project/commit/0670f855a7d8f48a86d67d83e6be45fab016f080
DIFF: https://github.com/llvm/llvm-project/commit/0670f855a7d8f48a86d67d83e6be45fab016f080.diff

LOG: [mlir][spirv] Add support for lowering scf.for scf/if with return value

This allow lowering to support scf.for and scf.if with results. As right now
spv region operations don't have return value the results are demoted to
Function memory. We create one allocation per result right before the region
and store the yield values in it. Then we can load back the value from
allocation to be able to use the results.

Differential Revision: https://reviews.llvm.org/D82246

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/test/Conversion/GPUToSPIRV/if.mlir
    mlir/test/Conversion/GPUToSPIRV/loop.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
index 95173717ceec..59fa0082e0cf 100644
--- a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
@@ -21,11 +21,23 @@ class Pass;
 // Owning list of rewriting patterns.
 class OwningRewritePatternList;
 class SPIRVTypeConverter;
+struct ScfToSPIRVContextImpl;
+
+struct ScfToSPIRVContext {
+  ScfToSPIRVContext();
+  ~ScfToSPIRVContext();
+
+  ScfToSPIRVContextImpl *getImpl() { return impl.get(); }
+
+private:
+  std::unique_ptr<ScfToSPIRVContextImpl> impl;
+};
 
 /// Collects a set of patterns to lower from scf.for, scf.if, and
 /// loop.terminator to CFG operations within the SPIR-V dialect.
 void populateSCFToSPIRVPatterns(MLIRContext *context,
                                 SPIRVTypeConverter &typeConverter,
+                                ScfToSPIRVContext &scfToSPIRVContext,
                                 OwningRewritePatternList &patterns);
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index c3bda25f0347..ccaecd9c9189 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -58,9 +58,10 @@ void GPUToSPIRVPass::runOnOperation() {
       spirv::SPIRVConversionTarget::get(targetAttr);
 
   SPIRVTypeConverter typeConverter(targetAttr);
+  ScfToSPIRVContext scfContext;
   OwningRewritePatternList patterns;
   populateGPUToSPIRVPatterns(context, typeConverter, patterns);
-  populateSCFToSPIRVPatterns(context, typeConverter, patterns);
+  populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
   if (failed(applyFullConversion(kernelModules, *target, patterns)))

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index a6a08b10fc63..b8eb87c80368 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -18,12 +18,44 @@
 
 using namespace mlir;
 
+namespace mlir {
+struct ScfToSPIRVContextImpl {
+  // Map between the spirv region control flow operation (spv.loop or
+  // spv.selection) to the VariableOp created to store the region results. The
+  // order of the VariableOp matches the order of the results.
+  DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
+};
+} // namespace mlir
+
+/// We use ScfToSPIRVContext to store information about the lowering of the scf
+/// region that need to be used later on. When we lower scf.for/scf.if we create
+/// VariableOp to store the results. We need to keep track of the VariableOp
+/// created as we need to insert stores into them when lowering Yield. Those
+/// StoreOp cannot be created earlier as they may use a 
diff erent type than
+/// yield operands.
+ScfToSPIRVContext::ScfToSPIRVContext() {
+  impl = std::make_unique<ScfToSPIRVContextImpl>();
+}
+ScfToSPIRVContext::~ScfToSPIRVContext() = default;
+
 namespace {
+/// Common class for all vector to GPU patterns.
+template <typename OpTy>
+class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
+public:
+  SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
+                          ScfToSPIRVContextImpl *scfToSPIRVContext)
+      : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
+        scfToSPIRVContext(scfToSPIRVContext) {}
+
+protected:
+  ScfToSPIRVContextImpl *scfToSPIRVContext;
+};
 
 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
-class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
+class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
 public:
-  using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
+  using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
 
   LogicalResult
   matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
@@ -32,29 +64,54 @@ class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
 
 /// Pattern to convert a scf::IfOp within kernel functions into
 /// spirv::SelectionOp.
-class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
+class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
 public:
-  using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
+  using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
 
   LogicalResult
   matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Pattern to erase a scf::YieldOp.
-class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
+class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
 public:
-  using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
+  using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
 
   LogicalResult
   matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.eraseOp(terminatorOp);
-    return success();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 } // namespace
 
+/// Helper function to replaces SCF op outputs with SPIR-V variable loads.
+/// We create VariableOp to handle the results value of the control flow region.
+/// spv.loop/spv.selection currently don't yield value. Right after the loop
+/// we load the value from the allocation and use it as the SCF op result.
+template <typename ScfOp, typename OpTy>
+static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
+                                  SPIRVTypeConverter &typeConverter,
+                                  ConversionPatternRewriter &rewriter,
+                                  ScfToSPIRVContextImpl *scfToSPIRVContext) {
+
+  Location loc = scfOp.getLoc();
+  auto &allocas = scfToSPIRVContext->outputVars[newOp];
+  SmallVector<Value, 8> resultValue;
+  for (Value result : scfOp.results()) {
+    auto convertedType = typeConverter.convertType(result.getType());
+    auto pointerType =
+        spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
+    rewriter.setInsertionPoint(newOp);
+    auto alloc = rewriter.create<spirv::VariableOp>(
+        loc, pointerType, spirv::StorageClass::Function,
+        /*initializer=*/nullptr);
+    allocas.push_back(alloc);
+    rewriter.setInsertionPointAfter(newOp);
+    Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
+    resultValue.push_back(loadResult);
+  }
+  rewriter.replaceOp(scfOp, resultValue);
+}
+
 //===----------------------------------------------------------------------===//
 // scf::ForOp.
 //===----------------------------------------------------------------------===//
@@ -83,6 +140,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   // Create the new induction variable to use.
   BlockArgument newIndVar =
       header->addArgument(forOperands.lowerBound().getType());
+  for (Value arg : forOperands.initArgs())
+    header->addArgument(arg.getType());
   Block *body = forOp.getBody();
 
   // Apply signature conversion to the body of the forOp. It has a single block,
@@ -91,29 +150,28 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   TypeConverter::SignatureConversion signatureConverter(
       body->getNumArguments());
   signatureConverter.remapInput(0, newIndVar);
-  FailureOr<Block *> newBody = rewriter.convertRegionTypes(
-      &forOp.getLoopBody(), typeConverter, &signatureConverter);
-  if (failed(newBody))
-    return failure();
-  body = *newBody;
-
-  // Delete the loop terminator.
-  rewriter.eraseOp(body->getTerminator());
+  for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
+    signatureConverter.remapInput(i, header->getArgument(i));
+  body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
+                                           signatureConverter);
 
   // Move the blocks from the forOp into the loopOp. This is the body of the
   // loopOp.
   rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
                               std::next(loopOp.body().begin(), 2));
 
+  SmallVector<Value, 8> args(1, forOperands.lowerBound());
+  args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
   // Branch into it from the entry.
   rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
-  rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
+  rewriter.create<spirv::BranchOp>(loc, header, args);
 
   // Generate the rest of the loop header.
   rewriter.setInsertionPointToEnd(header);
   auto *mergeBlock = loopOp.getMergeBlock();
   auto cmpOp = rewriter.create<spirv::SLessThanOp>(
       loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
+
   rewriter.create<spirv::BranchConditionalOp>(
       loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
 
@@ -127,7 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
       loc, newIndVar.getType(), newIndVar, forOperands.step());
   rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
 
-  rewriter.eraseOp(forOp);
+  replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
+                        scfToSPIRVContext);
   return success();
 }
 
@@ -179,13 +238,45 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
                                               thenBlock, ArrayRef<Value>(),
                                               elseBlock, ArrayRef<Value>());
 
-  rewriter.eraseOp(ifOp);
+  replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
+                        scfToSPIRVContext);
+  return success();
+}
+
+/// Yield is lowered to stores to the VariableOp created during lowering of the
+/// parent region. For loops we also need to update the branch looping back to
+/// the header with the loop carried values.
+LogicalResult TerminatorOpConversion::matchAndRewrite(
+    scf::YieldOp terminatorOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  // If the region is return values, store each value into the associated
+  // VariableOp created during lowering of the parent region.
+  if (!operands.empty()) {
+    auto loc = terminatorOp.getLoc();
+    auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()];
+    assert(allocas.size() == operands.size());
+    for (unsigned i = 0, e = operands.size(); i < e; i++)
+      rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
+    if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) {
+      // For loops we also need to update the branch jumping back to the header.
+      auto br =
+          cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
+      SmallVector<Value, 8> args(br.getBlockArguments());
+      args.append(operands.begin(), operands.end());
+      rewriter.setInsertionPoint(br);
+      rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
+                                       args);
+      rewriter.eraseOp(br);
+    }
+  }
+  rewriter.eraseOp(terminatorOp);
   return success();
 }
 
 void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
                                       SPIRVTypeConverter &typeConverter,
+                                      ScfToSPIRVContext &scfToSPIRVContext,
                                       OwningRewritePatternList &patterns) {
   patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
-      context, typeConverter);
+      context, typeConverter, scfToSPIRVContext.getImpl());
 }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 963f5393c572..03ce62807a8b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -589,9 +589,6 @@ StorageClass PointerType::getStorageClass() const {
 
 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                                 Optional<StorageClass> storage) {
-  if (storage)
-    assert(*storage == getStorageClass() && "inconsistent storage class!");
-
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
   getPointeeType().cast<SPIRVType>().getExtensions(extensions,
@@ -604,9 +601,6 @@ void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
 void PointerType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     Optional<StorageClass> storage) {
-  if (storage)
-    assert(*storage == getStorageClass() && "inconsistent storage class!");
-
   // Use this pointer type's storage class because this pointer indicates we are
   // using the pointee type in that specific storage class.
   getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,

diff  --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir
index 1e8164cb310d..81a7f6d32b91 100644
--- a/mlir/test/Conversion/GPUToSPIRV/if.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir
@@ -89,5 +89,79 @@ module attributes {
       }
       gpu.return
     }
+    // CHECK-LABEL: @simple_if_yield
+    gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel
+    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
+      // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
+      // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
+      // CHECK:       spv.selection {
+      // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
+      // CHECK-NEXT:  [[TRUE]]:
+      // CHECK:         %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32
+      // CHECK:         %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32
+      // CHECK-DAG:     spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32
+      // CHECK-DAG:     spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32
+      // CHECK:         spv.Branch ^[[MERGE:.*]]
+      // CHECK-NEXT:  [[FALSE]]:
+      // CHECK:         %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32
+      // CHECK:         %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32
+      // CHECK-DAG:     spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32
+      // CHECK-DAG:     spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32
+      // CHECK:         spv.Branch ^[[MERGE]]
+      // CHECK-NEXT:  ^[[MERGE]]:
+      // CHECK:         spv._merge
+      // CHECK-NEXT:  }
+      // CHECK-DAG:   %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
+      // CHECK-DAG:   %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
+      // CHECK:       spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
+      // CHECK:       spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
+      // CHECK:       spv.Return
+      %0:2 = scf.if %arg3 -> (f32, f32) {
+        %c0 = constant 0.0 : f32
+        %c1 = constant 1.0 : f32
+        scf.yield %c0, %c1 : f32, f32
+      } else {
+        %c0 = constant 2.0 : f32
+        %c1 = constant 3.0 : f32
+        scf.yield %c1, %c0 : f32, f32
+      }
+      %i = constant 0 : index
+      %j = constant 1 : index
+      store %0#0, %arg2[%i] : memref<10xf32>
+      store %0#1, %arg2[%j] : memref<10xf32>
+      gpu.return
+    }
+    // TODO(thomasraoux): The transformation should only be legal if
+    // VariablePointer capability is supported. This test is still useful to
+    // make sure we can handle scf op result with type change.
+    // CHECK-LABEL: @simple_if_yield_type_change
+    // CHECK:       %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>, Function>
+    // CHECK:       spv.selection {
+    // CHECK-NEXT:    spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
+    // CHECK-NEXT:  [[TRUE]]:
+    // CHECK:         spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+    // CHECK:         spv.Branch ^[[MERGE:.*]]
+    // CHECK-NEXT:  [[FALSE]]:
+    // CHECK:         spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+    // CHECK:         spv.Branch ^[[MERGE]]
+    // CHECK-NEXT:  ^[[MERGE]]:
+    // CHECK:         spv._merge
+    // CHECK-NEXT:  }
+    // CHECK:       %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+    // CHECK:       %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+    // CHECK:       spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32
+    // CHECK:       spv.Return
+    gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel
+    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
+      %i = constant 0 : index
+      %value = constant 0.0 : f32
+      %0 = scf.if %arg4 -> (memref<10xf32>) {
+        scf.yield %arg2 : memref<10xf32>
+      } else {
+        scf.yield %arg3 : memref<10xf32>
+      }
+      store %value, %0[%i] : memref<10xf32>
+      gpu.return
+    }
   }
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
index 7c5df798438f..2205c60f875f 100644
--- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
@@ -51,5 +51,48 @@ module attributes {
       }
       gpu.return
     }
+
+
+    // CHECK-LABEL: @loop_yield
+    gpu.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel
+    attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
+      // CHECK: %[[LB:.*]] = spv.constant 4 : i32
+      %lb = constant 4 : index
+      // CHECK: %[[UB:.*]] = spv.constant 42 : i32
+      %ub = constant 42 : index
+      // CHECK: %[[STEP:.*]] = spv.constant 2 : i32
+      %step = constant 2 : index
+      // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32
+      %s0 = constant 0.0 : f32
+      // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32
+      %s1 = constant 1.0 : f32
+      // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
+      // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
+      // CHECK: spv.loop {
+      // CHECK:   spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32)
+      // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32):
+      // CHECK:   %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
+      // CHECK:   spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
+      // CHECK: ^[[BODY]]:
+      // CHECK:   %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32
+      // CHECK-DAG:   %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
+      // CHECK-DAG:   spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32
+      // CHECK-DAG:   spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32
+      // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32)
+      // CHECK: ^[[MERGE]]:
+      // CHECK:   spv._merge
+      // CHECK: }
+      %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+        %sn = addf %si, %si : f32
+        scf.yield %sn, %sn : f32, f32
+      }
+      // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
+      // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
+      // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
+      // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
+      store %result#0, %arg3[%lb] : memref<10xf32>
+      store %result#1, %arg3[%ub] : memref<10xf32>
+      gpu.return
+    }
   }
 }


        


More information about the Mlir-commits mailing list