[Mlir-commits] [mlir] [mlir][linalg] Refine tensor.extract vectorization (PR #99299)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jul 17 03:08:48 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/99299

None

>From 6966f4ecc7f879b6ca1c66c71e16d818ae26ec13 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 17 Jul 2024 11:08:12 +0100
Subject: [PATCH] [mlir][linalg] Refine tensor.extract vectorization

---
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 .../Linalg/vectorize-tensor-extract.mlir      | 57 +++++++++++++++++++
 2 files changed, 58 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 68ee915cca3f4..577b853658dc0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -949,7 +949,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   // 2. Assume that it's a gather load when reading _from_ a tensor for which
   // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
   // TODO: Relax this condition.
-  if (inputShape.getShape().back() == 1)
+  if (inputShape.getShape().back() == 1 && targetShape.back() == 1)
     return VectorMemoryAccessKind::Gather;
 
   bool leadingIdxsLoopInvariant = true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 85e1c56dd45a0..49f3a053df460 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -595,3 +595,60 @@ module attributes {transform.with_named_sequence} {
      transform.yield
    }
 }
+
+
+// -----
+
+func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<[[0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
+
+  %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
+  ^bb0(%out: i32):
+    %8 = linalg.index 0 : index
+    %idx_0 = linalg.index 0 : index
+    %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<30x1xi32>
+    linalg.yield %extracted : i32
+  } -> tensor<1x1x4xi32>
+
+  return %out:tensor<1x1x4xi32>
+}
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL:   func.func @vectorize_scalar_broadcast_column_tensor(
+// CHECK-SAME:                                                        %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_10:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.step : vector<1xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK:           %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK:           %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_19:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<30x1xi32>, vector<1x1x4xi32>
+// CHECK:           %[[VAL_24:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+// CHECK:           return %[[VAL_25]] : tensor<1x1x4xi32>
+// CHECK:         }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0  vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list