[Mlir-commits] [mlir] [mlir][Vector] Make `vector.contract` work with scalable vectors (PR #65724)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Sep 8 01:34:03 PDT 2023
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/65724:
This is just a small fix that makes sure that `vector.contract` works with scalable vectors.
Rather than duplicating all the roundtrip tests for vector.contract, I'm treating scalable vectors as an edge case and just adding a couple to verify that this works.
>From 21542ceb35b42350ecd52a97719ab7d5f8ffff2b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Sep 2023 08:13:24 +0000
Subject: [PATCH] [mlir][Vector] Make `vector.contract` work with scalable
vectors
This is just a small fix that makes sure that `vector.contract` works
with scalable vectors.
Rather than duplicating all the roundtrip tests for vector.contract, I'm
treating scalable vectors as an edge case and just adding a couple to
verify that this works.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 ++-
mlir/test/Dialect/Vector/ops.mlir | 29 ++++++++++++++++++++++++
2 files changed, 31 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2aaf1cb7e5878e4..6473c92a91aa64b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape(
return e.cast<AffineConstantExpr>().getValue();
}));
auto expected =
- VectorType::get(expectedShape, resVectorType.getElementType());
+ VectorType::get(expectedShape, resVectorType.getElementType(),
+ resVectorType.getScalableDims());
if (resVectorType != expected || accVectorType != expected)
return op.emitOpError(
"invalid accumulator/result vector shape, expected: ")
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2154304965a5d04..f00bc6e97b350ea 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -
return %0 : f32
}
+// CHECK-LABEL: @contraction_to_scalar_scalable
+func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+ %f0 = arith.constant 0.0: f32
+ // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32
+ %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
+ : vector<[10]xf32>, vector<[10]xf32> into f32
+ // CHECK: return %[[X]] : f32
+ return %0 : f32
+}
+
// CHECK-LABEL: @contraction_extra_attrs
func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
@@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
return
}
+#contraction_matmul_accesses = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+#contraction_matmul_trait = {
+ indexing_maps = #contraction_matmul_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL: @contraction_matmul_scalable
+func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> {
+ // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+ %res = vector.contract #contraction_matmul_trait %A, %B, %C
+ : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+ // CHECK: return %[[X]] : vector<8x[32]xf32>
+ return %res : vector<8x[32]xf32>
+}
+
// CHECK-LABEL: @create_vector_mask
func.func @create_vector_mask() {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
More information about the Mlir-commits
mailing list