[Mlir-commits] [mlir] Add support for static unpack op vectorization without providing inpu… (PR #89067)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 05:57:00 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Prashant Kumar (pashu123)
<details>
<summary>Changes</summary>
…t vector size
In case, the vector sizes are not provided for the vectorization of tensor.unpack op, the vector sizes are determined by the result shape. This also assumes that the input and output shapes are static.
---
Full diff: https://github.com/llvm/llvm-project/pull/89067.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+20-3)
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+23)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df61381432921b..92d2d129ff749c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1597,6 +1597,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ // If the input vector sizes are not provided, then the vector sizes are
+ // determined by the result tensor shape. In case the vector sizes aren't
+ // provided, we update the inBounds attribute instead of masking.
+ bool doMasking = true;
+ if (inputVectorSizes.empty()) {
+ ArrayRef<int64_t> resultTensorShape = unpackOp.getDestType().getShape();
+ inputVectorSizes = resultTensorShape.take_front(unpackOp.getSourceRank());
+ doMasking = false;
+ }
+
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
@@ -1651,7 +1661,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
// to shape of source, then a mask is necessary.
Value readResult = createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
- ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
+ ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
+ doMasking);
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1827,8 +1838,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
return failure();
}
- llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
- if (!inputVectorSizes.empty() &&
+ ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+ bool satisfyEmptyCond = true;
+ if (inputVectorSizes.empty()) {
+ if (!unpackOp.getDestType().hasStaticShape() ||
+ !unpackOp.getSourceType().hasStaticShape())
+ satisfyEmptyCond = false;
+ }
+ if (!satisfyEmptyCond &&
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 80a5a4c6702ac1..5a81853973906b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -985,3 +985,26 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+ // -----
+
+func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C00:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+ }
``````````
</details>
https://github.com/llvm/llvm-project/pull/89067
More information about the Mlir-commits
mailing list