[Mlir-commits] [mlir] 1726b65 - [MLIR][Vector] Refactor tests for contract -> OP transforms (4/N) (#73807)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 1 00:45:18 PST 2023
Author: Andrzej WarzyĆski
Date: 2023-12-01T08:45:13Z
New Revision: 1726b65e4c273d55dd54838a742b03caff4abcdd
URL: https://github.com/llvm/llvm-project/commit/1726b65e4c273d55dd54838a742b03caff4abcdd
DIFF: https://github.com/llvm/llvm-project/commit/1726b65e4c273d55dd54838a742b03caff4abcdd.diff
LOG: [MLIR][Vector] Refactor tests for contract -> OP transforms (4/N) (#73807)
This patch refactors tests for:
vector.contract -> vector.outerproduct
for matvec operations (b += Ax). Summary of changes:
* add 2 missing cases (masked + scalable) when the operation kind is
`maxf`.
This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect.
Implements #72834.
Added:
Modified:
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index e84a43feaff39dc..8fed1f8fb341547 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -186,6 +186,52 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
return %0 : vector<2xf32>
}
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m_max(
+// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<2xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+func.func @masked_matvec_mk_k_m_max(%A: vector<2x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<2xf32>,
+ %m: vector<2x3xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(
+// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+ %x: vector<3xf32>,
+ %b: vector<[2]xf32>,
+ %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
+ : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
// ============================================================================
// Matvec 2 (plain + masked + scalable)
// ============================================================================
More information about the Mlir-commits
mailing list