[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 20 03:05:18 PDT 2023
================
@@ -82,5 +88,69 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
return
}
+func.func @test_masked_outerproduct_no_accumulator_4x4xf32() {
+ %c0 = arith.constant 0 : index
+ %ones = arith.constant dense<1> : vector<[4]xi32>
+
+ %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+ %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+ %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+ %lhsDim = arith.constant 3 : index
+ %rhsDim = arith.constant 2 : index
+ %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+ %tile = vector.mask %mask {
+ vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+ // Print the tile. Due to masking the result will be the top 3x2xf32 section.
+ //
+ // WITH-MASK: TILE BEGIN
+ // WITH-MASK-NEXT: ( 1, 2, 0, 0
+ // WITH-MASK-NEXT: ( 2, 4, 0, 0
+ // WITH-MASK-NEXT: ( 3, 6, 0, 0
+ // WITH-MASK-NEXT: ( 0, 0, 0, 0
+ // WITH-MASK: TILE END
+ func.call @printTileBegin() : () -> ()
+ vector.print %tile : vector<[4]x[4]xf32>
+ func.call @printTileEnd() : () -> ()
+
+ return
+}
+
+func.func @test_masked_outerproduct_with_accumulator_4x4xf32() {
+ %c0 = arith.constant 0 : index
+ %ones = arith.constant dense<1> : vector<[4]xi32>
+ %f10 = arith.constant 10.0 : f32
+
+ %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+ %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+ %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+ %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+ %lhsDim = arith.constant 2 : index
+ %rhsDim = arith.constant 3 : index
+ %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+ %tile = vector.mask %mask {
+ vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+ // Print the tile. Due to masking the result will be the top 2x3xf32 section.
----------------
banach-space wrote:
Wow, brilliant!
https://github.com/llvm/llvm-project/pull/69604
More information about the Mlir-commits
mailing list