[Mlir-commits] [mlir] a3ae393 - [mlir][linalg] Refine `tensor.extract` vectorisation
Andrzej Warzynski
llvmlistbot at llvm.org
Fri Apr 21 00:48:08 PDT 2023
Author: Andrzej Warzynski
Date: 2023-04-21T08:47:55+01:00
New Revision: a3ae3931d4e0c227a2d88fa827dbe536adda586a
URL: https://github.com/llvm/llvm-project/commit/a3ae3931d4e0c227a2d88fa827dbe536adda586a
DIFF: https://github.com/llvm/llvm-project/commit/a3ae3931d4e0c227a2d88fa827dbe536adda586a.diff
LOG: [mlir][linalg] Refine `tensor.extract` vectorisation
This patch updates the vectorisation of the extract Op so that the
permutation map for the transfer_read Op is defined explicitly by the
vectoriser (as opposed to being constructed implicitly by the
transfer_read builder).
This change is needed for cases where the rank of the source tensor is
lower than the rank of the output vector generated by the vectoriser:
```mlir
%17 = vector.transfer_read %arg1[%14, %16], %cst_4 {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x1x4xf32>
```
In cases like this, the vectorize will create the following permutation map:
```
(d0, d1) -> (0, d0, d1)
```
In other cases the behaviour remains unchanged.
Fixes https://github.com/openxla/iree/issues/13036. That's also where
the test case was extracted from.
Differential Revision: https://reviews.llvm.org/D148537
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1bc6c42545ed9..c7776f49b31b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -953,6 +953,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
}
// 2. Handle contiguous access.
+ LDBG("Vectorised as contiguous load: " << extractOp);
SmallVector<Value> transferReadIdxs;
auto resTrailingDim = resultType.getShape().back();
auto zero = rewriter.create<arith::ConstantOp>(
@@ -986,12 +987,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
}
// `tensor.extract_element` is always in-bounds, hence the following holds.
- SmallVector<bool> inBounds(resultType.getRank(), true);
+ auto dstRank = resultType.getRank();
+ SmallVector<bool> inBounds(dstRank, true);
+
+ // Create a permutation map for transfer_read Op.
+ auto srcRank = extractOp.getTensor().getType().getRank();
+ auto permutationMap = AffineMap::getMinorIdentityMap(
+ srcRank, std::min(dstRank, srcRank), rewriter.getContext());
+
+ int32_t rankDiff = dstRank - srcRank;
+ // When dstRank > srcRank, broadcast the source tensor to the unitary leading
+ // dims so that the ranks match. This is done by extending the map with 0s.
+ // For example, for dstRank = 3, srcRank = 2, the following map created
+ // above:
+ // (d0, d1) --> (d0, d1)
+ // is extended as:
+ // (d0, d1) --> (0, d0, d1)
+ while (rankDiff > 0) {
+ permutationMap = permutationMap.insertResult(
+ mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
+ rankDiff--;
+ }
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
-
- LDBG("Vectorised as contiguous load: " << extractOp);
+ loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
+ inBounds);
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d54a2f57617cb..910c815a7902d 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1834,6 +1834,76 @@ transform.sequence failures(propagate) {
// -----
+func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20xi32>, %input_2: tensor<257x24xf32>, %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<1x1x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c256 = arith.constant 256 : index
+ %output = tensor.empty() : tensor<1x1x4xf32>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%output : tensor<1x1x4xf32>) {
+ ^bb0(%out: f32):
+ %13 = linalg.index 0 : index
+ %14 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%arg0, %13, %arg2)
+ %15 = linalg.index 2 : index
+ %16 = linalg.index 1 : index
+ %17 = affine.apply affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 24 + d2 + d3)>(%arg1, %16, %15, %arg3)
+ %extracted_0 = tensor.extract %input_1[%c0, %14] : tensor<1x20xi32>
+ %18 = arith.index_cast %extracted_0 : i32 to index
+ %19 = arith.maxsi %18, %c0 : index
+ %20 = arith.minsi %19, %c256 : index
+ %extracted_1 = tensor.extract %input_2[%20, %17] : tensor<257x24xf32>
+ linalg.yield %extracted_1 : f32
+ } -> tensor<1x1x4xf32>
+ return %1 : tensor<1x1x4xf32>
+}
+
+// First `tensor.extract` is a loop invariant scalar load. This way, the
+// following `tensor.extract` Op becomes a contiguous load (all other Ops used
+// for address calculation also satisfy the required conditions).
+// TODO: Don't use vector.gather for the first tensor.extract.
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>,
+// CHECK-SAME: -> tensor<1x1x4xf32> {
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1x1x4xf32>
+// CHECK: %[[VAL_15:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex>
+// CHECK: %[[VAL_18:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
+// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
+// CHECK: %[[VAL_21:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : vector<1x1x4xindex>
+// CHECK: %[[VAL_23:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_10]]] {{\[}}%[[VAL_17]]], %[[VAL_8]], %[[VAL_9]] : tensor<1x20xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
+// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : vector<1x1x4xi32> to vector<1x1x4xindex>
+// CHECK: %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex>
+// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_11]] : vector<1x1x4xindex>
+// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK: %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex>
+// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK: %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex>
+// CHECK: %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_28]], %[[VAL_30]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32>
+// CHECK: %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
+// CHECK: return %[[VAL_33]] : tensor<1x1x4xf32>
+// CHECK: }
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
+// -----
+
func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
%c79 = arith.constant 79 : index
%1 = linalg.generic {
@@ -1918,7 +1988,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract }
- }
+}
// -----
More information about the Mlir-commits
mailing list