[Mlir-commits] [mlir] 7ec8fd4 - [mlir][Vector] Make `vector.contract` work with scalable vectors (#65724)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 01:14:29 PDT 2023


Author: Andrzej WarzyƄski
Date: 2023-09-11T09:14:25+01:00
New Revision: 7ec8fd4cc749545fd74959082b5f93fdef802f6e

URL: https://github.com/llvm/llvm-project/commit/7ec8fd4cc749545fd74959082b5f93fdef802f6e
DIFF: https://github.com/llvm/llvm-project/commit/7ec8fd4cc749545fd74959082b5f93fdef802f6e.diff

LOG: [mlir][Vector] Make `vector.contract` work with scalable vectors (#65724)

This is just a small fix that makes sure that `vector.contract` works
with scalable vectors.

Rather than duplicating all the roundtrip tests for vector.contract, I'm
treating scalable vectors as an edge case and just adding a couple to
verify that this works.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2aaf1cb7e5878e4..6473c92a91aa64b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape(
           return e.cast<AffineConstantExpr>().getValue();
         }));
     auto expected =
-        VectorType::get(expectedShape, resVectorType.getElementType());
+        VectorType::get(expectedShape, resVectorType.getElementType(),
+                        resVectorType.getScalableDims());
     if (resVectorType != expected || accVectorType != expected)
       return op.emitOpError(
                  "invalid accumulator/result vector shape, expected: ")

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 2154304965a5d04..f00bc6e97b350ea 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -
   return %0 : f32
 }
 
+// CHECK-LABEL: @contraction_to_scalar_scalable
+func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 {
+  // CHECK:      %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+  %f0 = arith.constant 0.0: f32
+  // CHECK:      %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32
+  %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
+    : vector<[10]xf32>, vector<[10]xf32> into f32
+  // CHECK:      return %[[X]] : f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: @contraction_extra_attrs
 func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
   // CHECK:      %[[C0:.*]] = arith.constant 0.000000e+00 : f32
@@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
   return
 }
 
+#contraction_matmul_accesses = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d2, d1)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+#contraction_matmul_trait = {
+  indexing_maps = #contraction_matmul_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+// CHECK-LABEL: @contraction_matmul_scalable
+func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> {
+  // CHECK:   %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+  %res = vector.contract #contraction_matmul_trait %A, %B, %C
+    : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
+  // CHECK:   return %[[X]] : vector<8x[32]xf32>
+  return %res : vector<8x[32]xf32>
+}
+
 // CHECK-LABEL: @create_vector_mask
 func.func @create_vector_mask() {
   // CHECK:      %[[C2:.*]] = arith.constant 2 : index


        


More information about the Mlir-commits mailing list