[Mlir-commits] [mlir] [mlir][vector][spirv] Lower vector.maskedload and vector.maskedstore to SPIR-V (PR #74834)

Hsiangkai Wang llvmlistbot at llvm.org
Fri Dec 8 05:19:46 PST 2023


https://github.com/Hsiangkai created https://github.com/llvm/llvm-project/pull/74834

Use spirv.mlir.loop and spirv.mlir.selection to lower vector.maskedload and vector.maskedstore.

>From 4a235552a1744263c120cc49a0f28e4422811518 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 30 Nov 2023 14:09:00 +0000
Subject: [PATCH] [mlir][vector][spirv] Lower vector.maskedload and
 vector.maskedstore to SPIR-V

Use spirv.mlir.loop and spirv.mlir.selection to lower vector.maskedload and
vector.maskedstore.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 325 +++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        | 111 ++++++
 2 files changed, 435 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e48f29a4f1702..b32c004e28a1e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -647,6 +647,328 @@ struct VectorStoreOpConverter final
   }
 };
 
+mlir::spirv::LoopOp createSpirvLoop(ConversionPatternRewriter &rewriter,
+                                    Location loc) {
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+  loopOp.addEntryAndMergeBlock();
+
+  auto &loopBody = loopOp.getBody();
+  // Create header block.
+  loopBody.getBlocks().insert(std::next(loopBody.begin(), 1), new Block());
+  // Create continue block.
+  loopBody.getBlocks().insert(std::prev(loopBody.end(), 2), new Block());
+
+  return loopOp;
+}
+
+mlir::spirv::SelectionOp
+createSpirvSelection(ConversionPatternRewriter &rewriter, Location loc) {
+  auto selectionOp =
+      rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+  auto &loopBody = selectionOp.getBody();
+  // Create header block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  // Create if-true block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  // Create merge block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  rewriter.create<spirv::MergeOp>(loc);
+
+  return selectionOp;
+}
+
+Value addOffsetToIndices(ConversionPatternRewriter &rewriter, Location loc,
+                         SmallVectorImpl<Value> &indices, const Value offset,
+                         const SPIRVTypeConverter &typeConverter,
+                         const MemRefType memrefType, const Value base) {
+  indices.back() = rewriter.create<spirv::IAddOp>(loc, indices.back(), offset);
+  return spirv::getElementPtr(typeConverter, memrefType, base, indices, loc,
+                              rewriter);
+}
+
+Value extractMaskBit(ConversionPatternRewriter &rewriter, Location loc,
+                     Value mask, Value offset) {
+  return rewriter.create<spirv::VectorExtractDynamicOp>(
+      loc, rewriter.getI1Type(), mask, offset);
+}
+
+Value extractVectorElement(ConversionPatternRewriter &rewriter, Location loc,
+                           Type type, Value vector, Value offset) {
+  return rewriter.create<spirv::VectorExtractDynamicOp>(loc, type, vector,
+                                                        offset);
+}
+
+Value createConstantInteger(ConversionPatternRewriter &rewriter, Location loc,
+                            int32_t value) {
+  auto i32Type = rewriter.getI32Type();
+  return rewriter.create<spirv::ConstantOp>(loc, i32Type,
+                                            IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedload to spirv dialect.
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %buffer = spirv.Variable
+///   spirv.mlir.loop {
+///     spirv.Branch ^bb1(0, %buffer)
+///   ^bb1(%i: i32, %partial: vector):
+///     %m = spirv.VectorExtractDynamic %mask[%i]
+///     %p = spirv.VectorExtractDynamic %pass_thru[%i]
+///     %value = spirv.Load
+///     %s = spirv.Select %m, %value, %p
+///     %v = spirv.VectorInsertDynamic %s, %partial[%i]
+///     spirv.Store %buffer, %v
+///     spirv.Branch ^bb2(%i, %v)
+///   ^bb2(%i: i32, %partial: vector):
+///     %update_i = spirv.IAdd %i, 1
+///     %cond = spirv.SLessThan %update_i, %veclen
+///     spirv.BranchConditional %cond, ^bb1(%update_i, %partial), ^bb3
+///   ^bb3:
+///     spirv.mlir.merge
+///   }
+///   %ret = spirv.Load %buffer
+///   return %ret
+///
+struct VectorMaskedLoadOpConverter final
+    : public OpConversionPattern<vector::MaskedLoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = maskedLoadOp.getMemRefType();
+    if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+      return failure();
+
+    VectorType maskVType = maskedLoadOp.getMaskVectorType();
+    if (maskVType.getRank() != 1)
+      return failure();
+    if (maskVType.getShape().size() != 1)
+      return failure();
+
+    // Create a local variable to store the loaded value.
+    auto loc = maskedLoadOp.getLoc();
+    auto vectorType = maskedLoadOp.getVectorType();
+    auto pointerType =
+        spirv::PointerType::get(vectorType, spirv::StorageClass::Function);
+    auto alloc = rewriter.create<spirv::VariableOp>(
+        loc, pointerType, spirv::StorageClass::Function,
+        /*initializer=*/nullptr);
+
+    // Create constants for the loop.
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    auto emptyVector = rewriter.create<spirv::ConstantOp>(
+        loc, vectorType, rewriter.getZeroAttr(vectorType));
+
+    // Construct a loop to go through the mask value
+    auto loopOp = createSpirvLoop(rewriter, loc);
+
+    auto *headerBlock = loopOp.getHeaderBlock();
+    auto *continueBlock = loopOp.getContinueBlock();
+
+    auto i32Type = rewriter.getI32Type();
+    BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+    BlockArgument partialVector = headerBlock->addArgument(vectorType, loc);
+    BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+    BlockArgument continueVector = continueBlock->addArgument(vectorType, loc);
+
+    // Insert code into loop entry block
+    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+
+    // Header block needs two arguments: induction variable, updated vector
+    rewriter.create<spirv::BranchOp>(loc, headerBlock,
+                                     ArrayRef<Value>({zero, emptyVector}));
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointToEnd(headerBlock);
+    auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+    auto scalarType = memrefType.getElementType();
+    auto passThruValule = extractVectorElement(rewriter, loc, scalarType,
+                                               adaptor.getPassThru(), indVar);
+
+    // Load base[indVar]
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto indices = llvm::to_vector<4>(adaptor.getIndices());
+    auto updatedAccessChain =
+        addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+                           memrefType, adaptor.getBase());
+    auto loadScalar =
+        rewriter.create<spirv::LoadOp>(loc, scalarType, updatedAccessChain);
+
+    // Select the loaded value or pass-through according to the mask bit.
+    auto valueToInsert = rewriter.create<spirv::SelectOp>(
+        loc, scalarType, maskBit, loadScalar, passThruValule);
+
+    // Insert the selected value to output vector.
+    auto updatedVector = rewriter.create<spirv::VectorInsertDynamicOp>(
+        loc, vectorType, partialVector, valueToInsert, indVar);
+    rewriter.create<spirv::StoreOp>(loc, alloc, updatedVector);
+    rewriter.create<spirv::BranchOp>(loc, continueBlock,
+                                     ArrayRef<Value>({indVar, updatedVector}));
+
+    // Insert code into continue block
+    rewriter.setInsertionPointToEnd(continueBlock);
+
+    // Update induction variable.
+    auto updatedIndVar =
+        rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+    // Check if the induction variable < length(mask)
+    auto cmpOp =
+        rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+    auto *mergeBlock = loopOp.getMergeBlock();
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, cmpOp, headerBlock,
+        ArrayRef<Value>({updatedIndVar, continueVector}), mergeBlock,
+        std::nullopt);
+
+    // Insert code after loop
+    rewriter.setInsertionPointAfter(loopOp);
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(maskedLoadOp, alloc);
+
+    return success();
+  }
+};
+
+/// Convert vector.maskedstore to spirv dialect.
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   spirv.mlir.loop {
+///     spirv.Branch ^bb1(0)
+///   ^bb1(%i: i32):
+///     %m = spirv.VectorExtractDynamic %mask[%i]
+///     spirv.mlir.selection {
+///       spirv.BranchConditional %m, ^if_bb1, ^if_bb2
+///       ^if_bb1:
+///         %v = spirv.VectorExtractDynamic %value[%i]
+///         spirv.Store %base[%i], %v
+///         spirv.Branch ^if_bb2
+///       ^if_bb2:
+///         spirv.mlir.merge
+///     }
+///     spirv.Branch ^bb2(%i)
+///   ^bb2(%i: i32):
+///     %update_i = spirv.IAdd %i, 1
+///     %cond = spirv.SLessThan %update_i, %veclen
+///     spirv.BranchConditional %cond, ^bb1, ^bb3
+///   ^bb3:
+///     spirv.mlir.merge
+///   }
+///   return
+///
+struct VectorMaskedStoreOpConverter final
+    : public OpConversionPattern<vector::MaskedStoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = maskedStoreOp.getMemRefType();
+    if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+      return failure();
+
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getRank() != 1)
+      return failure();
+    if (maskVType.getShape().size() != 1)
+      return failure();
+
+    // Create constants.
+    auto loc = maskedStoreOp.getLoc();
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    // Construct a loop to go through the mask value
+    auto loopOp = createSpirvLoop(rewriter, loc);
+    auto *headerBlock = loopOp.getHeaderBlock();
+    auto *continueBlock = loopOp.getContinueBlock();
+
+    auto i32Type = rewriter.getI32Type();
+    BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+    BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+
+    // Insert code into loop entry block
+    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+    rewriter.create<spirv::BranchOp>(loc, headerBlock, ArrayRef<Value>({zero}));
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointToEnd(headerBlock);
+    auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+    auto selectionOp = createSpirvSelection(rewriter, loc);
+    auto *selectionHeaderBlock = selectionOp.getHeaderBlock();
+    auto *selectionMergeBlock = selectionOp.getMergeBlock();
+    auto *selectionTrueBlock = &(*std::next(selectionOp.getBody().begin(), 1));
+
+    // Insert code into selection header block
+    rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, maskBit, selectionTrueBlock, std::nullopt, selectionMergeBlock,
+        std::nullopt);
+
+    // Insert code into selection true block
+    rewriter.setInsertionPointToEnd(selectionTrueBlock);
+    auto scalarType = memrefType.getElementType();
+    auto extractedStoreValue = extractVectorElement(
+        rewriter, loc, scalarType, adaptor.getValueToStore(), indVar);
+
+    // Store base[indVar]
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto indices = llvm::to_vector<4>(adaptor.getIndices());
+    auto updatedAccessChain =
+        addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+                           memrefType, adaptor.getBase());
+    rewriter.create<spirv::StoreOp>(loc, updatedAccessChain,
+                                    extractedStoreValue);
+    rewriter.create<spirv::BranchOp>(loc, selectionMergeBlock, std::nullopt);
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointAfter(selectionOp);
+    rewriter.create<spirv::BranchOp>(loc, continueBlock,
+                                     ArrayRef<Value>({indVar}));
+
+    // Insert code into loop continue block
+    rewriter.setInsertionPointToEnd(continueBlock);
+
+    // Update induction variable.
+    auto updatedIndVar =
+        rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+    // Check if the induction variable < length(mask)
+    auto cmpOp =
+        rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+    auto *mergeBlock = loopOp.getMergeBlock();
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, cmpOp, headerBlock, ArrayRef<Value>({updatedIndVar}), mergeBlock,
+        std::nullopt);
+
+    // Insert code after loop
+    rewriter.setInsertionPointAfter(loopOp);
+    rewriter.replaceOp(maskedStoreOp, loopOp);
+
+    return success();
+  }
+};
+
 struct VectorReductionToIntDotProd final
     : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -821,7 +1143,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc..bc9e92981644b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -805,4 +805,115 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageB
   return
 }
 
+// CHECK-LABEL:  @vector_maskedload
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:    %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+//       CHECK:    %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[CST_F0:.*]] = arith.constant 0.000000e+00 : f32
+//       CHECK:    %[[S4:.*]] = spirv.CompositeConstruct %[[CST_F0]], %[[CST_F0]], %[[CST_F0]], %[[CST_F0]] : (f32, f32, f32, f32) -> vector<4xf32>
+//       CHECK:    %[[S5:.*]] = spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
+//       CHECK:    %[[C0_1:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C1_1:.*]] = spirv.Constant 1 : i32
+//       CHECK:    %[[C4_1:.*]] = spirv.Constant 4 : i32
+//       CHECK:    %[[CV0:.*]] = spirv.Constant dense<0.000000e+00> : vector<4xf32>
+//       CHECK:    spirv.mlir.loop {
+//       CHECK:      spirv.Branch ^bb1(%[[C0_1]], %[[CV0]] : i32, vector<4xf32>)
+//       CHECK:    ^bb1(%[[S7:.*]]: i32, %[[S8:.*]]: vector<4xf32>):  // 2 preds: ^bb0, ^bb2
+//       CHECK:      %[[S9:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S7]]] : vector<4xi1>, i32
+//       CHECK:      %[[S10:.*]] = spirv.VectorExtractDynamic %[[S4]][%[[S7]]] : vector<4xf32>, i32
+//       CHECK:      %[[S11:.*]] = spirv.IAdd %[[S2]], %[[S7]] : i32
+//       CHECK:      %[[C0_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:      %[[C0_3:.*]] = spirv.Constant 0 : i32
+//       CHECK:      %[[C5:.*]] = spirv.Constant 5 : i32
+//       CHECK:      %[[S12:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+//       CHECK:      %[[S13:.*]] = spirv.IAdd %[[C0_3]], %[[S12]] : i32
+//       CHECK:      %[[C1_2:.*]] = spirv.Constant 1 : i32
+//       CHECK:      %[[S14:.*]] = spirv.IMul %[[C1_2]], %[[S11]] : i32
+//       CHECK:      %[[S15:.*]] = spirv.IAdd %[[S13]], %[[S14]] : i32
+//       CHECK:      %[[S16:.*]] = spirv.AccessChain %[[S0]][%[[C0_2]], %[[S15]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:      %[[S17:.*]] = spirv.Load "StorageBuffer" %[[S16]] : f32
+//       CHECK:      %[[S18:.*]] = spirv.Select %[[S9]], %[[S17]], %[[S10]] : i1, f32
+//       CHECK:      %[[S19:.*]] = spirv.VectorInsertDynamic %[[S18]], %[[S8]][%[[S7]]] : vector<4xf32>, i32
+//       CHECK:      spirv.Store "Function" %[[S5]], %[[S19]] : vector<4xf32>
+//       CHECK:      spirv.Branch ^bb2(%[[S7]], %[[S19]] : i32, vector<4xf32>)
+//       CHECK:    ^bb2(%[[S20:.*]]: i32, %[[S21:.*]]: vector<4xf32>):  // pred: ^bb1
+//       CHECK:      %[[S22:.*]] = spirv.IAdd %[[S20]], %[[C1_1]] : i32
+//       CHECK:      %[[S23:.*]] = spirv.SLessThan %[[S22]], %[[C4_1]] : i32
+//       CHECK:      spirv.BranchConditional %[[S23]], ^bb1(%[[S22]], %[[S21]] : i32, vector<4xf32>), ^bb3
+//       CHECK:    ^bb3:  // pred: ^bb2
+//       CHECK:      spirv.mlir.merge
+//       CHECK:    }
+//       CHECK:    %[[S6:.*]] = spirv.Load "Function" %[[S5]] : vector<4xf32>
+//       CHECK:    return %[[S6]] : vector<4xf32>
+//       CHECK:  }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+//       CHECK:    %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+//       CHECK:    %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[C0_1:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C1_1:.*]] = spirv.Constant 1 : i32
+//       CHECK:    %[[C4_1:.*]] = spirv.Constant 4 : i32
+//       CHECK:    spirv.mlir.loop {
+//       CHECK:      spirv.Branch ^bb1(%[[C0_1]] : i32)
+//       CHECK:    ^bb1(%[[S4:.*]]: i32):  // 2 preds: ^bb0, ^bb2
+//       CHECK:      %[[S5:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S4]]] : vector<4xi1>, i32
+//       CHECK:      spirv.mlir.selection {
+//       CHECK:        spirv.BranchConditional %[[S5]], ^bb1, ^bb2
+//       CHECK:      ^bb1:  // pred: ^bb0
+//       CHECK:        %[[S9:.*]] = spirv.VectorExtractDynamic %[[ARG1]][%[[S4]]] : vector<4xf32>, i32
+//       CHECK:        %[[S10:.*]] = spirv.IAdd %[[S2]], %[[S4]] : i32
+//       CHECK:        %[[C0_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:        %[[C1_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:        %[[C5:.*]] = spirv.Constant 5 : i32
+//       CHECK:        %[[S11:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+//       CHECK:        %[[S12:.*]] = spirv.IAdd %[[C1_2]], %[[S11]] : i32
+//       CHECK:        %[[C1_3:.*]] = spirv.Constant 1 : i32
+//       CHECK:        %[[S13:.*]] = spirv.IMul %[[C1_3]], %[[S10]] : i32
+//       CHECK:        %[[S14:.*]] = spirv.IAdd %[[S12]], %[[S13]] : i32
+//       CHECK:        %[[S15:.*]] = spirv.AccessChain %[[S0]][%[[C0_2]], %[[S14]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:        spirv.Store "StorageBuffer" %[[S15]], %[[S9]] : f32
+//       CHECK:        spirv.Branch ^bb2
+//       CHECK:      ^bb2:  // 2 preds: ^bb0, ^bb1
+//       CHECK:        spirv.mlir.merge
+//       CHECK:      }
+//       CHECK:      spirv.Branch ^bb2(%[[S4]] : i32)
+//       CHECK:    ^bb2(%[[S6:.*]]: i32):  // pred: ^bb1
+//       CHECK:      %[[S7:.*]] = spirv.IAdd %[[S6]], %[[C1_1]] : i32
+//       CHECK:      %[[S8:.*]] = spirv.SLessThan %[[S7]], %[[C4_1]] : i32
+//       CHECK:      spirv.BranchConditional %[[S8]], ^bb1(%[[S7]] : i32), ^bb3
+//       CHECK:    ^bb3:  // pred: ^bb2
+//       CHECK:      spirv.mlir.merge
+//       CHECK:    }
+//       CHECK:    return
+//       CHECK:  }
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32>
+  return
+}
+
 } // end module



More information about the Mlir-commits mailing list