[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Nov 14 08:20:40 PST 2024


================
@@ -75,83 +77,137 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   int numSrcElemsPerDest,
                                                   int numFrontPadElems = 0) {
 
-  assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+  assert(numFrontPadElems < numSrcElemsPerDest &&
+         "intraDataOffset must be less than scale");
 
   auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
                      numSrcElemsPerDest;
 
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
   // Finding the mask creation operation.
-  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+  while (maskOp &&
+         !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+             maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
       maskOp = extractOp.getVector().getDefiningOp();
       extractOps.push_back(extractOp);
     }
   }
-  auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
-  auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-  if (!createMaskOp && !constantMaskOp)
+
+  if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+          maskOp))
----------------
banach-space wrote:

Add a TODO - vector.splat ;-)

https://github.com/llvm/llvm-project/pull/116064


More information about the Mlir-commits mailing list