[Mlir-commits] [mlir] a63853e - [mlir] Broadcast scalars when vectorising tensor.extract

Andrzej Warzynski llvmlistbot at llvm.org
Thu Jan 12 08:34:45 PST 2023


Author: Andrzej Warzynski
Date: 2023-01-12T16:34:11Z
New Revision: a63853e6acc594d33b2fd96036e804d9e4a53f76

URL: https://github.com/llvm/llvm-project/commit/a63853e6acc594d33b2fd96036e804d9e4a53f76
DIFF: https://github.com/llvm/llvm-project/commit/a63853e6acc594d33b2fd96036e804d9e4a53f76.diff

LOG: [mlir] Broadcast scalars when vectorising tensor.extract

When vectorizing tensor.extract embedded within linalg.generic, the
default option is to rewrite it as vector.gather. When doing so, we need
to make sure that the corresponding indices are vectorized accordingly.
However, the Linalg vectorizer will not vectorize constants like in the
following example. This is fixed by simply broadcasting %c0 and %c1.

```
  func.func @example(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
    %c0 = arith.constant 1 : index
    %c1 = arith.constant 2 : index
    %1 = linalg.generic {
      (...)
    } outs(...) {
    ^bb0(...):
      %2 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
      linalg.yield %2 : f32
    } -> tensor<1x1x3xf32>
    return %1 : tensor<1x1x3xf32>
  }
```

This patch makes sure that in the case above (and other similar cases),
the vectorizer broadcasts %c0 and %c1.

Differential Revision: https://reviews.llvm.org/D140781

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a17663b101912..58d746088ad91 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -626,23 +626,19 @@ calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
-    auto dimSizeBcast = b.create<vector::BroadcastOp>(
-        loc, indexVecType,
+    auto dimSize = broadcastIfNeeded(
+        b,
         b.create<arith::ConstantIndexOp>(
             loc,
-            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)));
-    offset = b.create<arith::MulIOp>(loc, offset, dimSizeBcast);
-
-    auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]);
-    if (i == numIndices - 1) {
-      // We only need an additional broadcast for the trailing index. All other
-      // indices have already been broadcast by `vectorizeLinalgIndex` to match
-      // the output size.
-      originalIndexBcast = b.create<vector::BroadcastOp>(
-          loc, indexVecType, bvm.lookup(extractOp.getIndices()[i]));
-    }
+            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
+        indexVecType.getShape());
+
+    offset = b.create<arith::MulIOp>(loc, offset, dimSize);
+
+    auto extractOpIndex = broadcastIfNeeded(
+        b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
 
-    offset = b.create<arith::AddIOp>(loc, originalIndexBcast, offset);
+    offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
   }
 
   return offset;

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 9de78468f3e6e..2c7d34066a4bc 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1494,10 +1494,83 @@ transform.sequence failures(propagate) {
 
 // -----
 
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+  %c0 = arith.constant 1 : index
+  %c1 = arith.constant 2 : index
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<1x1x3xf32>) {
+  ^bb0(%arg4: f32):
+    %3 = linalg.index 2 : index
+    %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
+    linalg.yield %7 : f32
+  } -> tensor<1x1x3xf32>
+  return %2 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
+// CHECK:    %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK:    %[[C0:.*]] = arith.constant 0 : index
+// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2)
+// CHECK:    %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex>
+// CHECK:    %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK:    vector.transfer_write %[[GATHER]]
+// CHECK:  }
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+  %1 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<1x1x3xf32>) {
+  ^bb0(%arg4: f32):
+    %2 = linalg.index 0 : index
+    %3 = linalg.index 1 : index
+    %4 = linalg.index 2 : index
+    %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
+    linalg.yield %5 : f32
+  } -> tensor<1x1x3xf32>
+  return %1 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
+// CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+// CHECK:   %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK:   %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
+// CHECK:   %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK:   vector.transfer_write %[[GATHER]]
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+}
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
   %2 = linalg.generic {
     indexing_maps = [#map0, #map0, #map1, #map2],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -1510,7 +1583,7 @@ func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3
   } -> tensor<4x7x3x2xf32>
   return %2 : tensor<4x7x3x2xf32>
 }
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_index_from_tensor
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
 // CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32>
 // CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32>


        


More information about the Mlir-commits mailing list