[Mlir-commits] [mlir] [mlir][vector] Determine vector sizes from the result shape in the ca… (PR #88249)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 10 03:01:06 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Prashant Kumar (pashu123)
<details>
<summary>Changes</summary>
…se of tensor pack
When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.
---
Full diff: https://github.com/llvm/llvm-project/pull/88249.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-4)
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+34)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 25785653a71675..422fc0562f9003 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1525,6 +1525,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
(void)status; // prevent unused variable warning on non-assert builds.
assert(succeeded(status) && "failed to reify result shapes");
+ ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
+
+ // If the input vector sizes are not provided, then the vector sizes are
+ // determined by the result tensor shape.
+ if (inputVectorSizes.empty()) {
+ // Make sure that the result tensor shape is static.
+ if (ShapedType::isDynamicShape(resultTensorShape))
+ return failure();
+ inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
+ }
+
// Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1763,7 +1774,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
/// 1. The numbers of elements in both array are equal.
-/// 2. `inputVectorSizes` does nos have dynamic dimensions.
+/// 2. `inputVectorSizes` does not have dynamic dimensions.
/// 3. All the values in `inputVectorSizes` are greater than or equal to
/// static sizes in `shape`.
static LogicalResult
@@ -1881,18 +1892,19 @@ static LogicalResult vectorizeLinalgOpPrecondition(
return success();
}
-/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
- if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
+ Attribute cstAttr;
+ if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
- if (failed(isValidMaskedInputVector(
+ if (!inputVectorSizes.empty() &&
+ failed(isValidMaskedInputVector(
resultTensorShape.take_front(packOp.getSourceRank()),
inputVectorSizes)))
return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2d01d57304013c..f354ab9ea0b0a3 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -812,3 +812,37 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
transform.yield
}
}
+
+ // -----
+
+// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
+func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pad = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
+ return %pack : tensor<32x4x1x16x2xf32>
+}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
+// CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
+// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
+// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
+// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88249
More information about the Mlir-commits
mailing list