[Mlir-commits] [mlir] f9070b2 - [mlir][vector] Enable CastAwayElementwiseLeadingOneDim for scalable vec

Andrzej Warzynski llvmlistbot at llvm.org
Tue Aug 22 04:42:21 PDT 2023


Author: Andrzej Warzynski
Date: 2023-08-22T11:40:46Z
New Revision: f9070b2dfbec1a337213a8c8901cb98cd0c09bef

URL: https://github.com/llvm/llvm-project/commit/f9070b2dfbec1a337213a8c8901cb98cd0c09bef
DIFF: https://github.com/llvm/llvm-project/commit/f9070b2dfbec1a337213a8c8901cb98cd0c09bef.diff

LOG: [mlir][vector] Enable CastAwayElementwiseLeadingOneDim for scalable vec

This patch effectively enables the CastAwayElementwiseLeadingOneDim
rewrite pattern for scalable vectors. To this end,
`ExtractOp::inferReturnTypes` is updated so that scalable dimensions are
correctly recognised.

The change to ExtractOp will likely make also other conversion patterns
valid for scalable vectors, but this patch focuses on just one case.
Other conversion patterns will be enabled in the forthcoming patches.

Depends on D157993

Differential Revision: https://reviews.llvm.org/D158335

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1ad22cdf9788c1..fbf81cf2b79e70 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1151,7 +1151,8 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
     auto n =
         std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
     inferredReturnTypes.push_back(VectorType::get(
-        vectorType.getShape().drop_front(n), vectorType.getElementType()));
+        vectorType.getShape().drop_front(n), vectorType.getElementType(),
+        vectorType.getScalableDims().drop_front(n)));
   }
   return success();
 }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index b1cd5c4c0f6f1e..913c826dd91247 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -417,6 +417,18 @@ struct CastAwayContractionLeadingOneDim
   }
 };
 
+/// Looks at elementwise operations on vectors with at least one leading
+/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
+/// and cast aways the leading one dimensions (_plural_) and then broadcasts
+/// the results.
+///
+/// Example before:
+///     %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
+/// Example after:
+///    %2 = arith.mulf %0, %1 : vector<4x1xf32>
+///    %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
+///
+/// Does support scalable vectors.
 class CastAwayElementwiseLeadingOneDim : public RewritePattern {
 public:
   CastAwayElementwiseLeadingOneDim(MLIRContext *context,

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 0ee006e3df632e..59f31cf7246528 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -276,6 +276,30 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3
   return %0: vector<1x1x4xf32>
 }
 
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_scalable(
+// CHECK-SAME:    %[[S:.*]]: f32,
+// CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
+func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x[4]xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
+  %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
+  return %0: vector<1x1x[4]xf32>
+}
+
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
+// CHECK-SAME:    %[[S:.*]]: f32,
+// CHECK-SAME:    %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
+func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<1x[1]x4xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
+// CHECK:           return %[[BCAST]] : vector<1x[1]x4xf32>
+  %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
+  return %0: vector<1x[1]x4xf32>
+}
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
 //  CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
@@ -285,6 +309,16 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector
   return %0: vector<1x1x4xf32>
 }
 
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank1_scalable(
+// CHECK-SAME:    %[[S:.*]]: vector<[4]xf32>,
+// CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
+func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
+  %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
+  return %0: vector<1x1x[4]xf32>
+}
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -295,6 +329,17 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
   return %0: vector<1x1x4xf32>
 }
 
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_scalable(
+// CHECK-SAME:    %[[S:.*]]: vector<1x[4]xf32>,
+// CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
+func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
+  %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
+  return %0: vector<1x1x[4]xf32>
+}
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
 //       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -307,6 +352,19 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>,
   return %0: vector<1x2x1x4xf32>
 }
 
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
+// CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
+// CHECK-SAME:      %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
+// CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
+// CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<1x2x1x[4]xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
+// CHECK:           return %[[BCAST]] : vector<1x2x1x[4]xf32>
+func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
+  %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
+  return %0: vector<1x2x1x[4]xf32>
+}
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
@@ -317,6 +375,17 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
   return %0: vector<8x1x4xf32>
 }
 
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
+// CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
+// CHECK-SAME:      %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
+// CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<1x[4]xf32>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
+// CHECK:           return %[[INSERT]] : vector<8x1x[4]xf32>
+func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
+  %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
+  return %0: vector<8x1x[4]xf32>
+}
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
 //       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1>
@@ -328,3 +397,16 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v
   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
   return %0: vector<1x1x8x1x8xi1>
 }
+
+// CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
+// CHECK-SAME:      %[[S:.*]]: vector<1x[8]xi1>,
+// CHECK-SAME:      %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
+// CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<1x[8]xi1>
+// CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x[8]xi1>
+// CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
+// CHECK:           return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
+func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
+  %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
+  return %0: vector<1x1x8x1x[8]xi1>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 4d65405fb110e1..dfc564ca6fe483 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -34,6 +34,22 @@ func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) ->
   return %1 : vector<2x[4]x1xf32>
 }
 
+// CHECK-LABEL:   func.func @cast_away_leading_one_dim(
+// CHECK:           %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32>
+// CHECK:           vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32>
+func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> {
+  %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
+  return %1: vector<1x4x1xf32>
+}
+
+// CHECK-LABEL:   func.func @cast_away_leading_one_dim_scalable(
+// CHECK:           %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32>
+// CHECK:           vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32>
+func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> {
+  %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32>
+  return %1: vector<1x[4]x1xf32>
+}
+
 // CHECK-LABEL: func @add4x4
 //      CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
 // CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>


        


More information about the Mlir-commits mailing list