[Mlir-commits] [mlir] Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer (PR #86163)
Crefeda Rodrigues
llvmlistbot at llvm.org
Fri Mar 22 07:15:30 PDT 2024
================
@@ -201,12 +205,19 @@ struct TransferWritePermutationLowering
// Generate new transfer_write operation.
Value newVec = rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getVector(), indices);
+
+ auto vectorType = cast<VectorType>(newVec.getType());
+
+ if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
+ rewriter.eraseOp(newVec.getDefiningOp());
+ return failure();
----------------
cfRod wrote:
Previously it was: TransferWriteNonPermutationLowering called first and so the two vector.broadcast are added and then a transpose for the mask and then TransferWritePermutationLowering is called to add the transpose for the input
```
* Pattern (anonymous namespace)::TransferWriteNonPermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWriteNonPermutationLowering"
** Insert : 'vector.broadcast'(0xac685f872c10)
** Insert : 'vector.broadcast'(0xac685f879e40)
** Insert : 'vector.transpose'(0xac685f879ed0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::vector::detail::TransferWriteOpGenericAdaptorBase::Properties)
** Insert : 'vector.transfer_write'(0xac685f7e1d00)
** Replace : 'vector.transfer_write'(0xac685f7b1eb0)
** Erase : 'vector.transfer_write'(0xac685f7b1eb0)
"(anonymous namespace)::TransferWriteNonPermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
return
}
} -> success : pattern matched
//===-------------------------------------------===//
//===-------------------------------------------===//
Processing operation : 'vector.transfer_write'(0xac685f7e1d00) {
"vector.transfer_write"(%1, %arg1, %0, %0, %0, %0, %0, %0, %0, %3) <{in_bounds = [true, true, true, true, true, true], operandSegmentSizes = array<i32: 1, 1, 7, 1>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>}> : (vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>, index, index, index, index, index, index, index, vector<4x[8]x1x1x1x1xi1>) -> ()
ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::VectorType>::Impl<Empty>)
* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
** Insert : 'vector.transpose'(0xac685f87b6b0)
** Insert : 'vector.transfer_write'(0xac685f7b1eb0)
** Replace : 'vector.transfer_write'(0xac685f7e1d00)
** Erase : 'vector.transfer_write'(0xac685f7e1d00)
"(anonymous namespace)::TransferWritePermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
%3 = vector.transpose %0, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
vector.transfer_write %3, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
return
}
```
After this patch, the second transpose is "erased"
https://github.com/llvm/llvm-project/pull/86163
More information about the Mlir-commits
mailing list