[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)

Diego Caballero llvmlistbot at llvm.org
Wed Nov 13 17:39:01 PST 2024


================
@@ -75,83 +77,137 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   int numSrcElemsPerDest,
                                                   int numFrontPadElems = 0) {
 
-  assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+  assert(numFrontPadElems < numSrcElemsPerDest &&
+         "intraDataOffset must be less than scale");
 
   auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
                      numSrcElemsPerDest;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
   // Finding the mask creation operation.
-  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+  while (maskOp &&
+         !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+             maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
       maskOp = extractOp.getVector().getDefiningOp();
       extractOps.push_back(extractOp);
     }
   }
-  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
-  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-  if (!createMaskOp && !constantMaskOp)
+
+  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+          maskOp))
     return failure();
 
   // Computing the "compressed" mask. All the emulation logic (i.e. computing
   // new mask index) only happens on the last dimension of the vectors.
-  Operation *newMask = nullptr;
   SmallVector<int64_t> shape(
       cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
-  if (createMaskOp) {
-    OperandRange maskOperands = createMaskOp.getOperands();
-    size_t numMaskOperands = maskOperands.size();
-    AffineExpr s0;
-    bindSymbols(rewriter.getContext(), s0);
-    s0 = s0 + numSrcElemsPerDest - 1;
-    s0 = s0.floorDiv(numSrcElemsPerDest);
-    OpFoldResult origIndex =
-        getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
-    OpFoldResult maskIndex =
-        affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
-    SmallVector<Value> newMaskOperands(maskOperands.drop_back());
-    newMaskOperands.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
-    newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
-                                                    newMaskOperands);
-  } else if (constantMaskOp) {
-    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
-    size_t numMaskOperands = maskDimSizes.size();
-    int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-    int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
-    int64_t maskIndex =
-        llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
-
-    // TODO: we only want the mask between [startIndex, maskIndex] to be true,
-    // the rest are false.
-    if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
-      return failure();
-
-    SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
-    newMaskDimSizes.push_back(maskIndex);
-
-    if (numFrontPadElems == 0) {
-      newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                        newMaskDimSizes);
-    } else {
-      SmallVector<bool> newMaskValues;
-      for (int64_t i = 0; i < numElements; ++i)
-        newMaskValues.push_back(i >= startIndex && i < maskIndex);
-      auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
-      newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
-    }
-  }
+  std::optional<Operation *> newMask =
+      TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
+          .Case<vector::CreateMaskOp>(
+              [&](auto createMaskOp) -> std::optional<Operation *> {
+                OperandRange maskOperands = createMaskOp.getOperands();
+                size_t numMaskOperands = maskOperands.size();
+                AffineExpr s0;
+                bindSymbols(rewriter.getContext(), s0);
+                s0 = s0 + numSrcElemsPerDest - 1;
+                s0 = s0.floorDiv(numSrcElemsPerDest);
+                OpFoldResult origIndex =
+                    getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+                OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
+                    rewriter, loc, s0, origIndex);
+                SmallVector<Value> newMaskOperands(maskOperands.drop_back());
+                newMaskOperands.push_back(
+                    getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+                return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+                                                             newMaskOperands);
+              })
+          .Case<vector::ConstantMaskOp>(
+              [&](auto constantMaskOp) -> std::optional<Operation *> {
+                ArrayRef<int64_t> maskDimSizes =
+                    constantMaskOp.getMaskDimSizes();
+                size_t numMaskOperands = maskDimSizes.size();
+                int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+                int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+                int64_t maskIndex = llvm::divideCeil(
+                    numFrontPadElems + origIndex, numSrcElemsPerDest);
+
+                // TODO: we only want the mask between [startIndex, maskIndex]
+                // to be true, the rest are false.
+                if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
+                  return std::nullopt;
+
+                SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+                newMaskDimSizes.push_back(maskIndex);
+
+                if (numFrontPadElems == 0) {
+                  return rewriter.create<vector::ConstantMaskOp>(
+                      loc, newMaskType, newMaskDimSizes);
+                } else {
----------------
dcaballe wrote:

nit: no else after return

https://github.com/llvm/llvm-project/pull/116064


More information about the Mlir-commits mailing list