[Mlir-commits] [mlir] [mlir][vector] Add tests for scalable vectors (PR #67806)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 29 07:08:16 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

<details>
<summary>Changes</summary>

Adds tests for scalable vectors in:
  * vector-contract-to-outerproduct-transforms.mlir
Every existing test is duplicated with (fixed-width vectors are replaced
with scalable vectors). One test required a fix in
  * LowerVectorContract.cpp.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Fixes #<!-- -->67804


---

Patch is 38.84 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67806.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+1-1) 
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+452-43) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..6e63d52d22a1f6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
       return v;
     Type promotedType = dstElementType;
     if (vecType)
-      promotedType = VectorType::get(vecType.getShape(), promotedType);
+      promotedType = vecType.clone(promotedType);
     if (isa<FloatType>(dstElementType))
       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..3746897bcd864f6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -34,16 +34,16 @@
 // CHECK-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
 // CHECK-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
 // CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
 // CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
                                     %arg1: vector<3xf32>,
@@ -54,22 +54,46 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
+
+// CHECK-LABEL:   func.func @masked_extract_contract2_scalable(
+// CHECK-SAME:      %{{.*}}: vector<[2]x[3]xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[3]xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x[3]xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x[3]xi1> to vector<[3]x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK:         %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK:         %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,
@@ -80,10 +104,36 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
   return %0 : vector<3x7xf32>
 }
 
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable(
+// CHECK-SAME:    %{{.*}}: vector<[3]x[5]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<[5]x[7]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<[3]x[7]xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+// CHECK:         %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<[3]x[7]x[5]xi1> to vector<[5]x[3]x[7]xi1>
+// CHECK:         %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+
+func.func @masked_extract_contract4_scalable(%arg0: vector<[3]x[5]xf32>,
+                                    %arg1: vector<[5]x[7]xf32>,
+                                    %arg2: vector<[3]x[7]xf32>,
+                                    %m : vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<[3]x[5]xf32>, vector<[5]x[7]xf32> into vector<[3]x[7]xf32> } : vector<[3]x[7]x[5]xi1> -> vector<[3]x[7]xf32>
+  return %0 : vector<[3]x[7]xf32>
+}
+
 // CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<4x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 // CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
 //
@@ -116,6 +166,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[4]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[4]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME:  : vector<[2]x[4]xf32> to vector<[4]x[2]xf32>
+//
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: return %[[c3]] : vector<[2]x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<[2]x[4]xf32>,
+                          %arg1: vector<[4]x[3]xf32>,
+                          %arg2: vector<[2]x[3]xf32>) -> vector<[2]x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<[2]x[4]xf32>, vector<[4]x[3]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 // CHECK-LABEL: func @matmul_0
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -133,6 +219,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
 // CHECK-LABEL: func @matmul_0_mixed
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -152,6 +255,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf16>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf16>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf16> from vector<[1]x[2]xf16>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<[1]x[3]xf16>
+//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<[2]xf16> to vector<[2]xf32>
+//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<[2]x[1]xf16>, %arg1: vector<[1]x[3]xf16>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf16>, vector<[1]x[3]xf16> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_1 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -163,9 +285,9 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
 }
 
 // CHECK-LABEL: func @matmul_1
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
@@ -180,6 +302,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_2 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -191,9 +331,9 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_2
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
@@ -206,6 +346,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<[1]x[2]xf32>, vector<[1]x[3]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_3 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -217,9 +373,9 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_3
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
@@ -233,6 +389,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-L...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67806


More information about the Mlir-commits mailing list