[Mlir-commits] [mlir] [mlir][vector] Add support for vector.maskedstore sub-type emulation. (PR #73871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 29 15:58:20 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The idea is similar to vector.maskedload + vector.store emulation. What the emulation does is:

1. Get a compressed mask and load the data from destination.
2. Bitcast the data to original vector type.
3. Select values between `op.valueToStore` and the data from load using original mask.
4. Bitcast the new value and store it to destination using compressed masked.

---
Full diff: https://github.com/llvm/llvm-project/pull/73871.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+168-62) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+72) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6aea0343bfc9327..05c98b89e8a94c1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -32,6 +32,78 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+/// Returns a compressed mask. The mask value is set only if any mask is present
+/// in the the scale range. E.g., if `scale` equals to 2, the following mask:
+///
+///   %mask = [1, 1, 1, 0, 0, 0]
+///
+/// will return the following new compressed mask:
+///
+///   %mask = [1, 1, 0]
+static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
+                                                  Location loc, Value mask,
+                                                  int origElements, int scale) {
+  auto numElements = (origElements + scale - 1) / scale;
+
+  auto maskOp = mask.getDefiningOp();
+  SmallVector<vector::ExtractOp, 2> extractOps;
+  // Finding the mask creation operation.
+  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+    if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
+      maskOp = extractOp.getVector().getDefiningOp();
+      extractOps.push_back(extractOp);
+    }
+  }
+  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
+  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
+  if (!createMaskOp && !constantMaskOp)
+    return failure();
+
+  // Computing the "compressed" mask. All the emulation logic (i.e. computing
+  // new mask index) only happens on the last dimension of the vectors.
+  Operation *newMask = nullptr;
+  auto shape = llvm::to_vector(
+      maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
+  shape.push_back(numElements);
+  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
+  if (createMaskOp) {
+    auto maskOperands = createMaskOp.getOperands();
+    auto numMaskOperands = maskOperands.size();
+    AffineExpr s0;
+    bindSymbols(rewriter.getContext(), s0);
+    s0 = s0 + scale - 1;
+    s0 = s0.floorDiv(scale);
+    OpFoldResult origIndex =
+        getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+    OpFoldResult maskIndex =
+        affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
+    auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
+    newMaskOperands.push_back(
+        getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+    newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+                                                    newMaskOperands);
+  } else if (constantMaskOp) {
+    auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+    auto numMaskOperands = maskDimSizes.size();
+    auto origIndex =
+        cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
+    auto maskIndex =
+        rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+    auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
+    newMaskDimSizes.push_back(maskIndex);
+    newMask = rewriter.create<vector::ConstantMaskOp>(
+        loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+  }
+
+  while (!extractOps.empty()) {
+    newMask = rewriter.create<vector::ExtractOp>(
+        loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+    extractOps.pop_back();
+  }
+
+  return newMask;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -99,6 +171,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedStore final
+    : OpConversionPattern<vector::MaskedStoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getValueToStore().getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    int scale = dstBits / srcBits;
+    auto origElements = op.getValueToStore().getType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+    OpFoldResult linearizedIndicesOfr;
+    std::tie(std::ignore, linearizedIndicesOfr) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+    Value linearizedIndices =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
+
+    // Load the whole data and use arith.select to handle the corner cases.
+    // E.g., given these input values:
+    //
+    //   %mask = [1, 1, 1, 0, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+    //   %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //
+    // we'll have
+    //
+    //    expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+    //
+    //    %new_mask = [1, 1, 0]
+    //    %maskedload = [0x12, 0x34, 0x0]
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
+    //    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
+    //    %packed_data = [0x78, 0x94, 0x0, 0x0]
+    //
+    // Using the new mask to store %packed_data results in expected output.
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    if (failed(newMask))
+      return failure();
+
+    auto numElements = (origElements + scale - 1) / scale;
+    auto newType = VectorType::get(numElements, newElementType);
+    auto passThru = rewriter.create<arith::ConstantOp>(
+        loc, newType, rewriter.getZeroAttr(newType));
+
+    auto newLoad = rewriter.create<vector::MaskedLoadOp>(
+        loc, newType, adaptor.getBase(), linearizedIndices,
+        newMask.value()->getResult(0), passThru);
+
+    Value valueToStore = rewriter.create<vector::BitCastOp>(
+        loc, op.getValueToStore().getType(), newLoad);
+    valueToStore = rewriter.create<arith::SelectOp>(
+        loc, op.getMask(), op.getValueToStore(), valueToStore);
+    valueToStore =
+        rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
+
+    rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+        op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
+        valueToStore);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
@@ -236,7 +396,6 @@ struct ConvertVectorMaskedLoad final
     // TODO: Currently, only the even number of elements loading is supported.
     // To deal with the odd number of elements, one has to extract the
     // subvector at the proper offset after bit-casting.
-
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
     if (origElements % scale != 0)
@@ -244,7 +403,6 @@ struct ConvertVectorMaskedLoad final
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
-
     OpFoldResult linearizedIndices;
     std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
@@ -254,66 +412,13 @@ struct ConvertVectorMaskedLoad final
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
-    auto newType = VectorType::get(numElements, newElementType);
-
-    auto maskOp = op.getMask().getDefiningOp();
-    SmallVector<vector::ExtractOp, 2> extractOps;
-    // Finding the mask creation operation.
-    while (maskOp &&
-           !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
-      if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
-        maskOp = extractOp.getVector().getDefiningOp();
-        extractOps.push_back(extractOp);
-      }
-    }
-    auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
-    auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-    if (!createMaskOp && !constantMaskOp)
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    if (failed(newMask))
       return failure();
 
-    // Computing the "compressed" mask. All the emulation logic (i.e. computing
-    // new mask index) only happens on the last dimension of the vectors.
-    Operation *newMask = nullptr;
-    auto shape = llvm::to_vector(
-        maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
-    shape.push_back(numElements);
-    auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
-    if (createMaskOp) {
-      auto maskOperands = createMaskOp.getOperands();
-      auto numMaskOperands = maskOperands.size();
-      AffineExpr s0;
-      bindSymbols(rewriter.getContext(), s0);
-      s0 = s0 + scale - 1;
-      s0 = s0.floorDiv(scale);
-      OpFoldResult origIndex =
-          getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
-      OpFoldResult maskIndex =
-          affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
-      auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
-      newMaskOperands.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
-      newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
-                                                      newMaskOperands);
-    } else if (constantMaskOp) {
-      auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
-      auto numMaskOperands = maskDimSizes.size();
-      auto origIndex =
-          cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
-      auto maskIndex =
-          rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
-      auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
-      newMaskDimSizes.push_back(maskIndex);
-      newMask = rewriter.create<vector::ConstantMaskOp>(
-          loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
-    }
-
-    while (!extractOps.empty()) {
-      newMask = rewriter.create<vector::ExtractOp>(
-          loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
-      extractOps.pop_back();
-    }
-
+    auto numElements = (origElements + scale - 1) / scale;
+    auto newType = VectorType::get(numElements, newElementType);
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
 
@@ -321,7 +426,7 @@ struct ConvertVectorMaskedLoad final
     auto newLoad = rewriter.create<vector::MaskedLoadOp>(
         loc, newType, adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
-        newMask->getResult(0), newPassThru);
+        newMask.value()->getResult(0), newPassThru);
 
     // Setting the part that originally was not effectively loaded from memory
     // to pass through.
@@ -821,7 +926,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
 
   // Populate `vector.*` conversion patterns.
   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
-               ConvertVectorTransferRead>(typeConverter, patterns.getContext());
+               ConvertVectorMaskedStore, ConvertVectorTransferRead>(
+      typeConverter, patterns.getContext());
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index af0f98a1c447de0..cba299b2a1d9567 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -428,3 +428,75 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 //      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
 //      CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+  %0 = memref.alloc() : memref<3x8xi8>
+  %mask = vector.create_mask %arg2 : vector<8xi1>
+  vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+  return
+}
+// Expect no conversions, i8 is supported.
+//      CHECK: func @vector_maskedstore_i8(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
+// CHECK-NEXT:   %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK-NEXT:   vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
+// CHECK-NEXT:   return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32:     func @vector_maskedstore_i8(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK32-SAME:     %[[VAL:[a-zA-Z0-9]+]]
+// CHECK32:        %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
+// CHECK32:        %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK32:        %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32:        %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+// CHECK32:        %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<2xi1>
+// CHECK32:        %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK32:        %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
+// CHECK32:        %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
+// CHECK32:        %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
+// CHECK32:        %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
+// CHECK32:        vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+  %0 = memref.alloc() : memref<3x8xi8>
+  %mask = vector.constant_mask [4] : vector<8xi1>
+  vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+  return
+}
+// Expect no conversions, i8 is supported.
+//      CHECK: func @vector_cst_maskedstore_i8(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
+// CHECK-NEXT:   %[[MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK-NEXT:   vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
+// CHECK-NEXT:   return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32:     func @vector_cst_maskedstore_i8(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK32-SAME:     %[[VAL:[a-zA-Z0-9]+]]
+// CHECK32:        %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
+// CHECK32:        %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32:        %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32:        %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<2xi1>
+// CHECK32:        %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK32:        %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
+// CHECK32:        %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
+// CHECK32:        %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
+// CHECK32:        %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
+// CHECK32:        vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/73871


More information about the Mlir-commits mailing list