[Mlir-commits] [mlir] [mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support (PR #97297)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 1 07:17:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: xiaohui1.xu (BRUCE11111)
<details>
<summary>Changes</summary>
---
Patch is 26.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97297.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-2)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+334)
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+192)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..e0fd5f1b14070 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
- target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+ tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
+ tensor::ConcatOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..7a4db82749fd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
return success();
}
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+ Operation *inputOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(inputOp);
+ auto src = inputOp->getOperand(0);
+ auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+ auto result = inputOp->getResults()[0];
+ auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ Location loc = inputOp->getLoc();
+
+ llvm::SmallVector<int64_t> srcVectorizedShape;
+ llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+ auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+ ArrayRef<int64_t> &inputShape) {
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+ int64_t cur = 1, resultIdx = 0;
+ for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+ cur *= ss;
+ if (!isResultShapeBigger) {
+ // collapse
+ srcVectorizedShape.emplace_back(ss);
+ if (cur == retShape[resultIdx]) {
+ if (shapeScales.count(resultIdx)) {
+ srcVectorizedShape.back() *= shapeScales[resultIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ } else {
+ // expand
+ if (cur == retShape[resultIdx]) {
+ srcVectorizedShape.emplace_back(cur);
+ if (shapeScales.count(srcIdx)) {
+ srcVectorizedShape.back() *= shapeScales[srcIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ }
+ }
+ };
+ if (!inputVectorSizes.empty()) {
+ for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+ if (vs != resultShape[idx])
+ shapeScales[idx] = vs / resultShape[idx];
+ }
+
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+ if (!isResultShapeBigger) {
+ getVectorizeShape(resultShape, srcShape);
+ } else {
+ getVectorizeShape(srcShape, resultShape);
+ }
+ } else {
+ srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+ }
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, src,
+ inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+ padValue, false);
+
+ auto shapeCastType =
+ VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+ resultType.getElementType());
+ vector::ShapeCastOp shapeCastOp =
+ rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+ // write
+ SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape) {
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ }
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, shapeCastOp->getResults()[0], destSizes,
+ inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
+/// Vectorize a `tensor::bitcast` to these 3 Ops:
+/// vector::TransferReadOp - Reads a vector from the source tensor
+/// vector.Bitcast - Bitcast the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
+ tensor::BitcastOp bitCastOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(bitCastOp);
+
+ auto sourceType = bitCastOp.getSource().getType();
+ auto resultType = bitCastOp.getResult().getType();
+ auto resultShape = resultType.getShape();
+ if (inputVectorSizes.empty()) {
+ inputVectorSizes = resultShape;
+ }
+ Location loc = bitCastOp->getLoc();
+
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(sourceType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);
+
+ // bitcast
+ auto resultVectorType =
+ VectorType::get(inputVectorSizes, resultType.getElementType());
+ vector::BitCastOp vectorbitCastOp =
+ rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);
+
+ // write
+ llvm::SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape)
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ auto write =
+ createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
+ destSizes, inputVectorSizes, false);
+ newResults.push_back(write->getResults()[0]);
+ return success();
+}
+
+/// Vectorize a `tensor::concat` to these 3 Ops:
+/// Tensor::EmptyOp - The result tensor.
+/// Vector::TransferWriteOp - Write the result vector back to the destination
+/// tensor.
+/// Vector::TransferWriteOp - Write the result vector back to the destination
+/// tensor.
+static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
+ tensor::ConcatOp concatOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(concatOp);
+
+ Location loc = concatOp.getLoc();
+ FailureOr<Value> dest =
+ tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+ if (failed(dest))
+ return failure();
+
+ auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+ if (!empty)
+ return failure();
+
+ // Compute the partial sums for the slice offsets.
+ auto dim = concatOp.getDim();
+ Value dimValue =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
+
+ int64_t rank = concatOp.getResultType().getRank();
+ auto srcType =
+ mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+
+ // Construct the chain of insert_slice ops into the destination.
+ Value result = *dest;
+ Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {
+
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ SmallVector<int64_t> readMaskShape;
+ auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
+ auto sourceShape = inputType.getShape();
+
+ readMaskShape.append(sourceShape.begin(), sourceShape.end());
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, input, sourceShape, padValue, false);
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(rank, zero);
+ indices[dim] = previous_offset;
+ result = rewriter
+ .create<vector::TransferWriteOp>(
+ loc, readResult, result, indices,
+ rewriter.getMultiDimIdentityMap(rank))
+ ->getResults()[0];
+ if (idx != concatOp.getNumOperands() - 1) {
+ auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
+ previous_offset =
+ rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
+ }
+ }
+
+ newResults.push_back(result);
+ return success();
+}
+
// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
+static LogicalResult
+lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = expandOp->getResultTypes()[0];
+ auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
+ // check reassociation
+ llvm::SmallVector<int64_t> associateIndices;
+ for (auto &attr : expandOp.getReassociation()) {
+ for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
+ associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
+ }
+ }
+
+ if (llvm::any_of(associateIndices,
+ [](int64_t x) { return x == ShapedType::kDynamic; })) {
+ LDBG("Reassociation must be static: " << expandOp << "\n");
+ return failure();
+ }
+ // check input and output shape
+ if (!resultShape.hasStaticShape() ||
+ !expandOp.getSrcType().hasStaticShape()) {
+ LDBG("Input and output shape must be static: " << expandOp << "\n");
+ return failure();
+ }
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShape.getShape(),
+ inputVectorSizes)))
+ return failure();
+
+ return success();
+}
+
+static LogicalResult
+lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = bitCastOp->getResultTypes()[0];
+ auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+ auto srcType = bitCastOp.getSource().getType();
+ auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);
+
+ bool isStaticInputOutput =
+ resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+ if (!isStaticInputOutput) {
+ LDBG("Input and output shape must be static: " << bitCastOp << "\n");
+ return failure();
+ }
+
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+ inputVectorSizes)))
+ return failure();
+ return success();
+}
+
+static LogicalResult
+lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = collapseOp->getResultTypes()[0];
+ auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+ auto srcShapeType = collapseOp.getSrcType();
+
+ bool isStaticInputOutput =
+ resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+ if (!isStaticInputOutput) {
+ LDBG("Input and output shape must be static: " << collapseOp << "\n");
+ return failure();
+ }
+
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+ inputVectorSizes)))
+ return failure();
+ return success();
+}
+
+static LogicalResult
+lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ if (!inputVectorSizes.empty()) {
+ LDBG("Concat operation do not support specify inputVectorSizes: "
+ << concatOp << "\n");
+ }
+ for (auto x : concatOp->getOperands()) {
+ auto type = mlir::dyn_cast<ShapedType>(x.getType());
+ if (!type) {
+ LDBG("Operation type error: " << concatOp << "\n");
+ return failure();
+ }
+ if (!type.hasStaticShape()) {
+ LDBG("Type must be static: " << concatOp << "\n");
+ return failure();
+ }
+ }
+ auto dim = concatOp.getDim();
+ if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
+ LDBG("Invalid dim: " << concatOp << "\n");
+ return failure();
+ }
+
+ return success();
+}
+
/// Preconditions for scalable vectors.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
+ .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+ return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
+ })
+ .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+ return lowerCollapseShapeOpPrecondition(collapseShapeOp,
+ inputVectorSizes);
+ })
+ .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+ return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
+ })
+ .Case<tensor::ConcatOp>([&](auto concatOp) {
+ return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
})
+ .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+ return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+ return lowerTensorReshape(rewriter, collapseShapeOp,
+ inputVectorSizes, results);
+ })
+ .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+ return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::ConcatOp>([&](auto concatOp) {
+ return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
+ results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..114815b4e3de8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
transform.yield
}
}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape
+func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+ // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+ %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+ return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+ // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+ %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+ return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_expandshape
+func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32>
+ // CHECK: %[[C01:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x3...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/97297
More information about the Mlir-commits
mailing list