[Mlir-commits] [mlir] a3ae393 - [mlir][linalg] Refine `tensor.extract` vectorisation

Andrzej Warzynski llvmlistbot at llvm.org
Fri Apr 21 00:48:08 PDT 2023


Author: Andrzej Warzynski
Date: 2023-04-21T08:47:55+01:00
New Revision: a3ae3931d4e0c227a2d88fa827dbe536adda586a

URL: https://github.com/llvm/llvm-project/commit/a3ae3931d4e0c227a2d88fa827dbe536adda586a
DIFF: https://github.com/llvm/llvm-project/commit/a3ae3931d4e0c227a2d88fa827dbe536adda586a.diff

LOG: [mlir][linalg] Refine `tensor.extract` vectorisation

This patch updates the vectorisation of the extract Op so that the
permutation map for the transfer_read Op is defined explicitly by the
vectoriser (as opposed to being constructed implicitly by the
transfer_read builder).

This change is needed for cases where the rank of the source tensor is
lower than the rank of the output vector generated by the vectoriser:
```mlir
    %17 = vector.transfer_read %arg1[%14, %16], %cst_4 {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x1x4xf32>
```
In cases like this, the vectorize will create the following permutation map:
```
  (d0, d1) -> (0, d0, d1)
```

In other cases the behaviour remains unchanged.

Fixes https://github.com/openxla/iree/issues/13036. That's also where
the test case was extracted from.

Differential Revision: https://reviews.llvm.org/D148537

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1bc6c42545ed9..c7776f49b31b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -953,6 +953,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   }
 
   // 2. Handle contiguous access.
+  LDBG("Vectorised as contiguous load: " << extractOp);
   SmallVector<Value> transferReadIdxs;
   auto resTrailingDim = resultType.getShape().back();
   auto zero = rewriter.create<arith::ConstantOp>(
@@ -986,12 +987,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
-  SmallVector<bool> inBounds(resultType.getRank(), true);
+  auto dstRank = resultType.getRank();
+  SmallVector<bool> inBounds(dstRank, true);
+
+  // Create a permutation map for transfer_read Op.
+  auto srcRank = extractOp.getTensor().getType().getRank();
+  auto permutationMap = AffineMap::getMinorIdentityMap(
+      srcRank, std::min(dstRank, srcRank), rewriter.getContext());
+
+  int32_t rankDiff = dstRank - srcRank;
+  // When dstRank > srcRank, broadcast the source tensor to the unitary leading
+  // dims so that the ranks match. This is done by extending the map with 0s.
+  // For example, for dstRank = 3, srcRank = 2, the following map created
+  // above:
+  //    (d0, d1) --> (d0, d1)
+  // is extended as:
+  //    (d0, d1) --> (0, d0, d1)
+  while (rankDiff > 0) {
+    permutationMap = permutationMap.insertResult(
+        mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
+    rankDiff--;
+  }
 
   auto transferReadOp = rewriter.create<vector::TransferReadOp>(
-      loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
-
-  LDBG("Vectorised as contiguous load: " << extractOp);
+      loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
+      inBounds);
   return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
 }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d54a2f57617cb..910c815a7902d 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1834,6 +1834,76 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20xi32>, %input_2: tensor<257x24xf32>, %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<1x1x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c256 = arith.constant 256 : index
+  %output = tensor.empty() : tensor<1x1x4xf32>
+  %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%output : tensor<1x1x4xf32>) {
+  ^bb0(%out: f32):
+    %13 = linalg.index 0 : index
+    %14 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%arg0, %13, %arg2)
+    %15 = linalg.index 2 : index
+    %16 = linalg.index 1 : index
+    %17 = affine.apply affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 24 + d2 + d3)>(%arg1, %16, %15, %arg3)
+    %extracted_0 = tensor.extract %input_1[%c0, %14] : tensor<1x20xi32>
+    %18 = arith.index_cast %extracted_0 : i32 to index
+    %19 = arith.maxsi %18, %c0 : index
+    %20 = arith.minsi %19, %c256 : index
+    %extracted_1 = tensor.extract %input_2[%20, %17] : tensor<257x24xf32>
+    linalg.yield %extracted_1 : f32
+  } -> tensor<1x1x4xf32>
+  return %1 : tensor<1x1x4xf32>
+}
+
+// First `tensor.extract` is a loop invariant scalar load. This way, the
+// following `tensor.extract` Op becomes a contiguous load (all other Ops used
+// for address calculation also satisfy the required conditions).
+// TODO: Don't use vector.gather for the first tensor.extract.
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_tensor_extract(
+// CHECK-SAME:    %[[VAL_0:.*]]: tensor<1x20xi32>,
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<257x24xf32>,
+// CHECK-SAME:     -> tensor<1x1x4xf32> {
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
+// CHECK-DAG:       %[[VAL_12:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_14:.*]] = tensor.empty() : tensor<1x1x4xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_18:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_23:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_10]]] {{\[}}%[[VAL_17]]], %[[VAL_8]], %[[VAL_9]] : tensor<1x20xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32>
+// CHECK:           %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : vector<1x1x4xi32> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_11]] : vector<1x1x4xindex>
+// CHECK:           %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_29:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_28]], %[[VAL_30]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32>
+// CHECK:           %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
+// CHECK:           return %[[VAL_33]] : tensor<1x1x4xf32>
+// CHECK:         }
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
+// -----
+
 func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
   %c79 = arith.constant 79 : index
   %1 = linalg.generic {
@@ -1918,7 +1988,7 @@ transform.sequence failures(propagate) {
  ^bb1(%arg1: !pdl.operation):
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
    transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract }
- }
+}
 
 // -----
 


        


More information about the Mlir-commits mailing list