[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)

Han-Chung Wang llvmlistbot at llvm.org
Thu Nov 14 17:00:43 PST 2024


================
@@ -75,83 +75,133 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   int numSrcElemsPerDest,
                                                   int numFrontPadElems = 0) {
 
-  assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+  assert(numFrontPadElems < numSrcElemsPerDest &&
+         "numFrontPadElems must be less than numSrcElemsPerDest");
 
-  auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+  auto numDestElems = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
                      numSrcElemsPerDest;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
   // Finding the mask creation operation.
-  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+  while (maskOp &&
+         !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+             maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
       maskOp = extractOp.getVector().getDefiningOp();
       extractOps.push_back(extractOp);
     }
   }
-  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
-  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-  if (!createMaskOp && !constantMaskOp)
+
+  // TODO: add support to `vector.splat`.
----------------
hanhanW wrote:

Should we move this TODO to l.86 (i.e., right before the while loop)?

https://github.com/llvm/llvm-project/pull/116064


More information about the Mlir-commits mailing list