[Mlir-commits] [mlir] 678360f - [mlir][linalg] Add scalar broadcast load case to the vectoriser

Andrzej Warzynski llvmlistbot at llvm.org
Mon Jun 12 07:19:17 PDT 2023


Author: Andrzej Warzynski
Date: 2023-06-12T15:18:42+01:00
New Revision: 678360fd9d8517a978c5f46364a56412d70e15c7

URL: https://github.com/llvm/llvm-project/commit/678360fd9d8517a978c5f46364a56412d70e15c7
DIFF: https://github.com/llvm/llvm-project/commit/678360fd9d8517a978c5f46364a56412d70e15c7.diff

LOG: [mlir][linalg] Add scalar broadcast load case to the vectoriser

This patch extends the Linalg vectoriser so that scalar loads are
correctly identified as scalar rather than gather loads. Below is an
example of a scalar load (note that both indices are loop invariant):
```
func.func @example(%arg0: tensor<80x16xf32>, %arg2: tensor<1x4xf32>) -> tensor<1x4xf32> {
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%1 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
    iterator_types = ["parallel", "parallel"]
  } outs(%arg2 : tensor<1x4xf32>) {
  ^bb0(%out: f32):
    %2 = linalg.index 0 : index
    %extracted = tensor.extract %arg0[%2, %c16] : tensor<80x16xf32>
    linalg.yield %extracted : f32
  } -> tensor<1x4xf32>
  return %1 : tensor<1x4xf32>
}
```

This patch also makes sure that these scalar loads are indeed lowered to
a scalar load followed by a broadcast:
```
    %extracted = tensor.extract %arg0[%1, %c16] : tensor<80x16xf32>
    %2 = vector.broadcast %extracted : f32 to vector<1x4xf32>
```

Differential Revision: https://reviews.llvm.org/D149678

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 90ab75cbcc910..8b70f4255224f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -729,11 +729,7 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
   return offset;
 }
 
-enum VectorMemoryAccessKind {
-  // TODO: ScalarBroadcast,
-  Contiguous,
-  Gather
-};
+enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
 /// Checks whether /p val can be used for calculating a loop invariant index.
 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
@@ -872,36 +868,57 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (inputShape.getShape().back() == 1)
     return VectorMemoryAccessKind::Gather;
 
-  bool isContiguous = true;
+  bool leadingIdxsLoopInvariant = true;
 
-  // 3a. Analyze the leading indices of `extractOp`.
+  // 3. Analyze the leading indices of `extractOp`.
   // Look at the way each index is calculated and decide whether it is suitable
   // for a contiguous load, i.e. whether it's loop invariant.
   auto indices = extractOp.getIndices();
-  auto leadIndices = ValueRange(indices.drop_back(1));
+  auto leadIndices = indices.drop_back(1);
 
   for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
     if (inputShape.getShape()[i] == 1)
       continue;
 
-    isContiguous &= isLoopInvariantIdx(linalgOp, indexVal);
+    leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
+  }
+
+  if (!leadingIdxsLoopInvariant) {
+    LDBG("Found gather load: " << extractOp);
+    return VectorMemoryAccessKind::Gather;
   }
 
-  // 3b. Analyze the trailing index for `extractOp`.
+  // 4. Analyze the trailing index for `extractOp`.
+  // At this point we know that the leading indices are loop invariant. This
+  // means that is potentially a scalar or a contiguous load. We can decide
+  // based on the trailing idx.
   auto extractOpTrailingIdx = indices.back();
-  // For contiguous loads, the trailing `extractOp` index should increment with
-  // every loop iteration. This effectively means that it must be based on the
-  // trailing loop index. This is what the following bool captures.
+
+  // 4a. Scalar broadcast load
+  // If the trailing index is loop invariant then this is a scalar load.
+  if (leadingIdxsLoopInvariant &&
+      isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+    LDBG("Found scalar broadcast load: " << extractOp);
+
+    return VectorMemoryAccessKind::ScalarBroadcast;
+  }
+
+  // 4b. Contiguous loads
+  // The trailing `extractOp` index should increment with every loop iteration.
+  // This effectively means that it must be based on the trailing loop index.
+  // This is what the following bool captures.
   bool foundIndexOp = false;
-  isContiguous &=
+  bool isContiguousLoad =
       isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
-  isContiguous &= foundIndexOp;
+  isContiguousLoad &= foundIndexOp;
 
-  if (isContiguous) {
+  if (isContiguousLoad) {
     LDBG("Found contigous load: " << extractOp);
     return VectorMemoryAccessKind::Contiguous;
   }
 
+  // 5. Fallback case - gather load.
+  LDBG("Found gather load: " << extractOp);
   return VectorMemoryAccessKind::Gather;
 }
 
@@ -948,16 +965,14 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
         maskConstantOp, passThruConstantOp);
     gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
-    LDBG("Vectorised as gather load: " << extractOp);
+    LDBG("Vectorised as gather load: " << extractOp << "\n");
     return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
   }
 
-  // 2. Handle contiguous access.
-  LDBG("Vectorised as contiguous load: " << extractOp);
-  SmallVector<Value> transferReadIdxs;
-  auto resTrailingDim = resultType.getShape().back();
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
+  // 2. Handle:
+  //  a. scalar loads + broadcast,
+  //  b. contiguous loads.
+  // Both cases use vector.transfer_read.
 
   // Collect indices for `vector.transfer_read`. At this point, the indices will
   // either be scalars or would have been broadcast to vectors matching the
@@ -972,6 +987,10 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for scalar indices - just re-use it,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
+  SmallVector<Value> transferReadIdxs;
+  auto resTrailingDim = resultType.getShape().back();
+  auto zero = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     auto idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -988,10 +1007,24 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
   auto dstRank = resultType.getRank();
+  auto srcRank = extractOp.getTensor().getType().getRank();
   SmallVector<bool> inBounds(dstRank, true);
 
-  // Create a permutation map for transfer_read Op.
-  auto srcRank = extractOp.getTensor().getType().getRank();
+  // 2a. Handle scalar broadcast access.
+  if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
+    MLIRContext *ctx = rewriter.getContext();
+    SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
+    auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
+
+    auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+        loc, resultType, extractOp.getTensor(), transferReadIdxs,
+        permutationMap, inBounds);
+
+    LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
+    return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+  }
+
+  // 2b. Handle contiguous access.
   auto permutationMap = AffineMap::getMinorIdentityMap(
       srcRank, std::min(dstRank, srcRank), rewriter.getContext());
 
@@ -1012,6 +1045,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
       loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
       inBounds);
+
+  LDBG("Vectorised as contiguous load: " << extractOp);
   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
 }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 12f4d962532df..5047545f0b2cb 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -51,17 +51,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
   return %2 : tensor<1x1x3xf32>
 }
 
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx
-// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
-// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK:    %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
-// CHECK:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
-// CHECK:    %[[C0:.*]] = arith.constant 0 : index
-// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2)
-// CHECK:    %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex>
-// CHECK:    %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
-// CHECK:    vector.transfer_write %[[GATHER]]
-// CHECK:  }
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_constant_idx(
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<3x3xf32>,
+// CHECK-SAME:      %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]]{{\[}}%[[C1]], %[[C2]]] : tensor<3x3xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_write %[[BCAST]], %[[ARG_1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK:           return %[[VAL_7]] : tensor<1x1x3xf32>
+// CHECK:         }
 
 transform.sequence failures(propagate) {
  ^bb1(%arg1: !transform.any_op):
@@ -316,44 +316,43 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
   return %1 : tensor<1x1x4xf32>
 }
 
-// First `tensor.extract` is a loop invariant scalar load. This way, the
-// following `tensor.extract` Op becomes a contiguous load (all other Ops used
-// for address calculation also satisfy the required conditions).
-// TODO: Don't use vector.gather for the first tensor.extract.
-
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_tensor_extract(
 // CHECK-SAME:    %[[VAL_0:.*]]: tensor<1x20xi32>,
 // CHECK-SAME:    %[[VAL_1:.*]]: tensor<257x24xf32>,
-// CHECK-SAME:     -> tensor<1x1x4xf32> {
+// CHECK-SAME:       %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> {
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
-// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_14:.*]] = tensor.empty() : tensor<1x1x4xf32>
-// CHECK:           %[[VAL_15:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_18:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex>
 // CHECK:           %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_21:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
-// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_23:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_10]]] {{\[}}%[[VAL_17]]], %[[VAL_8]], %[[VAL_9]] : tensor<1x20xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
-// CHECK:           %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : vector<1x1x4xi32> to vector<1x1x4xindex>
-// CHECK:           %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_11]] : vector<1x1x4xindex>
-// CHECK:           %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_29:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_28]], %[[VAL_30]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32>
-// CHECK:           %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
-// CHECK:           return %[[VAL_33]] : tensor<1x1x4xf32>
+// CHECK:           %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex>
+// First `tensor.extract` from the generic Op - loop invariant scalar load.
+// CHECK:           %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32>
+// CHECK:           %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index
+// CHECK:           %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex>
+// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
+// for address calculation also satisfy the required conditions).
+// CHECK:           %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32>
+// CHECK:           %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
+// CHECK:           return %[[VAL_34]] : tensor<1x1x4xf32>
 // CHECK:         }
 
 transform.sequence failures(propagate) {


        


More information about the Mlir-commits mailing list