[Mlir-commits] [mlir] [mlir][Vector] Make `vector.contract` work with scalable vectors (PR #65724)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Sep 8 01:34:03 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/65724:

This is just a small fix that makes sure that `vector.contract` works with scalable vectors.

Rather than duplicating all the roundtrip tests for vector.contract, I'm treating scalable vectors as an edge case and just adding a couple to verify that this works.

>From 21542ceb35b42350ecd52a97719ab7d5f8ffff2b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 08:13:24 +0000
Subject: [PATCH] [mlir][Vector] Make `vector.contract` work with scalable
 vectors

This is just a small fix that makes sure that `vector.contract` works
with scalable vectors.

Rather than duplicating all the roundtrip tests for vector.contract, I'm
treating scalable vectors as an edge case and just adding a couple to
verify that this works.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp |  3 ++-
 mlir/test/Dialect/Vector/ops.mlir        | 29 ++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2aaf1cb7e5878e4..6473c92a91aa64b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape(
           return e.cast<AffineConstantExpr>().getValue();
         }));
     auto expected =
-        VectorType::get(expectedShape, resVectorType.getElementType());
+        VectorType::get(expectedShape, resVectorType.getElementType(),
+                        resVectorType.getScalableDims());
     if (resVectorType != expected || accVectorType != expected)
       return op.emitOpError(
                  "invalid accumulator/result vector shape, expected: ")
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2154304965a5d04..f00bc6e97b350ea 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -
   return %0 : f32
 }
 
+// CHECK-LABEL: @contraction_to_scalar_scalable
+func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 {
+  // CHECK:      %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+  %f0 = arith.constant 0.0: f32
+  // CHECK:      %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32
+  %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
+    : vector<[10]xf32>, vector<[10]xf32> into f32
+  // CHECK:      return %[[X]] : f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: @contraction_extra_attrs
 func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
   // CHECK:      %[[C0:.*]] = arith.constant 0.000000e+00 : f32
@@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
   return
 }
 
+#contraction_matmul_accesses = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d2, d1)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+#contraction_matmul_trait = {
+  indexing_maps = #contraction_matmul_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL: @contraction_matmul_scalable
+func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> {
+  // CHECK:   %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+  %res = vector.contract #contraction_matmul_trait %A, %B, %C
+    : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+  // CHECK:   return %[[X]] : vector<8x[32]xf32>
+  return %res : vector<8x[32]xf32>
+}
+
 // CHECK-LABEL: @create_vector_mask
 func.func @create_vector_mask() {
   // CHECK:      %[[C2:.*]] = arith.constant 2 : index



More information about the Mlir-commits mailing list