[Mlir-commits] [mlir] Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer (PR #86163)

Crefeda Rodrigues llvmlistbot at llvm.org
Fri Mar 22 07:15:30 PDT 2024


================
@@ -201,12 +205,19 @@ struct TransferWritePermutationLowering
     // Generate new transfer_write operation.
     Value newVec = rewriter.create<vector::TransposeOp>(
         op.getLoc(), op.getVector(), indices);
+
+    auto vectorType = cast<VectorType>(newVec.getType());
+
+    if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
+      rewriter.eraseOp(newVec.getDefiningOp());
+      return failure();
----------------
cfRod wrote:

Previously it was: TransferWriteNonPermutationLowering called first and so the two vector.broadcast are added and then a transpose for the mask and then TransferWritePermutationLowering is called to add the transpose for the input
  ```
* Pattern (anonymous namespace)::TransferWriteNonPermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWriteNonPermutationLowering"
    ** Insert  : 'vector.broadcast'(0xac685f872c10)
    ** Insert  : 'vector.broadcast'(0xac685f879e40)
    ** Insert  : 'vector.transpose'(0xac685f879ed0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::vector::detail::TransferWriteOpGenericAdaptorBase::Properties)
    ** Insert  : 'vector.transfer_write'(0xac685f7e1d00)
    ** Replace : 'vector.transfer_write'(0xac685f7b1eb0)
    ** Erase   : 'vector.transfer_write'(0xac685f7b1eb0)
"(anonymous namespace)::TransferWriteNonPermutationLowering" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
  %c0 = arith.constant 0 : index
  %0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
  %1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
  %2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
  return
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'vector.transfer_write'(0xac685f7e1d00) {
  "vector.transfer_write"(%1, %arg1, %0, %0, %0, %0, %0, %0, %0, %3) <{in_bounds = [true, true, true, true, true, true], operandSegmentSizes = array<i32: 1, 1, 7, 1>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>}> : (vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>, index, index, index, index, index, index, index, vector<4x[8]x1x1x1x1xi1>) -> ()

ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::VectorType>::Impl<Empty>)

  * Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
    ** Insert  : 'vector.transpose'(0xac685f87b6b0)
    ** Insert  : 'vector.transfer_write'(0xac685f7b1eb0)
    ** Replace : 'vector.transfer_write'(0xac685f7e1d00)
    ** Erase   : 'vector.transfer_write'(0xac685f7e1d00)
"(anonymous namespace)::TransferWritePermutationLowering" result 1
  } -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
  %c0 = arith.constant 0 : index
  %0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
  %1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
  %2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
  %3 = vector.transpose %0, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
  vector.transfer_write %3, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
  return
}
```
After this patch, the second transpose is "erased"

https://github.com/llvm/llvm-project/pull/86163


More information about the Mlir-commits mailing list