[Mlir-commits] [mlir] [mlr][vector] Add more patterns to Vector Linearize transformation (PR #136193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 17 13:26:08 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>

This PR adds linearization patterns for vector.load, vector.store, vector.create_mask, vector.splat, vector.insert_strided_slice & RegionBranchOps. This is because SPIR-V only supports 1D vectors. 

---

Patch is 40.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136193.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+404-3) 
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+335) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a009aa03aaf64..6de5d0c5a101e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -27,6 +28,10 @@
 using namespace mlir;
 
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  // For BW-0, all operations are legal
+  if (targetBitWidth == 0) {
+    return false;
+  }
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
     VectorType vecType = dyn_cast<VectorType>(resType);
@@ -273,6 +278,77 @@ struct LinearizeVectorExtractStridedSlice final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern linearizes the InsertStridedSliceOp by extracting rows from the
+/// source vector using ExtractStridedSliceOp and inserting them into the
+/// destination vector using InsertStridedSliceOp.
+/// Following,
+///   vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// is converted to :
+///   %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
+///   %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
+///   %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+struct LinearizeVectorInsertStridedSlice final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern<
+      vector::InsertStridedSliceOp>::OpConversionPattern;
+      LinearizeVectorInsertStridedSlice(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto srcTy = op.getSourceVectorType();
+    auto dstTy = op.getDestVectorType();
+
+    if (op.hasNonUnitStrides()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports unit strides.");
+    }
+
+    if (srcTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linearization only supports 2D source.");
+    }
+
+    if (!srcTy.hasStaticShape() || !dstTy.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op, "InsertStridedSliceOp linerization only supports static shapes.");
+    }
+
+    auto dstShape = dstTy.getShape();
+    auto dstStrides = dstShape.drop_front().vec();
+    dstStrides.push_back(1);
+    int64_t linearizedOffset = 0;
+    for (auto [off, stride] : llvm::zip_equal(op.getOffsets(), dstStrides)) {
+      linearizedOffset += getConstantIntValue(off).value() * stride;
+    }
+
+    // extracts a row from source, and insert it into the destination
+    auto srcShape = srcTy.getShape();
+    Value dstValue = adaptor.getDest();
+    for (auto i = 0; i < srcShape[0]; i++) {
+      auto srcOffset = i * srcShape[1];
+      auto value = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, adaptor.getValueToStore(), srcOffset, srcShape[1], 1);
+
+      auto dstOffset = linearizedOffset + i * dstShape.back();
+      dstValue = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, value, dstValue, dstOffset, 1);
+    }
+
+    rewriter.replaceOp(op, dstValue);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 /// This pattern converts the ShuffleOp that works on nD (n > 1)
 /// vectors to a ShuffleOp that works on linearized vectors.
 /// Following,
@@ -369,6 +445,11 @@ struct LinearizeVectorExtract final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Skip if result is not a vector type
+    if (!isa<VectorType>(extractOp.getType()))
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalar extract is not supported.");
+
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     if (!dstTy)
       return rewriter.notifyMatchFailure(extractOp,
@@ -531,12 +612,312 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %result = arith.constant dense<0.0> : vector<4x4xf32>
+///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
+///   %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+///   ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final
+    : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+  LinearizeVectorLoad(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = loadOp->getLoc();
+    auto vecType = loadOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+    auto unrollCount = shape[0];
+    auto vecSize = shape[1];
+    auto newVecType =
+        VectorType::get({vecSize}, vecType.getElementType());
+
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    // Construct the 2D vector.
+    Value resultVec = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(vecType));
+    // Emit unrolled loads for each 1D vector slice.
+    for (auto i = 0; i < unrollCount; i++) {
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      auto vec = rewriter.create<vector::LoadOp>(
+          loc, newVecType, adaptor.getBase(), indices);
+      resultVec =
+          rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+    }
+
+    rewriter.replaceOp(loadOp, resultVec);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %slice_0 = vector.extract %source[0] : vector<4xf32>
+///   vector.store %slice_0, %base[%indices] : vector<4xf32>
+///   %slice_1 = vector.extract %source[1] : vector<4xf32>
+///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+///   ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
+/// slices from the source vector and storing them into the destination.
+/// The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+  LinearizeVectorStore(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = storeOp->getLoc();
+    auto vecType = storeOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+
+    auto unrollCount = shape[0];
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    auto vec = rewriter.create<vector::ShapeCastOp>(
+        loc, vecType, adaptor.getValueToStore());
+
+    for (auto i = 0; i < unrollCount; i++) {
+      auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex =
+            rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+                                             indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(
+        splatOp, adaptor.getInput(), dstTy);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        createMaskOp, dstTy, adaptor.getOperands().back());
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
+    : public OpInterfaceConversionPattern<RegionBranchOpInterface> {
+  using OpInterfaceConversionPattern<
+      RegionBranchOpInterface>::OpInterfaceConversionPattern;
+
+  LinearizeRegionBranchOp(
+    const TypeConverter &typeConverter, MLIRContext *context,
+    unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1)
+    : OpInterfaceConversionPattern(typeConverter, context, benefit),
+      targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(RegionBranchOpInterface op,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto converter = getTypeConverter();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.startOpModification(op);
+
+    llvm::SmallVector<Type> convertedTypes;
+    for (Type ty : op->getResultTypes()) {
+      convertedTypes.push_back(converter->convertType(ty));
+    }
+
+    if (convertedTypes == op->getResultTypes() &&
+        op->getOperands() == operands) {
+      return failure();
+    }
+
+    op->setOperands(operands);
+
+    // Convert region types (block arguments and yields)
+    for (Region &region : op->getRegions()) {
+      if (failed(rewriter.convertRegionTypes(&region, *converter))) {
+        return failure();
+      }
+
+      // Process yields within each region
+      for (Block &block : region) {
+        if (auto *terminator = block.getTerminator()) {
+          for (OpOperand &yieldOperand : terminator->getOpOperands()) {
+            Value value = yieldOperand.get();
+            Type type = value.getType();
+            if (!converter->isLegal(type)) {
+              Type newTy = converter->convertType(type);
+              rewriter.setInsertionPoint(terminator);
+              Value newValue =
+                  rewriter.create<vector::ShapeCastOp>(loc, newTy, value);
+              yieldOperand.set(newValue);
+            }
+          }
+        }
+      }
+    }
+
+    // Update result types
+    rewriter.setInsertionPointAfter(op);
+    llvm::SmallVector<Value> newResults;
+    for (Value result : op->getResults()) {
+      Type oldTy = result.getType();
+      if (!converter->isLegal(oldTy)) {
+        Type newTy = converter->convertType(oldTy);
+        result.setType(newTy);
+        Operation *castOp =
+            rewriter.create<vector::ShapeCastOp>(loc, oldTy, result);
+        result.replaceAllUsesExcept(castOp->getResult(0), castOp);
+        newResults.push_back(castOp->getResult(0));
+      } else {
+        newResults.push_back(result);
+      }
+    }
+
+    rewriter.finalizeOpModification(op);
+    return success();
+  }
+  private:
+    unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth) {
 
+  typeConverter.addConversion([](Type type) -> Type { return type; });
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     if (!isLinearizableVector(type))
       return type;
@@ -555,9 +936,12 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   };
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
+  target.addLegalOp<mlir::vector::ShapeCastOp>();
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
-        if ((isa<vector::BitCastOp>(op) ||
+        if ((isa<vector::BitCastOp, vector::LoadOp,
+                 vector::StoreOp, vector::CreateMaskOp,
+                 RegionBranchOpInterface, vector::SplatOp>(op) ||
              op->hasTrait<OpTrait::ConstantLike>() ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -568,7 +952,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
       });
 
   patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast>(typeConverter, patterns.getContext(),
+               LinearizeVectorBitCast, LinearizeVectorLoad,
+               LinearizeVectorStore, LinearizeVectorSplat,
+               LinearizeVectorCreateMask, LinearizeRegionBranchOp
+               >(typeConverter, patterns.getContext(),
                                        targetBitWidth);
 }
 
@@ -583,7 +970,21 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
                               .getRank() == 1)
                    : true;
       });
+
+  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+    [=](vector::InsertStridedSliceOp op) -> bool {
+      if(isLessThanTargetBitWidth(op, targetBitWidth)) {
+        auto srcTy = op.getSourceVectorType();
+        auto dstTy = op.getDestVectorType();
+        if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+            srcTy.hasStaticShape() && dstTy.hasStaticShape())
+          return false;
+      }
+      return true;
+    });
+
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+               LinearizeVectorInsertStridedSlice>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9052c6440e6ac..e47e7c4a84d68 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -399,3 +399,338 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
   %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
   return %1 : vector<[4]x4xf16>
 }
+
+// -----
+// ALL-LABEL: test_vector_load
+// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>)
+func.func @test_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
+  // DEFAULT: %[[C1:.*]] = arith.constant 1 : index
+  // BW-128: %[[C1:.*]] = arith.constant 1 : index
+  // DEFAULT: %[[C2:.*]] = arith.constant 2 : index
+  // BW-128: %[[C2:.*]] = arith.constant 2 : index
+...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/136193


More information about the Mlir-commits mailing list