[Mlir-commits] [mlir] [mlir][vector][spirv] Lower vector.maskedload and vector.maskedstore to SPIR-V (PR #74834)

Hsiangkai Wang llvmlistbot at llvm.org
Mon Dec 11 08:56:19 PST 2023


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/74834

>From 82b3d7bc550c23a89d557cf160880f2403faec26 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 30 Nov 2023 14:09:00 +0000
Subject: [PATCH] [mlir][vector][spirv] Lower vector.maskedload and
 vector.maskedstore to SPIR-V

In this patch, it will lower

vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru

to

%value = spirv.load %base[%idx_0, %idx_1]
spirv.select %mask, %value, %pass_thru

It will lower

vector.maskedstore %base[%idx_0, %idx_1], %mask, %value

to

spirv.mlir.loop {
  spirv.Branch ^bb1(0)
^bb1(%i: i32):
  %m = spirv.VectorExtractDynamic %mask[%i]
  spirv.mlir.selection {
    spirv.BranchConditional %m, ^if_bb1, ^if_bb2
    ^if_bb1:
      %v = spirv.VectorExtractDynamic %value[%i]
      spirv.Store %base[%i], %v
      spirv.Branch ^if_bb2
    ^if_bb2:
      spirv.mlir.merge
  }
  spirv.Branch ^bb2(%i)
^bb2(%i: i32):
  %update_i = spirv.IAdd %i, 1
  %cond = spirv.SLessThan %update_i, %veclen
  spirv.BranchConditional %cond, ^bb1, ^bb3
^bb3:
  spirv.mlir.merge
}
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 240 +++++++++++++++++-
 .../VectorToSPIRV/vector-to-spirv.mlir        |  90 +++++++
 2 files changed, 329 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e48f29a4f1702..827126a68f6fc 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -647,6 +647,243 @@ struct VectorStoreOpConverter final
   }
 };
 
+mlir::spirv::LoopOp createSpirvLoop(ConversionPatternRewriter &rewriter,
+                                    Location loc) {
+  auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
+  loopOp.addEntryAndMergeBlock();
+
+  auto &loopBody = loopOp.getBody();
+  // Create header block.
+  loopBody.getBlocks().insert(std::next(loopBody.begin(), 1), new Block());
+  // Create continue block.
+  loopBody.getBlocks().insert(std::prev(loopBody.end(), 2), new Block());
+
+  return loopOp;
+}
+
+mlir::spirv::SelectionOp
+createSpirvSelection(ConversionPatternRewriter &rewriter, Location loc) {
+  auto selectionOp =
+      rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+  auto &loopBody = selectionOp.getBody();
+  // Create header block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  // Create if-true block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  // Create merge block.
+  rewriter.createBlock(&loopBody, loopBody.end());
+  rewriter.create<spirv::MergeOp>(loc);
+
+  return selectionOp;
+}
+
+Value addOffsetToIndices(ConversionPatternRewriter &rewriter, Location loc,
+                         SmallVectorImpl<Value> &indices, const Value offset,
+                         const SPIRVTypeConverter &typeConverter,
+                         const MemRefType memrefType, const Value base) {
+  indices.back() = rewriter.create<spirv::IAddOp>(loc, indices.back(), offset);
+  return spirv::getElementPtr(typeConverter, memrefType, base, indices, loc,
+                              rewriter);
+}
+
+Value extractMaskBit(ConversionPatternRewriter &rewriter, Location loc,
+                     Value mask, Value offset) {
+  return rewriter.create<spirv::VectorExtractDynamicOp>(
+      loc, rewriter.getI1Type(), mask, offset);
+}
+
+Value extractVectorElement(ConversionPatternRewriter &rewriter, Location loc,
+                           Type type, Value vector, Value offset) {
+  return rewriter.create<spirv::VectorExtractDynamicOp>(loc, type, vector,
+                                                        offset);
+}
+
+Value createConstantInteger(ConversionPatternRewriter &rewriter, Location loc,
+                            int32_t value) {
+  auto i32Type = rewriter.getI32Type();
+  return rewriter.create<spirv::ConstantOp>(loc, i32Type,
+                                            IntegerAttr::get(i32Type, value));
+}
+
+/// Convert vector.maskedload to spirv dialect.
+///
+/// Before:
+///
+///   vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
+///
+/// After:
+///
+///   %value = spirv.load %base[%idx_0, %idx_1]
+///   spirv.select %mask, %value, %pass_thru
+///
+struct VectorMaskedLoadOpConverter final
+    : public OpConversionPattern<vector::MaskedLoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = maskedLoadOp.getMemRefType();
+    auto attr =
+        dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
+    if (!attr)
+      return rewriter.notifyMatchFailure(
+          maskedLoadOp, "expected spirv.storage_class memory space");
+
+    auto loc = maskedLoadOp.getLoc();
+    auto vectorType = maskedLoadOp.getVectorType();
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Value accessChain =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
+                             adaptor.getIndices(), loc, rewriter);
+    if (!accessChain)
+      return rewriter.notifyMatchFailure(
+          maskedLoadOp, "failed to get memref element pointer");
+
+    spirv::StorageClass storageClass = attr.getValue();
+    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
+    Value castedAccessChain =
+        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+    auto load =
+        rewriter.create<spirv::LoadOp>(loc, vectorType, castedAccessChain);
+
+    auto loadedValue = rewriter.create<spirv::SelectOp>(
+        loc, adaptor.getMask(), load, adaptor.getPassThru());
+
+    rewriter.replaceOp(maskedLoadOp, loadedValue);
+
+    return success();
+  }
+};
+
+/// Convert vector.maskedstore to spirv dialect.
+///
+/// Before:
+///
+///   vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
+///
+/// After:
+///
+///   spirv.mlir.loop {
+///     spirv.Branch ^bb1(0)
+///   ^bb1(%i: i32):
+///     %m = spirv.VectorExtractDynamic %mask[%i]
+///     spirv.mlir.selection {
+///       spirv.BranchConditional %m, ^if_bb1, ^if_bb2
+///       ^if_bb1:
+///         %v = spirv.VectorExtractDynamic %value[%i]
+///         spirv.Store %base[%i], %v
+///         spirv.Branch ^if_bb2
+///       ^if_bb2:
+///         spirv.mlir.merge
+///     }
+///     spirv.Branch ^bb2(%i)
+///   ^bb2(%i: i32):
+///     %update_i = spirv.IAdd %i, 1
+///     %cond = spirv.SLessThan %update_i, %veclen
+///     spirv.BranchConditional %cond, ^bb1, ^bb3
+///   ^bb3:
+///     spirv.mlir.merge
+///   }
+///   return
+///
+struct VectorMaskedStoreOpConverter final
+    : public OpConversionPattern<vector::MaskedStoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto memrefType = maskedStoreOp.getMemRefType();
+    if (!isa<spirv::StorageClassAttr>(memrefType.getMemorySpace()))
+      return failure();
+
+    VectorType maskVType = maskedStoreOp.getMaskVectorType();
+    if (maskVType.getRank() != 1)
+      return failure();
+    if (maskVType.getShape().size() != 1)
+      return failure();
+
+    // Create constants.
+    auto loc = maskedStoreOp.getLoc();
+    Value zero = createConstantInteger(rewriter, loc, 0);
+    Value one = createConstantInteger(rewriter, loc, 1);
+    Value maskLength =
+        createConstantInteger(rewriter, loc, maskVType.getShape()[0]);
+
+    // Construct a loop to go through the mask value
+    auto loopOp = createSpirvLoop(rewriter, loc);
+    auto *headerBlock = loopOp.getHeaderBlock();
+    auto *continueBlock = loopOp.getContinueBlock();
+
+    auto i32Type = rewriter.getI32Type();
+    BlockArgument indVar = headerBlock->addArgument(i32Type, loc);
+    BlockArgument continueIndVar = continueBlock->addArgument(i32Type, loc);
+
+    // Insert code into loop entry block
+    rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
+    rewriter.create<spirv::BranchOp>(loc, headerBlock, ArrayRef<Value>({zero}));
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointToEnd(headerBlock);
+    auto maskBit = extractMaskBit(rewriter, loc, adaptor.getMask(), indVar);
+
+    auto selectionOp = createSpirvSelection(rewriter, loc);
+    auto *selectionHeaderBlock = selectionOp.getHeaderBlock();
+    auto *selectionMergeBlock = selectionOp.getMergeBlock();
+    auto *selectionTrueBlock = &(*std::next(selectionOp.getBody().begin(), 1));
+
+    // Insert code into selection header block
+    rewriter.setInsertionPointToEnd(selectionHeaderBlock);
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, maskBit, selectionTrueBlock, std::nullopt, selectionMergeBlock,
+        std::nullopt);
+
+    // Insert code into selection true block
+    rewriter.setInsertionPointToEnd(selectionTrueBlock);
+    auto scalarType = memrefType.getElementType();
+    auto extractedStoreValue = extractVectorElement(
+        rewriter, loc, scalarType, adaptor.getValueToStore(), indVar);
+
+    // Store base[indVar]
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    auto indices = llvm::to_vector<4>(adaptor.getIndices());
+    auto updatedAccessChain =
+        addOffsetToIndices(rewriter, loc, indices, indVar, typeConverter,
+                           memrefType, adaptor.getBase());
+    rewriter.create<spirv::StoreOp>(loc, updatedAccessChain,
+                                    extractedStoreValue);
+    rewriter.create<spirv::BranchOp>(loc, selectionMergeBlock, std::nullopt);
+
+    // Insert code into loop header block
+    rewriter.setInsertionPointAfter(selectionOp);
+    rewriter.create<spirv::BranchOp>(loc, continueBlock,
+                                     ArrayRef<Value>({indVar}));
+
+    // Insert code into loop continue block
+    rewriter.setInsertionPointToEnd(continueBlock);
+
+    // Update induction variable.
+    auto updatedIndVar =
+        rewriter.create<spirv::IAddOp>(loc, continueIndVar, one);
+
+    // Check if the induction variable < length(mask)
+    auto cmpOp =
+        rewriter.create<spirv::SLessThanOp>(loc, updatedIndVar, maskLength);
+
+    auto *mergeBlock = loopOp.getMergeBlock();
+    rewriter.create<spirv::BranchConditionalOp>(
+        loc, cmpOp, headerBlock, ArrayRef<Value>({updatedIndVar}), mergeBlock,
+        std::nullopt);
+
+    // Insert code after loop
+    rewriter.setInsertionPointAfter(loopOp);
+    rewriter.replaceOp(maskedStoreOp, loopOp);
+
+    return success();
+  }
+};
+
 struct VectorReductionToIntDotProd final
     : OpRewritePattern<vector::ReductionOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -821,7 +1058,8 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc..d1e3286e2ede6 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -805,4 +805,94 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageB
   return
 }
 
+// CHECK-LABEL:  @vector_maskedload
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:    %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+//       CHECK:    %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[CST_F0:.*]] = arith.constant 0.000000e+00 : f32
+//       CHECK:    %[[S4:.*]] = spirv.CompositeConstruct %[[CST_F0]], %[[CST_F0]], %[[CST_F0]], %[[CST_F0]] : (f32, f32, f32, f32) -> vector<4xf32>
+//       CHECK:    %[[C0_1:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C0_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C5:.*]] = spirv.Constant 5 : i32
+//       CHECK:    %[[S5:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+//       CHECK:    %[[S6:.*]] = spirv.IAdd %[[C0_2]], %[[S5]] : i32
+//       CHECK:    %[[C1_1:.*]] = spirv.Constant 1 : i32
+//       CHECK:    %[[S7:.*]] = spirv.IMul %[[C1_1]], %[[S2]] : i32
+//       CHECK:    %[[S8:.*]] = spirv.IAdd %[[S6]], %[[S7]] : i32
+//       CHECK:    %[[S9:.*]] = spirv.AccessChain %[[S0]][%[[C0_1]], %[[S8]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:    %[[S10:.*]] = spirv.Bitcast %[[S9]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
+//       CHECK:    %[[S11:.*]] = spirv.Load "StorageBuffer" %[[S10]] : vector<4xf32>
+//       CHECK:    %[[S12:.*]] = spirv.Select %[[S3]], %[[S11]], %[[S4]] : vector<4xi1>, vector<4xf32>
+//       CHECK:    return %[[S12]] : vector<4xf32>
+//       CHECK:  }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+//       CHECK:    %[[S0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[S1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S2:.*]] = builtin.unrealized_conversion_cast %[[C4]] : index to i32
+//       CHECK:    %[[S3:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[C0_1:.*]] = spirv.Constant 0 : i32
+//       CHECK:    %[[C1_1:.*]] = spirv.Constant 1 : i32
+//       CHECK:    %[[C4_1:.*]] = spirv.Constant 4 : i32
+//       CHECK:    spirv.mlir.loop {
+//       CHECK:      spirv.Branch ^bb1(%[[C0_1]] : i32)
+//       CHECK:    ^bb1(%[[S4:.*]]: i32):  // 2 preds: ^bb0, ^bb2
+//       CHECK:      %[[S5:.*]] = spirv.VectorExtractDynamic %[[S3]][%[[S4]]] : vector<4xi1>, i32
+//       CHECK:      spirv.mlir.selection {
+//       CHECK:        spirv.BranchConditional %[[S5]], ^bb1, ^bb2
+//       CHECK:      ^bb1:  // pred: ^bb0
+//       CHECK:        %[[S9:.*]] = spirv.VectorExtractDynamic %[[ARG1]][%[[S4]]] : vector<4xf32>, i32
+//       CHECK:        %[[S10:.*]] = spirv.IAdd %[[S2]], %[[S4]] : i32
+//       CHECK:        %[[C0_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:        %[[C1_2:.*]] = spirv.Constant 0 : i32
+//       CHECK:        %[[C5:.*]] = spirv.Constant 5 : i32
+//       CHECK:        %[[S11:.*]] = spirv.IMul %[[C5]], %[[S1]] : i32
+//       CHECK:        %[[S12:.*]] = spirv.IAdd %[[C1_2]], %[[S11]] : i32
+//       CHECK:        %[[C1_3:.*]] = spirv.Constant 1 : i32
+//       CHECK:        %[[S13:.*]] = spirv.IMul %[[C1_3]], %[[S10]] : i32
+//       CHECK:        %[[S14:.*]] = spirv.IAdd %[[S12]], %[[S13]] : i32
+//       CHECK:        %[[S15:.*]] = spirv.AccessChain %[[S0]][%[[C0_2]], %[[S14]]] : !spirv.ptr<!spirv.struct<(!spirv.array<20 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:        spirv.Store "StorageBuffer" %[[S15]], %[[S9]] : f32
+//       CHECK:        spirv.Branch ^bb2
+//       CHECK:      ^bb2:  // 2 preds: ^bb0, ^bb1
+//       CHECK:        spirv.mlir.merge
+//       CHECK:      }
+//       CHECK:      spirv.Branch ^bb2(%[[S4]] : i32)
+//       CHECK:    ^bb2(%[[S6:.*]]: i32):  // pred: ^bb1
+//       CHECK:      %[[S7:.*]] = spirv.IAdd %[[S6]], %[[C1_1]] : i32
+//       CHECK:      %[[S8:.*]] = spirv.SLessThan %[[S7]], %[[C4_1]] : i32
+//       CHECK:      spirv.BranchConditional %[[S8]], ^bb1(%[[S7]] : i32), ^bb3
+//       CHECK:    ^bb3:  // pred: ^bb2
+//       CHECK:      spirv.mlir.merge
+//       CHECK:    }
+//       CHECK:    return
+//       CHECK:  }
+func.func @vector_maskedstore(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32>
+  return
+}
+
 } // end module



More information about the Mlir-commits mailing list