[Mlir-commits] [mlir] [vector][linearize] Refactor code to push target bit width out of patterns (PR #136581)

James Newling llvmlistbot at llvm.org
Tue Apr 22 10:00:57 PDT 2025


================
@@ -518,24 +441,103 @@ struct LinearizeVectorBitCast final
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type.");
 
-    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
-      return rewriter.notifyMatchFailure(
-          loc, "Can't flatten since targetBitWidth <= OpSize");
-
     rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
                                                    adaptor.getSource());
     return mlir::success();
   }
-
-private:
-  unsigned targetVectorBitWidth;
 };
 
 } // namespace
 
-void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target, unsigned targetBitWidth) {
+/// If `type` is VectorType with trailing dimension of (bit) size greater than
+/// or equal to `targetBitWidth`, its defining op is considered legal.
+static bool legalBecauseOfBitwidth(Type type, unsigned targetBitWidth) {
+
+  VectorType vecType = dyn_cast<VectorType>(type);
+
+  if (!vecType)
+    return true;
+
+  // The width of the type 'index' is unbounded (and therefore potentially above
+  // the target width).
+  if (vecType.getElementType().isIndex())
+    return true;
+
+  unsigned finalDimSize =
+      vecType.getRank() == 0 ? 0 : vecType.getShape().back();
+
+  unsigned trailingVecDimBitWidth =
+      finalDimSize * vecType.getElementTypeBitWidth();
+
+  return trailingVecDimBitWidth >= targetBitWidth;
+}
+
+static SmallVector<std::pair<Type, unsigned>>
+getChecksForBitwidth(Operation *op, unsigned targetBitWidth) {
+
+  if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
+    auto w = targetBitWidth < std::numeric_limits<unsigned>::max()
+                 ? targetBitWidth + 1
+                 : targetBitWidth;
+    return {{insertOp.getValueToStoreType(), w}};
+  }
+  auto resultTypes = op->getResultTypes();
+  SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
+  resultsWithBitWidth.reserve(resultTypes.size());
+  for (Type type : resultTypes) {
+    resultsWithBitWidth.push_back({type, targetBitWidth});
+  }
+  return resultsWithBitWidth;
+}
+
+/// Return true if the operation `op` does not support scalable vectors and
+/// has at least 1 scalable vector result.
+static bool legalBecauseScalable(Operation *op) {
+
+  bool scalableSupported = op->hasTrait<OpTrait::ConstantLike>() ||
+                           op->hasTrait<OpTrait::Vectorizable>() ||
+                           isa<vector::BitCastOp>(op);
----------------
newling wrote:

I was just trying to do a pure NFC refactor, and these were the only ops that support scalable vectors. Flipping the logic is a good idea, will make this clearer. 

https://github.com/llvm/llvm-project/pull/136581


More information about the Mlir-commits mailing list