[Mlir-commits] [mlir] [mlir][vector] Add more patterns to Vector Linearize transformation (PR #136193)

Ivan Butygin llvmlistbot at llvm.org
Thu Apr 17 13:47:26 PDT 2025


================
@@ -531,12 +615,312 @@ struct LinearizeVectorBitCast final
   unsigned targetVectorBitWidth;
 };
 
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %result = arith.constant dense<0.0> : vector<4x4xf32>
+///   %slice_0 = vector.load %base[%indices] : vector<4xf32>
+///   %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
+///   vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+///   %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
+///   vector<4x4xf32>
+///   ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+  LinearizeVectorLoad(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = loadOp->getLoc();
+    auto vecType = loadOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+    auto unrollCount = shape[0];
+    auto vecSize = shape[1];
+    auto newVecType = VectorType::get({vecSize}, vecType.getElementType());
+
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    // Construct the 2D vector.
+    Value resultVec =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
+    // Emit unrolled loads for each 1D vector slice.
+    for (auto i = 0; i < unrollCount; i++) {
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+                                                 adaptor.getBase(), indices);
+      resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+    }
+
+    rewriter.replaceOp(loadOp, resultVec);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+///   vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+///   %slice_0 = vector.extract %source[0] : vector<4xf32>
+///   vector.store %slice_0, %base[%indices] : vector<4xf32>
+///   %slice_1 = vector.extract %source[1] : vector<4xf32>
+///   vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+///   ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+  LinearizeVectorStore(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = storeOp->getLoc();
+    auto vecType = storeOp.getVectorType();
+    auto shape = vecType.getShape();
+
+    if (shape.size() != 2) {
+      return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+    }
+
+    auto unrollCount = shape[0];
+    llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+    Value xBaseIndex = indices[0];
+
+    auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+                                                    adaptor.getValueToStore());
+
+    for (auto i = 0; i < unrollCount; i++) {
+      auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+      Value xIndex = xBaseIndex;
+      if (i) {
+        auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+      }
+      indices[0] = xIndex;
+      rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+                                       indices);
+    }
+    rewriter.eraseOp(storeOp);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+///   vector.splat %value : vector<4x4xf32>
+/// is converted to:
+///   %out_1d = vector.splat %value : vector<16xf32>
+///   %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+    : public OpConversionPattern<vector::SplatOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorSplat(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+    rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+                                                 dstTy);
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        createMaskOp, dstTy, adaptor.getOperands().back());
+    return success();
+  }
+
+private:
+  unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
----------------
Hardcode84 wrote:

Instead of adding this pattern, can you test `populateSCFStructuralTypeConversionsAndLegality`?

https://github.com/llvm/llvm-project/pull/136193


More information about the Mlir-commits mailing list