[Mlir-commits] [mlir] c8557c7 - [MLIR][Vector] Enable masked vectorizaton of contraction ops
Diego Caballero
llvmlistbot at llvm.org
Fri Apr 21 12:26:39 PDT 2023
Author: Diego Caballero
Date: 2023-04-21T19:19:01Z
New Revision: c8557c7c3e98dddc0abb28913e06c8dae60a6c01
URL: https://github.com/llvm/llvm-project/commit/c8557c7c3e98dddc0abb28913e06c8dae60a6c01
DIFF: https://github.com/llvm/llvm-project/commit/c8557c7c3e98dddc0abb28913e06c8dae60a6c01.diff
LOG: [MLIR][Vector] Enable masked vectorizaton of contraction ops
This patch enables the vectorization of contraction ops using vector
masking. Support for vectorizing contractions is already there so this
is just adding contraction ops to the list of supported ops in
`vectorizeDynamicLinalgOpPrecondition` and adding a test.
Reviewed By: hanchung, awarzynski
Differential Revision: https://reviews.llvm.org/D148865
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c7776f49b31b5..55beb2f87cd78 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1303,7 +1303,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
// TODO: Masking only supports dynamic generic ops for now.
- if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp>(op))
+ if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp,
+ linalg::ContractionOpInterface>(op.getOperation()))
return failure();
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 910c815a7902d..f2f6fedabd492 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -2832,7 +2832,7 @@ transform.sequence failures(propagate) {
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
- %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
+ %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
-> tensor<2x4xf32>
{
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
@@ -2841,9 +2841,9 @@ func.func @test_masked_vectorize_pad(
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
- // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
- // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]]
- // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
+ // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]]
+ // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
// CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
@@ -2857,7 +2857,47 @@ func.func @test_masked_vectorize_pad(
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
: (!pdl.operation) -> !pdl.operation
transform.structured.masked_vectorize %0 vector_sizes [2, 4]
}
+
+// -----
+
+func.func @vectorize_dynamic_matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?x?xf32>, %[[VAL_1:.*]]: memref<?x?xf32>, %[[VAL_2:.*]]: memref<?x?xf32>) {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?x?xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?x?xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref<?x?xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
+// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
+// CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
+// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32>
+// CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
+// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction <add>, %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
+// CHECK: return
+// CHECK: }
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4]
+}
More information about the Mlir-commits
mailing list