[Mlir-commits] [mlir] [mlir][vector] Add support for vector.maskedstore sub-type emulation. (PR #73871)

Diego Caballero llvmlistbot at llvm.org
Thu Nov 30 03:08:51 PST 2023


================
@@ -32,6 +32,78 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+/// Returns a compressed mask. The mask value is set only if any mask is present
+/// in the the scale range. E.g., if `scale` equals to 2, the following mask:
+///
+///   %mask = [1, 1, 1, 0, 0, 0]
+///
+/// will return the following new compressed mask:
+///
+///   %mask = [1, 1, 0]
+static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
+                                                  Location loc, Value mask,
+                                                  int origElements, int scale) {
+  auto numElements = (origElements + scale - 1) / scale;
+
+  auto maskOp = mask.getDefiningOp();
+  SmallVector<vector::ExtractOp, 2> extractOps;
+  // Finding the mask creation operation.
+  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+    if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
+      maskOp = extractOp.getVector().getDefiningOp();
+      extractOps.push_back(extractOp);
+    }
+  }
+  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
+  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
+  if (!createMaskOp && !constantMaskOp)
+    return failure();
+
+  // Computing the "compressed" mask. All the emulation logic (i.e. computing
+  // new mask index) only happens on the last dimension of the vectors.
+  Operation *newMask = nullptr;
+  auto shape = llvm::to_vector(
+      maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
+  shape.push_back(numElements);
+  auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
+  if (createMaskOp) {
+    auto maskOperands = createMaskOp.getOperands();
+    auto numMaskOperands = maskOperands.size();
+    AffineExpr s0;
+    bindSymbols(rewriter.getContext(), s0);
+    s0 = s0 + scale - 1;
+    s0 = s0.floorDiv(scale);
+    OpFoldResult origIndex =
+        getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+    OpFoldResult maskIndex =
+        affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
+    auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
+    newMaskOperands.push_back(
+        getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+    newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+                                                    newMaskOperands);
+  } else if (constantMaskOp) {
+    auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+    auto numMaskOperands = maskDimSizes.size();
+    auto origIndex =
+        cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
+    auto maskIndex =
+        rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+    auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
----------------
dcaballe wrote:

ditto

https://github.com/llvm/llvm-project/pull/73871


More information about the Mlir-commits mailing list