[Mlir-commits] [mlir] [mlir][vector] Add more patterns to Vector Linearize transformation (PR #136193)

Nishant Patel llvmlistbot at llvm.org
Wed Apr 23 09:49:32 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/136193

>From 8e39c56b6f39cc03002ba9c5e6662fa29d478016 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 15 Apr 2025 22:34:54 +0000
Subject: [PATCH 1/6] Add more patterns to Vector Linearize Pass

---
 .../Vector/Transforms/VectorLinearize.cpp     | 407 +++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 335 ++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |   3 +-
 3 files changed, 741 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
 using namespace mlir;
 
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  // For BW-0, all operations are legal
+  if (targetBitWidth == 0) {
+    return false;
+  }
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern<
+      vector::InsertStridedSliceOp>::OpConversionPattern;
+      LinearizeVectorInsertStridedSlice(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto srcTy = op.getSourceVectorType();
+    auto dstTy = op.getDestVectorType();
+
+    if (op.hasNonUnitStrides()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports unit strides.");
+    }
+
+    if (srcTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports 2D source.");
+    }
+
+    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linerization only supports static shapes.");
+    }
+
+    auto dstShape = dstTy.getShape();
+    auto dstStrides = dstShape.drop_front().vec();
+    dstStrides.push_back(1);
+    int64_t linearizedOffset = 0;
+    for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+      linearizedOffset += getConstantIntValue(off).value() * stride;
+    }
+
+    // extracts a row from source, and insert it into the destination
+    auto srcShape = srcTy.getShape();
+    Value dstValue = adaptor.getDest();
+    for (auto i = 0; i < srcShape[0]; i++) {
+      auto srcOffset = i * srcShape[1];
+      auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+      auto dstOffset = linearizedOffset + i * dstShape.back();
+      dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, value, dstValue, dstOffset, 1);
+    }
+
+    rewriter.replaceOp(op, dstValue);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
 /// vectors to a ShuffleOp that works on linearized vectors.
 /// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
+
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %result = arith.constant dense<0.0> : vector<4x4xf32>
+///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+    : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+  LinearizeVectorLoad(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = loadOp->getLoc();
+    auto vecType = loadOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+    auto unrollCount = shape[0];
+    auto vecSize = shape[1];
+    auto newVecType =
+        VectorType::get({vecSize}, vecType.getElementType());
+
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    // Construct the 2D vector.
+    Value resultVec = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(vecType));
+    // Emit unrolled loads for each 1D vector slice.
+    for (auto i = 0; i < unrollCount; i++) {
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      auto vec = rewriter.create<vector::LoadOp>(
+          loc, newVecType, adaptor.getBase(), indices);
+      resultVec =
+          rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+    }
+
+    rewriter.replaceOp(loadOp, resultVec);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %slice_0 = vector.extract %source[0] : vector<4xf32>
+///   vector.store %slice_0, %base[%indices] : vector<4xf32>
+///   %slice_1 = vector.extract %source[1] : vector<4xf32>
+///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+///   ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+  LinearizeVectorStore(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = storeOp->getLoc();
+    auto vecType = storeOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+
+    auto unrollCount = shape[0];
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    auto vec = rewriter.create<vector::ShapeCastOp>(
+        loc, vecType, adaptor.getValueToStore());
+
+    for (auto i = 0; i < unrollCount; i++) {
+      auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+                                             indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(
+        splatOp, adaptor.getInput(), dstTy);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        createMaskOp, dstTy, adaptor.getOperands().back());
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+    : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+  using OpInterfaceConversionPattern<
+      RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+  LinearizeRegionBranchOp(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpInterfaceConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(RegionBranchOpInterface op,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto converter = getTypeConverter();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.startOpModification(op);
+
+    llvm::SmallVector<Type> convertedTypes;
+    for (Type ty : op->getResultTypes()) {
+      convertedTypes.push_back(converter->convertType(ty));
+    }
+
+    if (convertedTypes == op->getResultTypes() &&
+        op->getOperands() == operands) {
+      return failure();
+    }
+
+    op->setOperands(operands);
+
+    // Convert region types (block arguments and yields)
+    for (Region &region : op->getRegions()) {
+      if (failed(rewriter.convertRegionTypes(&region, *converter))) {
+        return failure();
+      }
+
+      // Process yields within each region
+      for (Block &block : region) {
+        if (auto *terminator = block.getTerminator()) {
+          for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+            Value value = yieldOperand.get();
+            Type type = value.getType();
+            if (!converter->isLegal(type)) {
+              Type newTy = converter->convertType(type);
+              rewriter.setInsertionPoint(terminator);
+              Value newValue =
+                  rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+              yieldOperand.set(newValue);
+            }
+          }
+        }
+      }
+    }
+
+    // Update result types
+    rewriter.setInsertionPointAfter(op);
+    llvm::SmallVector<Value> newResults;
+    for (Value result : op->getResults()) {
+      Type oldTy = result.getType();
+      if (!converter->isLegal(oldTy)) {
+        Type newTy = converter->convertType(oldTy);
+        result.setType(newTy);
+        Operation *castOp =
+            rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+        result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+        newResults.push_back(castOp->getResult(0));
+      } else {
+        newResults.push_back(result);
+      }
+    }
+
+    rewriter.finalizeOpModification(op);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth) {
 
+  typeConverter.addConversion([](Type type) -> Type { return type; });
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
       return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+  target.addLegalOp<mlir::vector::ShapeCastOp>();
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::LoadOp,
+                 vector::StoreOp, vector::CreateMaskOp,
+                 RegionBranchOpInterface, vector::SplatOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+               LinearizeVectorBitCast, LinearizeVectorLoad,
+               LinearizeVectorStore, LinearizeVectorSplat,
+               LinearizeVectorCreateMask, LinearizeRegionBranchOp
+               >(typeConverter, patterns.getContext(),
                                        targetBitWidth);
 }
 
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
                               .getRank() == 1)
                    : true;
       });
+
+  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+    [=](vector::InsertStridedSliceOp op) -> bool {
+      if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+        auto srcTy = op.getSourceVectorType();
+        auto dstTy = op.getDestVectorType();
+        if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+            srcTy.hasStaticShape() && dstTy.hasStaticShape())
+          return false;
+      }
+      return true;
+    });
+
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+               LinearizeVectorInsertStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
+  // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
+  // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1_0:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2_1:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[C3:.*]] = arith.constant 3 : index
+  // BW-128: %[[C3:.*]] = arith.constant 3 : index
+  // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
+  // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16>
+  // DEFAULT: return %[[CAST]] : vector<4x4xf16>
+  // BW-128: return %[[CAST]] : vector<4x4xf16>
+
+  // BW-0: %[[C1:.*]] = arith.constant 1 : index
+  // BW-0: %[[C2:.*]] = arith.constant 2 : index
+  // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
+  // BW-0: return %[[LOAD]] : vector<4x4xf16>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
+  return %0 : vector<4x4xf16>
+}
+
+// -----
+// ALL-LABEL: test_vector_store
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) {
+func.func @test_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
+  // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
+  // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16>
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
+  // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16>
+  // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
+  // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16>
+  // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1_0:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index
+  // DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2_1:.*]] = arith.constant 2 : index
+  // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index
+  // DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
+  // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16>
+  // DEFAULT: %[[C3:.*]] = arith.constant 3 : index
+  // BW-128: %[[C3:.*]] = arith.constant 3 : index
+  // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index
+  // DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16>
+  // DEFAULT: return
+  // BW-128: return
+
+  // BW-0: %[[C1:.*]] = arith.constant 1 : index
+  // BW-0: %[[C2:.*]] = arith.constant 2 : index
+  // BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16>
+  // BW-0: return
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+
+// -----
+// ALL-LABEL: test_create_mask
+func.func @test_create_mask() -> vector<1x16xi1> {
+  // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
+  // BW-128: %[[C0:.*]] = arith.constant 0 : index
+  // DEFAULT: %[[C20:.*]] = arith.constant 20 : index
+  // BW-128: %[[C20:.*]] = arith.constant 20 : index
+  // DEFAULT: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // BW-128: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1>
+
+  // BW-0: %[[C0:.*]] = arith.constant 0 : index
+  // BW-0: %[[C20:.*]] = arith.constant 20 : index
+  // BW-0: %[[MASK:.*]] = vector.create_mask %[[C0]], %[[C20]] : vector<1x16xi1>
+  %c0 = arith.constant 0 : index
+  %c20 = arith.constant 20 : index
+  %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+  return %0 : vector<1x16xi1>
+}
+
+// -----
+// ALL-LABEL: test_loop
+func.func @test_loop() -> vector<2x4xf16> {
+  // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
+  // BW-128: %[[C0:.*]] = arith.constant 0 : index
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C4:.*]] = arith.constant 4 : index
+  // BW-128: %[[C4:.*]] = arith.constant 4 : index
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16>
+  // BW-128: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16>
+  // DEFAULT: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) {
+  // BW-128: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) {
+  // DEFAULT: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16>
+  // BW-128: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16>
+  // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16>
+  // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16>
+  // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16>
+  // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16>
+  // DEFAULT: scf.yield %[[CAST1]] : vector<8xf16>
+  // BW-128: scf.yield %[[CAST1]] : vector<8xf16>
+  // DEFAULT: }
+  // BW-128: }
+  // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16>
+  // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16>
+  // DEFAULT: return %[[CAST2]] : vector<2x4xf16>
+  // BW-128: return %[[CAST2]] : vector<2x4xf16>
+
+  // BW-0: %[[C0:.*]] = arith.constant 0 : index
+  // BW-0: %[[C1:.*]] = arith.constant 1 : index
+  // BW-0: %[[C4:.*]] = arith.constant 4 : index
+  // BW-0: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf16>
+  // BW-0: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<2x4xf16>) {
+  // BW-0: %[[ADD:.*]] = arith.addf %[[CST]], %[[ARG1]] : vector<2x4xf16>
+  // BW-0: scf.yield %[[ADD]] : vector<2x4xf16>
+  // BW-0: }
+  // BW-0: return %[[FOR]] : vector<2x4xf16>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %1 = arith.constant dense<1.0> : vector<2x4xf16>
+  %r = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %1) -> (vector<2x4xf16>) {
+    %2 = arith.addf %1, %arg1 : vector<2x4xf16>
+    scf.yield %2 : vector<2x4xf16>
+  }
+  return %r : vector<2x4xf16>
+}
+
+// -----
+// ALL-LABEL: test_vector_insert_2d_idx
+// ALL-SAME: (%[[ARG:.*]]: vector<4x8xf16>) -> vector<8x16xf16>
+func.func @test_vector_insert_2d_idx(%arg0: vector<4x8xf16>) -> vector<8x16xf16> {
+  // DEFAULT: %[[V0:.*]] = vector.shape_cast %[[ARG]] : vector<4x8xf16> to vector<32xf16>
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+  // DEFAULT: %[[V1:.*]] = vector.shuffle %[[V0]], %[[V0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16>
+  // DEFAULT: %[[V2:.*]] = vector.insert_strided_slice %[[V1]], %[[CST]] {offsets = [0], strides = [1]} : vector<8xf16> into vector<128xf16>
+  // DEFAULT: %[[V3:.*]] = vector.shuffle %[[V0]], %[[V0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
+  // DEFAULT: %[[V4:.*]] = vector.insert_strided_slice %[[V3]], %[[V2]] {offsets = [16], strides = [1]} : vector<8xf16> into vector<128xf16>
+  // DEFAULT: %[[V5:.*]] = vector.shuffle %[[V0]], %[[V0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16>
+  // DEFAULT: %[[V6:.*]] = vector.insert_strided_slice %[[V5]], %[[V4]] {offsets = [32], strides = [1]} : vector<8xf16> into vector<128xf16>
+  // DEFAULT: %[[V7:.*]] = vector.shuffle %[[V0]], %[[V0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
+  // DEFAULT: %[[V8:.*]] = vector.insert_strided_slice %[[V7]], %[[V6]] {offsets = [48], strides = [1]} : vector<8xf16> into vector<128xf16>
+  // DEFAULT: %[[V9:.*]] = vector.shape_cast %[[V8]] : vector<128xf16> to vector<8x16xf16>
+  // DEFAULT: return %[[V9]] : vector<8x16xf16>
+
+  // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf16>
+  // BW-128: %[[V0:.*]] = vector.insert_strided_slice %[[ARG]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16>
+  // BW-128: return %[[V0]] : vector<8x16xf16>
+
+  // BW-0: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf16>
+  // BW-0: %[[V0:.*]] = vector.insert_strided_slice %[[ARG]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16>
+  // BW-0: return %[[V0]] : vector<8x16xf16>
+  %cst = arith.constant dense <0.0> : vector<8x16xf16>
+  %0 = vector.insert_strided_slice %arg0, %cst {offsets = [0, 0], strides = [1, 1]} : vector<4x8xf16> into vector<8x16xf16>
+  return %0 : vector<8x16xf16>
+}
+
+// -----
+// ALL-LABEL: test_if_single_vector
+func.func @test_if_single_vector() -> vector<16x1xi32> {
+  // DEFAULT: %[[COND:.*]] = arith.constant false
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32>
+  // DEFAULT: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) {
+  // DEFAULT:   %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32>
+  // DEFAULT:   %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32>
+  // DEFAULT:   %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32>
+  // DEFAULT:   scf.yield %[[V3]] : vector<16xi32>
+  // DEFAULT: } else {
+  // DEFAULT:   %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32>
+  // DEFAULT:   %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32>
+  // DEFAULT:   %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32>
+  // DEFAULT:   scf.yield %[[V5]] : vector<16xi32>
+  // DEFAULT: }
+  // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32>
+  // DEFAULT: return %[[V1]] : vector<16x1xi32>
+
+  // BW-128: %[[COND:.*]] = arith.constant false
+  // BW-128: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32>
+  // BW-128: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) {
+  // BW-128:   %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32>
+  // BW-128:   %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32>
+  // BW-128:   %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32>
+  // BW-128:   scf.yield %[[V3]] : vector<16xi32>
+  // BW-128: } else {
+  // BW-128:   %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32>
+  // BW-128:   %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32>
+  // BW-128:   %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32>
+  // BW-128:   scf.yield %[[V5]] : vector<16xi32>
+  // BW-128: }
+  // BW-128: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32>
+  // BW-128: return %[[V1]] : vector<16x1xi32>
+
+  // BW-0: %[[COND:.*]] = arith.constant false
+  // BW-0: %[[V:.*]] = arith.constant dense<3> : vector<16x1xi32>
+  // BW-0: %[[R:.*]] = scf.if %[[COND]] -> (vector<16x1xi32>) {
+  // BW-0:   %[[ADD:.*]] = arith.addi %[[V]], %[[V]] : vector<16x1xi32>
+  // BW-0:   scf.yield %[[ADD]] : vector<16x1xi32>
+  // BW-0: } else {
+  // BW-0:   %[[SUB:.*]] = arith.subi %[[V]], %[[V]] : vector<16x1xi32>
+  // BW-0:   scf.yield %[[SUB]] : vector<16x1xi32>
+  // BW-0: }
+  %cond = arith.constant 0 : i1
+  %v = arith.constant dense<3> : vector<16x1xi32>
+  %r = scf.if %cond -> (vector<16x1xi32>) {
+    %add = arith.addi %v, %v : vector<16x1xi32>
+    scf.yield %add : vector<16x1xi32>
+  } else {
+    %sub = arith.subi %v, %v : vector<16x1xi32>
+    scf.yield %sub : vector<16x1xi32>
+  }
+  return %r : vector<16x1xi32>
+}
+
+// -----
+// ALL-LABEL: test_while
+func.func @test_while() -> vector<2x4xf32> {
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32>
+  // DEFAULT: %[[V0:.*]] = scf.while (%[[ARG0:.*]] = %[[CST]]) : (vector<8xf32>) -> vector<8xf32> {
+  // DEFAULT:   %[[V2:.*]] = vector.shape_cast %[[ARG0]] : vector<8xf32> to vector<2x4xf32>
+  // DEFAULT:   %[[C0:.*]] = arith.constant 0 : i32
+  // DEFAULT:   %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32
+  // DEFAULT:   %[[V4:.*]] = vector.shape_cast %[[V2]] : vector<2x4xf32> to vector<8xf32>
+  // DEFAULT:   scf.condition(%[[COND]]) %[[V4]] : vector<8xf32>
+  // DEFAULT: } do {
+  // DEFAULT: ^bb0(%[[ARG1:.*]]: vector<8xf32>):
+  // DEFAULT:   %[[V2:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<8xf32>
+  // DEFAULT:   %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<8xf32> to vector<2x4xf32>
+  // DEFAULT:   %[[V4:.*]] = vector.shape_cast %[[V3]] : vector<2x4xf32> to vector<8xf32>
+  // DEFAULT:   scf.yield %[[V4]] : vector<8xf32>
+  // DEFAULT: }
+  // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<8xf32> to vector<2x4xf32>
+  // DEFAULT: return %[[V1]] : vector<2x4xf32>
+
+  // BW-128: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32>
+  // BW-128: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> {
+  // BW-128:   %[[C0:.*]] = arith.constant 0 : i32
+  // BW-128:   %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32
+  // BW-128:   scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32>
+  // BW-128: } do {
+  // BW-128: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>):
+  // BW-128:   %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32>
+  // BW-128:   scf.yield %[[ADD]] : vector<2x4xf32>
+  // BW-128: }
+  // BW-128: return %[[RESULT]] : vector<2x4xf32>
+
+  // BW-0: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32>
+  // BW-0: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> {
+  // BW-0:   %[[C0:.*]] = arith.constant 0 : i32
+  // BW-0:   %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32
+  // BW-0:   scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32>
+  // BW-0: } do {
+  // BW-0: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>):
+  // BW-0:   %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32>
+  // BW-0:   scf.yield %[[ADD]] : vector<2x4xf32>
+  // BW-0: }
+  // BW-0: return %[[RESULT]] : vector<2x4xf32>
+  %v = arith.constant dense<1.0> : vector<2x4xf32>
+  %result = scf.while (%arg0 = %v) : (vector<2x4xf32>) -> vector<2x4xf32> {
+    %c0 = arith.constant 0 : i32
+    %cond = arith.cmpi slt, %c0, %c0 : i32
+    scf.condition(%cond) %arg0 : vector<2x4xf32>
+  } do {
+  ^bb0(%arg1: vector<2x4xf32>):
+    %add = arith.addf %arg1, %arg1 : vector<2x4xf32>
+    scf.yield %add : vector<2x4xf32>
+  }
+  return %result : vector<2x4xf32>
+}
+
+// -----
+// ALL-LABEL: test_vector_splat
+// ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
+func.func @test_vector_splat(%arg0: i32) -> vector<4x2xi32> {
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // DEFAULT: return %[[CAST]] : vector<4x2xi32>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
+  // BW-128: return %[[CAST]] : vector<4x2xi32>
+
+  // BW-0: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4x2xi32>
+  // BW-0: return %[[SPLAT]] : vector<4x2xi32>
+  %0 = vector.splat %arg0 : vector<4x2xi32>
+  return %0 : vector<4x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..40b0a2321a2b2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -851,7 +851,8 @@ struct TestVectorLinearize final
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect>();
+    registry.insert<vector::VectorDialect, memref::MemRefDialect,
+                    arith::ArithDialect, scf::SCFDialect>();
   }
 
   Option<unsigned> targetVectorBitwidth{

>From a76f02d7beb790cb30df34d42f5c0f0047be7a10 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 17 Apr 2025 20:39:51 +0000
Subject: [PATCH 2/6] Run Clang-format

---
 .../Vector/Transforms/VectorLinearize.cpp     | 214 +++++++++---------
 1 file changed, 108 insertions(+), 106 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 6de5d0c5a101e..d97eed7aea008 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -282,22 +282,24 @@ struct LinearizeVectorExtractStridedSlice final
 /// source vector using ExtractStridedSliceOp and inserting them into the
 /// destination vector using InsertStridedSliceOp.
 /// Following,
-///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into
+///   vector<4x4xf32>
 /// is converted to :
-///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
-///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
-///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
-///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]}
+///   : vector<4xf32> from vector<8xf32> %1 = vector.insert_strided_slice %0, %d
+///   {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> %2 =
+///   vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} :
+///   vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1
+///   {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
 struct LinearizeVectorInsertStridedSlice final
     : public OpConversionPattern<vector::InsertStridedSliceOp> {
-  using OpConversionPattern<
-      vector::InsertStridedSliceOp>::OpConversionPattern;
-      LinearizeVectorInsertStridedSlice(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+  using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
+  LinearizeVectorInsertStridedSlice(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
@@ -345,8 +347,9 @@ struct LinearizeVectorInsertStridedSlice final
     rewriter.replaceOp(op, dstValue);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -619,22 +622,22 @@ struct LinearizeVectorBitCast final
 /// is converted to :
 ///   %result = arith.constant dense<0.0> : vector<4x4xf32>
 ///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
-///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
-///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
-///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
+///   vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
+///   vector<4x4xf32>
 ///   ...
 /// This unrolls the 2D vector load into multiple 1D vector loads and inserts
 /// them into the result vector. The pattern currently supports only 2D vectors
-struct LinearizeVectorLoad final
-    : public OpConversionPattern<vector::LoadOp> {
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
 
   LinearizeVectorLoad(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
@@ -648,35 +651,33 @@ struct LinearizeVectorLoad final
     }
     auto unrollCount = shape[0];
     auto vecSize = shape[1];
-    auto newVecType =
-        VectorType::get({vecSize}, vecType.getElementType());
+    auto newVecType = VectorType::get({vecSize}, vecType.getElementType());
 
     llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
     Value xBaseIndex = indices[0];
 
     // Construct the 2D vector.
-    Value resultVec = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(vecType));
+    Value resultVec =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
     // Emit unrolled loads for each 1D vector slice.
     for (auto i = 0; i < unrollCount; i++) {
       Value xIndex = xBaseIndex;
       if (i) {
         auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
-        xIndex =
-            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
       }
       indices[0] = xIndex;
-      auto vec = rewriter.create<vector::LoadOp>(
-          loc, newVecType, adaptor.getBase(), indices);
-      resultVec =
-          rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+      auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+                                                 adaptor.getBase(), indices);
+      resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
     }
 
     rewriter.replaceOp(loadOp, resultVec);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
@@ -689,19 +690,19 @@ struct LinearizeVectorLoad final
 ///   %slice_1 = vector.extract %source[1] : vector<4xf32>
 ///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
 ///   ...
-/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
-/// slices from the source vector and storing them into the destination.
-/// The pattern currently supports only 2D vectors
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
 struct LinearizeVectorStore final
     : public OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
 
   LinearizeVectorStore(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
@@ -718,26 +719,26 @@ struct LinearizeVectorStore final
     llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
     Value xBaseIndex = indices[0];
 
-    auto vec = rewriter.create<vector::ShapeCastOp>(
-        loc, vecType, adaptor.getValueToStore());
+    auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+                                                    adaptor.getValueToStore());
 
     for (auto i = 0; i < unrollCount; i++) {
       auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
       Value xIndex = xBaseIndex;
       if (i) {
         auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
-        xIndex =
-            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
       }
       indices[0] = xIndex;
       rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
-                                             indices);
+                                       indices);
     }
     rewriter.eraseOp(storeOp);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the SplatOp to work on a linearized vector.
@@ -754,11 +755,11 @@ struct LinearizeVectorSplat final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorSplat(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
@@ -766,12 +767,13 @@ struct LinearizeVectorSplat final
     auto dstTy = getTypeConverter()->convertType(splatOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
-    rewriter.replaceOpWithNewOp<vector::SplatOp>(
-        splatOp, adaptor.getInput(), dstTy);
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+                                                 dstTy);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts the CreateMaskOp to work on a
@@ -789,11 +791,11 @@ struct LinearizeVectorCreateMask final
   using OpConversionPattern::OpConversionPattern;
 
   LinearizeVectorCreateMask(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
   matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
@@ -816,8 +818,9 @@ struct LinearizeVectorCreateMask final
         createMaskOp, dstTy, adaptor.getOperands().back());
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 /// This pattern converts operations implementing the RegionBranchOpInterface
@@ -835,15 +838,14 @@ struct LinearizeRegionBranchOp final
       RegionBranchOpInterface>::OpInterfaceConversionPattern;
 
   LinearizeRegionBranchOp(
-    const TypeConverter &typeConverter, MLIRContext *context,
-    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-    PatternBenefit benefit = 1)
-    : OpInterfaceConversionPattern(typeConverter, context, benefit),
-      targetVectorBitWidth(targetVectBitWidth) {}
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpInterfaceConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
-  matchAndRewrite(RegionBranchOpInterface op,
-                  ArrayRef<Value> operands,
+  matchAndRewrite(RegionBranchOpInterface op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto converter = getTypeConverter();
@@ -907,8 +909,9 @@ struct LinearizeRegionBranchOp final
     rewriter.finalizeOpModification(op);
     return success();
   }
-  private:
-    unsigned targetVectorBitWidth;
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 } // namespace
@@ -937,26 +940,25 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.addLegalOp<mlir::vector::ShapeCastOp>();
-  target.markUnknownOpDynamicallyLegal(
-      [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp, vector::LoadOp,
-                 vector::StoreOp, vector::CreateMaskOp,
-                 RegionBranchOpInterface, vector::SplatOp>(op) ||
-             op->hasTrait<OpTrait::ConstantLike>() ||
-             op->hasTrait<OpTrait::Vectorizable>())) {
-          return (isLessThanTargetBitWidth(op, targetBitWidth)
-                      ? typeConverter.isLegal(op)
-                      : true);
-        }
-        return std::nullopt;
-      });
+  target.markUnknownOpDynamicallyLegal([=](Operation *op)
+                                           -> std::optional<bool> {
+    if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp,
+             vector::CreateMaskOp, RegionBranchOpInterface, vector::SplatOp>(
+             op) ||
+         op->hasTrait<OpTrait::ConstantLike>() ||
+         op->hasTrait<OpTrait::Vectorizable>())) {
+      return (isLessThanTargetBitWidth(op, targetBitWidth)
+                  ? typeConverter.isLegal(op)
+                  : true);
+    }
+    return std::nullopt;
+  });
 
-  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast, LinearizeVectorLoad,
-               LinearizeVectorStore, LinearizeVectorSplat,
-               LinearizeVectorCreateMask, LinearizeRegionBranchOp
-               >(typeConverter, patterns.getContext(),
-                                       targetBitWidth);
+  patterns
+      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+           LinearizeVectorLoad, LinearizeVectorStore, LinearizeVectorSplat,
+           LinearizeVectorCreateMask, LinearizeRegionBranchOp>(
+          typeConverter, patterns.getContext(), targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
@@ -972,16 +974,16 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       });
 
   target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
-    [=](vector::InsertStridedSliceOp op) -> bool {
-      if(isLessThanTargetBitWidth(op, targetBitWidth)) {
-        auto srcTy = op.getSourceVectorType();
-        auto dstTy = op.getDestVectorType();
-        if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
-            srcTy.hasStaticShape() && dstTy.hasStaticShape())
-          return false;
-      }
-      return true;
-    });
+      [=](vector::InsertStridedSliceOp op) -> bool {
+        if (isLessThanTargetBitWidth(op, targetBitWidth)) {
+          auto srcTy = op.getSourceVectorType();
+          auto dstTy = op.getDestVectorType();
+          if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+              srcTy.hasStaticShape() && dstTy.hasStaticShape())
+            return false;
+        }
+        return true;
+      });
 
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
                LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,

>From 63088b4671d1bf70c42803346ca7be60b4fb3a65 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 17 Apr 2025 21:28:49 +0000
Subject: [PATCH 3/6] Remove RegionBranchOp pattern and address comments

---
 .../Vector/Transforms/VectorLinearize.cpp     | 172 +++++-------------
 mlir/test/Dialect/Vector/linearize.mlir       | 160 ----------------
 .../Dialect/Vector/TestVectorTransforms.cpp   |   2 +-
 3 files changed, 42 insertions(+), 292 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index d97eed7aea008..06ba40da3b0b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,7 +11,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -29,9 +28,9 @@ using namespace mlir;
 
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
   // For BW-0, all operations are legal
-  if (targetBitWidth == 0) {
+  if (targetBitWidth == 0)
     return false;
-  }
+
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
@@ -302,32 +301,37 @@ struct LinearizeVectorInsertStridedSlice final
         targetVectorBitWidth(targetVectBitWidth) {}
 
   LogicalResult
-  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+  matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto srcTy = op.getSourceVectorType();
-    auto dstTy = op.getDestVectorType();
+    auto loc = insertOp.getLoc();
+    auto srcTy = insertOp.getSourceVectorType();
+    auto dstTy = insertOp.getDestVectorType();
 
-    if (op.hasNonUnitStrides()) {
+    if (insertOp.hasNonUnitStrides())
       return rewriter.notifyMatchFailure(
-          op, "InsertStridedSliceOp linearization only supports unit strides.");
-    }
+          insertOp,
+          "InsertStridedSliceOp linearization only supports unit strides.");
 
-    if (srcTy.getRank() != 2) {
+    if (srcTy.getRank() != 2)
       return rewriter.notifyMatchFailure(
-          op, "InsertStridedSliceOp linearization only supports 2D source.");
-    }
+          insertOp,
+          "InsertStridedSliceOp linearization only supports 2D source.");
 
-    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
-          op, "InsertStridedSliceOp linerization only supports static shapes.");
-    }
+          insertOp,
+          "InsertStridedSliceOp linerization only supports static shapes.");
+
+    if (srcTy.isScalable() || dstTy.isScalable())
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "scalable vectors are not supported.");
 
     auto dstShape = dstTy.getShape();
     auto dstStrides = dstShape.drop_front().vec();
     dstStrides.push_back(1);
     int64_t linearizedOffset = 0;
-    for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+    for (auto [off, stride] :
+         llvm::zip_equal(insertOp.getOffsets(), dstStrides)) {
       linearizedOffset += getConstantIntValue(off).value() * stride;
     }
 
@@ -344,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final
           loc, value, dstValue, dstOffset, 1);
     }
 
-    rewriter.replaceOp(op, dstValue);
+    rewriter.replaceOp(insertOp, dstValue);
     return success();
   }
 
@@ -535,12 +539,11 @@ struct LinearizeVectorInsert final
     auto srcTy = insertOp.getValueToStoreType();
     auto srcAsVec = dyn_cast<VectorType>(srcTy);
     uint64_t srcSize = 0;
-    if (srcAsVec) {
+    if (srcAsVec)
       srcSize = srcAsVec.getNumElements();
-    } else {
+    else
       return rewriter.notifyMatchFailure(insertOp,
                                          "scalars are not supported.");
-    }
 
     auto dstShape = insertOp.getDestVectorType().getShape();
     const auto dstSize = insertOp.getDestVectorType().getNumElements();
@@ -646,9 +649,9 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
     auto vecType = loadOp.getVectorType();
     auto shape = vecType.getShape();
 
-    if (shape.size() != 2) {
+    if (shape.size() != 2)
       return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
-    }
+
     auto unrollCount = shape[0];
     auto vecSize = shape[1];
     auto newVecType = VectorType::get({vecSize}, vecType.getElementType());
@@ -711,9 +714,8 @@ struct LinearizeVectorStore final
     auto vecType = storeOp.getVectorType();
     auto shape = vecType.getShape();
 
-    if (shape.size() != 2) {
+    if (shape.size() != 2)
       return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
-    }
 
     auto unrollCount = shape[0];
     llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
@@ -823,97 +825,6 @@ struct LinearizeVectorCreateMask final
   unsigned targetVectorBitWidth;
 };
 
-/// This pattern converts operations implementing the RegionBranchOpInterface
-/// to ensure compatibility with linearized vector types. It updates the
-/// operands, result types, and region types (block arguments and yields) to
-/// match the converted types. Additionally, it processes yields within each
-/// region to ensure that the types of yielded values are compatible with the
-/// target vector bit width. If the result types of the operation are updated,
-/// shape cast operations are inserted to maintain compatibility with the
-/// original types. This pattern ensures that operations with regions are
-/// properly linearized and remain valid after type conversion.
-struct LinearizeRegionBranchOp final
-    : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
-  using OpInterfaceConversionPattern<
-      RegionBranchOpInterface>::OpInterfaceConversionPattern;
-
-  LinearizeRegionBranchOp(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
-      : OpInterfaceConversionPattern(typeConverter, context, benefit),
-        targetVectorBitWidth(targetVectBitWidth) {}
-
-  LogicalResult
-  matchAndRewrite(RegionBranchOpInterface op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto converter = getTypeConverter();
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.startOpModification(op);
-
-    llvm::SmallVector<Type> convertedTypes;
-    for (Type ty : op->getResultTypes()) {
-      convertedTypes.push_back(converter->convertType(ty));
-    }
-
-    if (convertedTypes == op->getResultTypes() &&
-        op->getOperands() == operands) {
-      return failure();
-    }
-
-    op->setOperands(operands);
-
-    // Convert region types (block arguments and yields)
-    for (Region &region : op->getRegions()) {
-      if (failed(rewriter.convertRegionTypes(&region, *converter))) {
-        return failure();
-      }
-
-      // Process yields within each region
-      for (Block &block : region) {
-        if (auto *terminator = block.getTerminator()) {
-          for (OpOperand &yieldOperand : terminator->getOpOperands()) {
-            Value value = yieldOperand.get();
-            Type type = value.getType();
-            if (!converter->isLegal(type)) {
-              Type newTy = converter->convertType(type);
-              rewriter.setInsertionPoint(terminator);
-              Value newValue =
-                  rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
-              yieldOperand.set(newValue);
-            }
-          }
-        }
-      }
-    }
-
-    // Update result types
-    rewriter.setInsertionPointAfter(op);
-    llvm::SmallVector<Value> newResults;
-    for (Value result : op->getResults()) {
-      Type oldTy = result.getType();
-      if (!converter->isLegal(oldTy)) {
-        Type newTy = converter->convertType(oldTy);
-        result.setType(newTy);
-        Operation *castOp =
-            rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
-        result.replaceAllUsesExcept(castOp->getResult(0), castOp);
-        newResults.push_back(castOp->getResult(0));
-      } else {
-        newResults.push_back(result);
-      }
-    }
-
-    rewriter.finalizeOpModification(op);
-    return success();
-  }
-
-private:
-  unsigned targetVectorBitWidth;
-};
-
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -940,25 +851,24 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.addLegalOp<mlir::vector::ShapeCastOp>();
-  target.markUnknownOpDynamicallyLegal([=](Operation *op)
-                                           -> std::optional<bool> {
-    if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp,
-             vector::CreateMaskOp, RegionBranchOpInterface, vector::SplatOp>(
-             op) ||
-         op->hasTrait<OpTrait::ConstantLike>() ||
-         op->hasTrait<OpTrait::Vectorizable>())) {
-      return (isLessThanTargetBitWidth(op, targetBitWidth)
-                  ? typeConverter.isLegal(op)
-                  : true);
-    }
-    return std::nullopt;
-  });
+  target.markUnknownOpDynamicallyLegal(
+      [=](Operation *op) -> std::optional<bool> {
+        if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp,
+                 vector::CreateMaskOp, vector::SplatOp>(op) ||
+             op->hasTrait<OpTrait::ConstantLike>() ||
+             op->hasTrait<OpTrait::Vectorizable>())) {
+          return (isLessThanTargetBitWidth(op, targetBitWidth)
+                      ? typeConverter.isLegal(op)
+                      : true);
+        }
+        return std::nullopt;
+      });
 
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorLoad, LinearizeVectorStore, LinearizeVectorSplat,
-           LinearizeVectorCreateMask, LinearizeRegionBranchOp>(
-          typeConverter, patterns.getContext(), targetBitWidth);
+           LinearizeVectorCreateMask>(typeConverter, patterns.getContext(),
+                                      targetBitWidth);
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index e47e7c4a84d68..2ea4751393ebf 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -529,54 +529,6 @@ func.func @test_create_mask() -> vector<1x16xi1> {
   return %0 : vector<1x16xi1>
 }
 
-// -----
-// ALL-LABEL: test_loop
-func.func @test_loop() -> vector<2x4xf16> {
-  // DEFAULT: %[[C0:.*]] = arith.constant 0 : index
-  // BW-128: %[[C0:.*]] = arith.constant 0 : index
-  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
-  // BW-128: %[[C1:.*]] = arith.constant 1 : index
-  // DEFAULT: %[[C4:.*]] = arith.constant 4 : index
-  // BW-128: %[[C4:.*]] = arith.constant 4 : index
-  // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16>
-  // BW-128: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf16>
-  // DEFAULT: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) {
-  // BW-128: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<8xf16>) {
-  // DEFAULT: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16>
-  // BW-128: %[[ADD:.*]] = arith.addf %[[ARG1]], %[[CST]] : vector<8xf16>
-  // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16>
-  // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ADD]] : vector<8xf16> to vector<2x4xf16>
-  // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16>
-  // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<2x4xf16> to vector<8xf16>
-  // DEFAULT: scf.yield %[[CAST1]] : vector<8xf16>
-  // BW-128: scf.yield %[[CAST1]] : vector<8xf16>
-  // DEFAULT: }
-  // BW-128: }
-  // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16>
-  // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[FOR]] : vector<8xf16> to vector<2x4xf16>
-  // DEFAULT: return %[[CAST2]] : vector<2x4xf16>
-  // BW-128: return %[[CAST2]] : vector<2x4xf16>
-
-  // BW-0: %[[C0:.*]] = arith.constant 0 : index
-  // BW-0: %[[C1:.*]] = arith.constant 1 : index
-  // BW-0: %[[C4:.*]] = arith.constant 4 : index
-  // BW-0: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf16>
-  // BW-0: %[[FOR:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ARG1:.*]] = %[[CST]]) -> (vector<2x4xf16>) {
-  // BW-0: %[[ADD:.*]] = arith.addf %[[CST]], %[[ARG1]] : vector<2x4xf16>
-  // BW-0: scf.yield %[[ADD]] : vector<2x4xf16>
-  // BW-0: }
-  // BW-0: return %[[FOR]] : vector<2x4xf16>
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
-  %1 = arith.constant dense<1.0> : vector<2x4xf16>
-  %r = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %1) -> (vector<2x4xf16>) {
-    %2 = arith.addf %1, %arg1 : vector<2x4xf16>
-    scf.yield %2 : vector<2x4xf16>
-  }
-  return %r : vector<2x4xf16>
-}
-
 // -----
 // ALL-LABEL: test_vector_insert_2d_idx
 // ALL-SAME: (%[[ARG:.*]]: vector<4x8xf16>) -> vector<8x16xf16>
@@ -606,118 +558,6 @@ func.func @test_vector_insert_2d_idx(%arg0: vector<4x8xf16>) -> vector<8x16xf16>
   return %0 : vector<8x16xf16>
 }
 
-// -----
-// ALL-LABEL: test_if_single_vector
-func.func @test_if_single_vector() -> vector<16x1xi32> {
-  // DEFAULT: %[[COND:.*]] = arith.constant false
-  // DEFAULT: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32>
-  // DEFAULT: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) {
-  // DEFAULT:   %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32>
-  // DEFAULT:   %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32>
-  // DEFAULT:   %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32>
-  // DEFAULT:   scf.yield %[[V3]] : vector<16xi32>
-  // DEFAULT: } else {
-  // DEFAULT:   %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32>
-  // DEFAULT:   %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32>
-  // DEFAULT:   %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32>
-  // DEFAULT:   scf.yield %[[V5]] : vector<16xi32>
-  // DEFAULT: }
-  // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32>
-  // DEFAULT: return %[[V1]] : vector<16x1xi32>
-
-  // BW-128: %[[COND:.*]] = arith.constant false
-  // BW-128: %[[CST:.*]] = arith.constant dense<3> : vector<16xi32>
-  // BW-128: %[[V0:.*]] = scf.if %[[COND]] -> (vector<16xi32>) {
-  // BW-128:   %[[CST_THEN:.*]] = arith.constant dense<6> : vector<16xi32>
-  // BW-128:   %[[V2:.*]] = vector.shape_cast %[[CST_THEN]] : vector<16xi32> to vector<16x1xi32>
-  // BW-128:   %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<16x1xi32> to vector<16xi32>
-  // BW-128:   scf.yield %[[V3]] : vector<16xi32>
-  // BW-128: } else {
-  // BW-128:   %[[CST_ELSE:.*]] = arith.constant dense<0> : vector<16xi32>
-  // BW-128:   %[[V4:.*]] = vector.shape_cast %[[CST_ELSE]] : vector<16xi32> to vector<16x1xi32>
-  // BW-128:   %[[V5:.*]] = vector.shape_cast %[[V4]] : vector<16x1xi32> to vector<16xi32>
-  // BW-128:   scf.yield %[[V5]] : vector<16xi32>
-  // BW-128: }
-  // BW-128: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<16xi32> to vector<16x1xi32>
-  // BW-128: return %[[V1]] : vector<16x1xi32>
-
-  // BW-0: %[[COND:.*]] = arith.constant false
-  // BW-0: %[[V:.*]] = arith.constant dense<3> : vector<16x1xi32>
-  // BW-0: %[[R:.*]] = scf.if %[[COND]] -> (vector<16x1xi32>) {
-  // BW-0:   %[[ADD:.*]] = arith.addi %[[V]], %[[V]] : vector<16x1xi32>
-  // BW-0:   scf.yield %[[ADD]] : vector<16x1xi32>
-  // BW-0: } else {
-  // BW-0:   %[[SUB:.*]] = arith.subi %[[V]], %[[V]] : vector<16x1xi32>
-  // BW-0:   scf.yield %[[SUB]] : vector<16x1xi32>
-  // BW-0: }
-  %cond = arith.constant 0 : i1
-  %v = arith.constant dense<3> : vector<16x1xi32>
-  %r = scf.if %cond -> (vector<16x1xi32>) {
-    %add = arith.addi %v, %v : vector<16x1xi32>
-    scf.yield %add : vector<16x1xi32>
-  } else {
-    %sub = arith.subi %v, %v : vector<16x1xi32>
-    scf.yield %sub : vector<16x1xi32>
-  }
-  return %r : vector<16x1xi32>
-}
-
-// -----
-// ALL-LABEL: test_while
-func.func @test_while() -> vector<2x4xf32> {
-  // DEFAULT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32>
-  // DEFAULT: %[[V0:.*]] = scf.while (%[[ARG0:.*]] = %[[CST]]) : (vector<8xf32>) -> vector<8xf32> {
-  // DEFAULT:   %[[V2:.*]] = vector.shape_cast %[[ARG0]] : vector<8xf32> to vector<2x4xf32>
-  // DEFAULT:   %[[C0:.*]] = arith.constant 0 : i32
-  // DEFAULT:   %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32
-  // DEFAULT:   %[[V4:.*]] = vector.shape_cast %[[V2]] : vector<2x4xf32> to vector<8xf32>
-  // DEFAULT:   scf.condition(%[[COND]]) %[[V4]] : vector<8xf32>
-  // DEFAULT: } do {
-  // DEFAULT: ^bb0(%[[ARG1:.*]]: vector<8xf32>):
-  // DEFAULT:   %[[V2:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<8xf32>
-  // DEFAULT:   %[[V3:.*]] = vector.shape_cast %[[V2]] : vector<8xf32> to vector<2x4xf32>
-  // DEFAULT:   %[[V4:.*]] = vector.shape_cast %[[V3]] : vector<2x4xf32> to vector<8xf32>
-  // DEFAULT:   scf.yield %[[V4]] : vector<8xf32>
-  // DEFAULT: }
-  // DEFAULT: %[[V1:.*]] = vector.shape_cast %[[V0]] : vector<8xf32> to vector<2x4xf32>
-  // DEFAULT: return %[[V1]] : vector<2x4xf32>
-
-  // BW-128: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32>
-  // BW-128: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> {
-  // BW-128:   %[[C0:.*]] = arith.constant 0 : i32
-  // BW-128:   %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32
-  // BW-128:   scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32>
-  // BW-128: } do {
-  // BW-128: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>):
-  // BW-128:   %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32>
-  // BW-128:   scf.yield %[[ADD]] : vector<2x4xf32>
-  // BW-128: }
-  // BW-128: return %[[RESULT]] : vector<2x4xf32>
-
-  // BW-0: %[[V:.*]] = arith.constant dense<1.000000e+00> : vector<2x4xf32>
-  // BW-0: %[[RESULT:.*]] = scf.while (%[[ARG0:.*]] = %[[V]]) : (vector<2x4xf32>) -> vector<2x4xf32> {
-  // BW-0:   %[[C0:.*]] = arith.constant 0 : i32
-  // BW-0:   %[[COND:.*]] = arith.cmpi slt, %[[C0]], %[[C0]] : i32
-  // BW-0:   scf.condition(%[[COND]]) %[[ARG0]] : vector<2x4xf32>
-  // BW-0: } do {
-  // BW-0: ^bb0(%[[ARG1:.*]]: vector<2x4xf32>):
-  // BW-0:   %[[ADD:.*]] = arith.addf %[[ARG1]], %[[ARG1]] : vector<2x4xf32>
-  // BW-0:   scf.yield %[[ADD]] : vector<2x4xf32>
-  // BW-0: }
-  // BW-0: return %[[RESULT]] : vector<2x4xf32>
-  %v = arith.constant dense<1.0> : vector<2x4xf32>
-  %result = scf.while (%arg0 = %v) : (vector<2x4xf32>) -> vector<2x4xf32> {
-    %c0 = arith.constant 0 : i32
-    %cond = arith.cmpi slt, %c0, %c0 : i32
-    scf.condition(%cond) %arg0 : vector<2x4xf32>
-  } do {
-  ^bb0(%arg1: vector<2x4xf32>):
-    %add = arith.addf %arg1, %arg1 : vector<2x4xf32>
-    scf.yield %add : vector<2x4xf32>
-  }
-  return %result : vector<2x4xf32>
-}
-
 // -----
 // ALL-LABEL: test_vector_splat
 // ALL-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 40b0a2321a2b2..aea116cffc3a8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -852,7 +852,7 @@ struct TestVectorLinearize final
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<vector::VectorDialect, memref::MemRefDialect,
-                    arith::ArithDialect, scf::SCFDialect>();
+                    arith::ArithDialect>();
   }
 
   Option<unsigned> targetVectorBitwidth{

>From 03789ec19d9c802f99a23a81e093c5571c9f71fd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 18 Apr 2025 17:15:49 +0000
Subject: [PATCH 4/6] Modify create_mask pattern

---
 .../Vector/Transforms/VectorLinearize.cpp     | 23 +++++++++++++++++--
 mlir/test/Dialect/Vector/linearize.mlir       | 20 ++++++++++++----
 2 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 06ba40da3b0b0..7028285c0a91d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -816,8 +816,27 @@ struct LinearizeVectorCreateMask final
     if (!dstTy)
       return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
 
-    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
-        createMaskOp, dstTy, adaptor.getOperands().back());
+    // Compare the first operand with 0. If it's less than or equal to 0,
+    // create a zero mask, else strip the first operand and create a mask
+    // using the second operand.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero =
+        rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+    auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+        createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+        zero);
+    auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+        createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+    // Use a select operation to choose between the masks.
+    auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+        createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+        createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+    auto result = rewriter.create<mlir::arith::SelectOp>(
+        createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+    rewriter.replaceOp(createMaskOp, result.getResult());
     return success();
   }
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2ea4751393ebf..f7a767dbdc272 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -515,10 +515,22 @@ func.func @test_create_mask() -> vector<1x16xi1> {
   // BW-128: %[[C0:.*]] = arith.constant 0 : index
   // DEFAULT: %[[C20:.*]] = arith.constant 20 : index
   // BW-128: %[[C20:.*]] = arith.constant 20 : index
-  // DEFAULT: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
-  // BW-128: %[[MASK:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
-  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1>
-  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[MASK]] : vector<16xi1> to vector<1x16xi1>
+  // DEFAULT: %[[C0_0:.*]] = arith.constant 0 : index
+  // BW-128: %[[C0_0:.*]] = arith.constant 0 : index
+  // DEFAULT: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // BW-128: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // DEFAULT: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+  // BW-128: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+  // DEFAULT: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+  // BW-128: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+  // DEFAULT: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // BW-128: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // DEFAULT: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+  // BW-128: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>
+  // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+  // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+  // DEFAULT: return %[[CAST]] : vector<1x16xi1>
+  // BW-128: return %[[CAST]] : vector<1x16xi1>
 
   // BW-0: %[[C0:.*]] = arith.constant 0 : index
   // BW-0: %[[C20:.*]] = arith.constant 20 : index

>From 231371c66b5f5a0ab109003ba85bffdb9d962aae Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 21 Apr 2025 21:50:40 +0000
Subject: [PATCH 5/6] Address comments

---
 .../Vector/Transforms/VectorLinearize.cpp     | 49 +++++++++----------
 1 file changed, 23 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7028285c0a91d..3b3153b787bb9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@
 
 using namespace mlir;
 
+constexpr unsigned defaultTargetVectorBitWidth =
+    std::numeric_limits<unsigned>::max();
+
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
   // For BW-0, all operations are legal
   if (targetBitWidth == 0)
@@ -86,7 +89,7 @@ struct LinearizeConstantLike final
 
   LinearizeConstantLike(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -140,7 +143,7 @@ struct LinearizeVectorizable final
 public:
   LinearizeVectorizable(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpTraitConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -179,7 +182,7 @@ struct LinearizeVectorExtractStridedSlice final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtractStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -295,7 +298,7 @@ struct LinearizeVectorInsertStridedSlice final
   using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
   LinearizeVectorInsertStridedSlice(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -317,11 +320,6 @@ struct LinearizeVectorInsertStridedSlice final
           insertOp,
           "InsertStridedSliceOp linearization only supports 2D source.");
 
-    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape())
-      return rewriter.notifyMatchFailure(
-          insertOp,
-          "InsertStridedSliceOp linerization only supports static shapes.");
-
     if (srcTy.isScalable() || dstTy.isScalable())
       return rewriter.notifyMatchFailure(insertOp,
                                          "scalable vectors are not supported.");
@@ -372,7 +370,7 @@ struct LinearizeVectorShuffle final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorShuffle(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -445,7 +443,7 @@ struct LinearizeVectorExtract final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorExtract(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -513,7 +511,7 @@ struct LinearizeVectorInsert final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorInsert(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -593,7 +591,7 @@ struct LinearizeVectorBitCast final
   using OpConversionPattern::OpConversionPattern;
   LinearizeVectorBitCast(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -618,6 +616,7 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+// clang-format off
 /// This pattern converts the LoadOp to a series of LoadOp & InsertOp
 /// that works on a linearized vector.
 /// Following,
@@ -625,20 +624,19 @@ struct LinearizeVectorBitCast final
 /// is converted to :
 ///   %result = arith.constant dense<0.0> : vector<4x4xf32>
 ///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
-///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
-///   vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
-///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
-///   vector<4x4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
 ///   ...
 /// This unrolls the 2D vector load into multiple 1D vector loads and inserts
 /// them into the result vector. The pattern currently supports only 2D vectors
+// clang-format on
 struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
 
-  LinearizeVectorLoad(
-      const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
-      PatternBenefit benefit = 1)
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
+                      PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
 
@@ -702,7 +700,7 @@ struct LinearizeVectorStore final
 
   LinearizeVectorStore(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -758,7 +756,7 @@ struct LinearizeVectorSplat final
 
   LinearizeVectorSplat(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -794,7 +792,7 @@ struct LinearizeVectorCreateMask final
 
   LinearizeVectorCreateMask(
       const TypeConverter &typeConverter, MLIRContext *context,
-      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      unsigned targetVectBitWidth = defaultTargetVectorBitWidth,
       PatternBenefit benefit = 1)
       : OpConversionPattern(typeConverter, context, benefit),
         targetVectorBitWidth(targetVectBitWidth) {}
@@ -907,8 +905,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
         if (isLessThanTargetBitWidth(op, targetBitWidth)) {
           auto srcTy = op.getSourceVectorType();
           auto dstTy = op.getDestVectorType();
-          if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
-              srcTy.hasStaticShape() && dstTy.hasStaticShape())
+          if (!op.hasNonUnitStrides() && srcTy.getRank() == 2)
             return false;
         }
         return true;

>From e3788909066b3690104b0570feb0f05bb0140526 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 23 Apr 2025 16:46:47 +0000
Subject: [PATCH 6/6] Fix formatting

---
 .../Dialect/Vector/Transforms/VectorLinearize.cpp  | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 3b3153b787bb9..ede77c6d0fa12 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -280,6 +280,7 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
+// clang-format off
 /// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
 /// source vector using ExtractStridedSliceOp and inserting them into the
 /// destination vector using InsertStridedSliceOp.
@@ -288,11 +289,14 @@ struct LinearizeVectorExtractStridedSlice final
 ///   vector<4x4xf32>
 /// is converted to :
 ///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]}
-///   : vector<4xf32> from vector<8xf32> %1 = vector.insert_strided_slice %0, %d
-///   {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> %2 =
-///   vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} :
-///   vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1
-///   {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   : vector<4xf32> from vector<8xf32>
+///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]}
+///   : vector<4xf32> into vector<16xf32>
+///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]}
+///   : vector<4xf32> from vector<8xf32>
+///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]}
+///   : vector<4xf32> into vector<16xf32>
+// clang-format on
 struct LinearizeVectorInsertStridedSlice final
     : public OpConversionPattern<vector::InsertStridedSliceOp> {
   using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;



More information about the Mlir-commits mailing list