[Mlir-commits] [mlir] [mlir][linalg] Vectorize unpack op without masking (PR #89067)

Prashant Kumar llvmlistbot at llvm.org
Thu Apr 25 04:25:45 PDT 2024


https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/89067

>From 5bc4819be1470cab8625cb7dfcf29f2000cb99e6 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 17 Apr 2024 08:54:28 -0400
Subject: [PATCH] [mlir] Vectorize unpack op given no vector sizes

Enables vectorization of unpack op in the case of unknown vector size.
The vector sizes are determined by the result shape.
---
 .../Linalg/Transforms/Vectorization.cpp       | 23 ++++++++++++++++---
 mlir/test/Dialect/Linalg/vectorization.mlir   | 23 +++++++++++++++++++
 2 files changed, 43 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index e836f0dc63b4f9..b0f386beee2c99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1558,6 +1558,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
 
   RankedTensorType unpackTensorType = unpackOp.getSourceType();
 
+  // If the input vector sizes are not provided, then the vector sizes are
+  // determined by the result tensor shape. In case the vector sizes aren't
+  // provided, we update the inBounds attribute instead of masking.
+  bool useInBoundsInsteadOfMasking = true;
+  if (inputVectorSizes.empty()) {
+    ArrayRef<int64_t> resultTensorShape = unpackOp.getDestType().getShape();
+    inputVectorSizes = resultTensorShape.take_front(unpackOp.getSourceRank());
+    useInBoundsInsteadOfMasking = false;
+  }
+
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
 
@@ -1612,7 +1622,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
   // to shape of source, then a mask is necessary.
   Value readResult = vector::createReadOrMaskedRead(
       rewriter, loc, unpackOp.getSource(),
-      ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
+      ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
+      doMasking);
 
   PackingMetadata packMetadata;
   SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1753,8 +1764,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
     return failure();
   }
-  llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
-  if (!inputVectorSizes.empty() &&
+  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+  bool satisfyEmptyCond = true;
+  if (inputVectorSizes.empty()) {
+    if (!unpackOp.getDestType().hasStaticShape() ||
+        !unpackOp.getSourceType().hasStaticShape())
+      satisfyEmptyCond = false;
+  }
+  if (!satisfyEmptyCond &&
       failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
     return failure();
 
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 80a5a4c6702ac1..5a81853973906b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -985,3 +985,26 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+  // -----
+
+func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+  // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+  // CHECK: %[[C00:.*]] = arith.constant 0 : index
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+ }



More information about the Mlir-commits mailing list