[Mlir-commits] [mlir] Fix unsupported transpose ops scalable (PR #86163)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 21 11:06:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Crefeda Rodrigues (cfRod)
<details>
<summary>Changes</summary>
Addresses comment in https://github.com/llvm/llvm-project/pull/85632#discussion_r1530727576
---
Full diff: https://github.com/llvm/llvm-project/pull/86163.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+15-4)
- (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+17)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fcfb6edaf..570a5222862b72 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
- VectorType newVecType =
- VectorType::get(newShape, originalVecType.getElementType());
+
+ SmallVector<bool> newScalableDims(addedRank, false);
+ newScalableDims.append(originalVecType.getScalableDims().begin(),
+ originalVecType.getScalableDims().end());
+ VectorType newVecType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
@@ -201,12 +205,19 @@ struct TransferWritePermutationLowering
// Generate new transfer_write operation.
Value newVec = rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getVector(), indices);
+
+ auto vectorType = cast<VectorType>(newVec.getType());
+
+ if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
+ rewriter.eraseOp(newVec.getDefiningOp());
+ return failure();
+ }
+
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
op.getMask(), newInBoundsAttr);
-
return success();
}
};
@@ -269,7 +280,7 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.size());
// Mask: add unit dims at the end of the shape.
Value newMask;
- if (op.getMask())
+ if (op.getMask() && !op.getVectorType().isScalable())
newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
missingInnerDim.size());
exprs.append(map.getResults().begin(), map.getResults().end());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 13e07f59a72a77..83a7f21daf683f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -41,6 +41,23 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}
+// CHECK-LABEL: func.func @permutation_with_mask_transfer_write_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
+// CHECK: vector.transfer_write %[[BCAST]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true, true], permutation_map = #map} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+// CHECK: return
+// CHECK: }
+ func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
+} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
+
+ return
+ }
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
``````````
</details>
https://github.com/llvm/llvm-project/pull/86163
More information about the Mlir-commits
mailing list