[Mlir-commits] [mlir] d09bef8 - [MLIR] Vectorize tensor.extract on 1-d tensor

Diego Caballero llvmlistbot at llvm.org
Mon Oct 17 17:06:37 PDT 2022

Author: Che-Yu Wu
Date: 2022-10-18T00:06:02Z
New Revision: d09bef82c0709c5755ce33da20532481b7da2245

URL: https://github.com/llvm/llvm-project/commit/d09bef82c0709c5755ce33da20532481b7da2245
DIFF: https://github.com/llvm/llvm-project/commit/d09bef82c0709c5755ce33da20532481b7da2245.diff

LOG: [MLIR] Vectorize tensor.extract on 1-d tensor

This patch implements the vectorization of tensor.extract for the
basic 1-d lookup case. It only vectorizes the tensor.extract to a
vector.gather when the op extracts value from an 1-d tensor.

Related discussion: https://github.com/iree-org/iree/issues/9198

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D133786




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2b70155b24887..a435c10833676 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -232,6 +232,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
   return Value();
+// Custom vectorization precondition function type. This is intented to be used
+// with CustomVectorizationHook. Returns success if the correpsonding custom
+// hook can vectorize the op.
+using CustomVectorizationPrecondition =
+    std::function<LogicalResult(Operation *)>;
 // Custom vectorization function type. Produce a vector form of Operation*
 // assuming all its vectorized operands are already in the BlockAndValueMapping.
 // Return nullptr if the Operation cannot be vectorized.
@@ -300,6 +306,69 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+/// Helper function to check if the tensor.extract can be vectorized by the
+/// custom hook vectorizeTensorExtract.
+static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
+  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
+  if (!extractOp)
+    return failure();
+  // Currently only supports extraction with an 1-D index.
+  if (extractOp.getIndices().size() != 1)
+    return failure();
+  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
+    return failure();
+  if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
+        return !VectorType::isValidElementType(type);
+      })) {
+    return failure();
+  }
+  return success();
+/// Helper function to vectorize the tensor.extract operations. Returns
+/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// should map the produced operations. This function is meant to be used as a
+/// CustomVectorizationHook.
+static VectorizationResult
+vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp,
+                       const BlockAndValueMapping &bvm) {
+  tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
+  if (!extractOp)
+    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+  auto loc = extractOp.getLoc();
+  // Currently only supports extraction with an 1-D index. Checked in the
+  // tensorExtractVectorizationPrecondition.
+  assert(extractOp.getIndices().size() == 1);
+  auto indexVec = bvm.lookup(extractOp.getIndices()[0]);
+  // Compute the static loop sizes of the extract op.
+  auto targetShape = linalgOp.computeStaticLoopSizes();
+  SmallVector<Value> gatherIndices;
+  gatherIndices.push_back(b.create<arith::ConstantIndexOp>(loc, 0));
+  auto maskConstantOp = b.create<arith::ConstantOp>(
+      loc,
+      DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()),
+                                /*value=*/true));
+  auto resultType =
+      VectorType::get(targetShape, extractOp.getResult().getType());
+  auto passThruConstantOp =
+      b.create<arith::ConstantOp>(loc, b.getZeroAttr(resultType));
+  auto gatherOp = b.create<vector::GatherOp>(
+      loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
+      maskConstantOp, passThruConstantOp);
+  return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
 /// Emit reduction operations if the shapes of the value to reduce is 
diff erent
 /// that the result shape.
 static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
@@ -515,6 +584,14 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
+  // 4c. Register CustomVectorizationHook for extractOp.
+  CustomVectorizationHook vectorizeExtract =
+      [&](Operation *op,
+          const BlockAndValueMapping &bvm) -> VectorizationResult {
+    return vectorizeTensorExtract(b, op, linalgOp, bvm);
+  };
+  hooks.push_back(vectorizeExtract);
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block->getOperations()) {
     VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
@@ -552,9 +629,20 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   return success();
-static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult vectorizeStaticLinalgOpPrecondition(
+    linalg::LinalgOp op,
+    ArrayRef<CustomVectorizationPrecondition> customPreconditions) {
   // All types in the body should be a supported element type for VectorType.
   for (Operation &innerOp : op->getRegion(0).front()) {
+    // Check if any custom hook can vectorize the inner op.
+    if (llvm::any_of(
+            customPreconditions,
+            [&](const CustomVectorizationPrecondition &customPrecondition) {
+              return succeeded(customPrecondition(&innerOp));
+            })) {
+      continue;
+    }
     if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
           return !VectorType::isValidElementType(type);
         })) {
@@ -566,16 +654,8 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
       return failure();
-  if (isElementwise(op)) {
-    // Some operations in the body cannot be vectorized.
-    for (Operation &payloadOp : *op.getBlock()) {
-      if (isa<tensor::ExtractOp>(payloadOp)) {
-        LDBG("precondition failed: `tensor.extract` not vectorizable");
-        return failure();
-      }
-    }
+  if (isElementwise(op))
     return success();
-  }
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
   // But we will still need stride/dilation attributes that will be annoying to
   // reverse-engineer...
@@ -601,7 +681,13 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
     LDBG("precondition failed: dynamic shape");
     return failure();
-  return vectorizeStaticLinalgOpPrecondition(linalgOp);
+  SmallVector<CustomVectorizationPrecondition> customPreconditions;
+  // Register CustomVectorizationPrecondition for extractOp.
+  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
+  return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions);
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 119c3db644c55..aba2d5f5cd49f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -161,8 +161,8 @@ bool hasOnlyScalarElementwiseOp(Region &r) {
   if (!llvm::hasSingleElement(r))
     return false;
   for (Operation &op : r.front()) {
-    if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
-              linalg::IndexOp>(op) ||
+    if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
+              linalg::YieldOp, linalg::IndexOp>(op) ||
           OpTrait::hasElementwiseMappableTraits(&op)) ||
                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index f3735be9a3a5d..8a83c3137e6a0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1457,3 +1457,68 @@ transform.sequence failures(propagate) {
   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
   %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
+// -----
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @vectorize_1d_tensor_extract(%arg0: tensor<3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x7x2xf32>, %arg3: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+  %2 = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%arg1, %arg2 : tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg3 : tensor<4x7x3x2xf32>) {
+  ^bb0(%arg4: i32, %arg5: f32, %arg6: f32):
+    %3 = arith.index_cast %arg4 : i32 to index
+    %7 = tensor.extract %arg0[%3] : tensor<3xf32>
+    linalg.yield %7 : f32
+  } -> tensor<4x7x3x2xf32>
+  return %2 : tensor<4x7x3x2xf32>
+// CHECK-LABEL: func.func @vectorize_1d_tensor_extract
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<3xf32>
+// CHECK-SAME:    %[[ARG1:.*]]: tensor<4x3xi32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
+// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
+// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]]
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]]
+// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST]]
+// CHECK: %[[INDICES:.*]] = vector.transpose %[[BROADCAST]]
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASSTHRU]]
+// CHECK: vector.transfer_write %[[GATHER]]
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1
+// -----
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+  %2 = linalg.generic {
+    indexing_maps = [#map0, #map0, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } ins(%arg1, %arg2, %arg3 : tensor<4x3xi32>, tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg4 : tensor<4x7x3x2xf32>) {
+  ^bb0(%arg5: i32, %arg6: i32, %arg7: f32, %arg8: f32):
+    %3 = arith.index_cast %arg5 : i32 to index
+    %4 = arith.index_cast %arg6 : i32 to index
+    %7 = tensor.extract %arg0[%3, %4] : tensor<3x3xf32>
+    linalg.yield %7 : f32
+  } -> tensor<4x7x3x2xf32>
+  return %2 : tensor<4x7x3x2xf32>
+// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract
+// CHECK: tensor.extract
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1


More information about the Mlir-commits mailing list