[Mlir-commits] [mlir] [MLIR][Linalg] Scalable Vectorization of Reduction on the Trailing Dimension (PR #97788)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Jul 21 09:09:01 PDT 2024


================
@@ -189,3 +189,85 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @vectorize_dynamic_reduction_scalable_1d(%arg0: tensor<?xf32>,
+                                          %arg1: tensor<f32>) -> tensor<f32> {
+
+  %0 = linalg.reduce ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<f32>) dimensions = [0]
+  (%in: f32, %init: f32) {
+    %0 = arith.addf %in, %init : f32
+    linalg.yield %0 : f32
+  }
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:  func.func @vectorize_dynamic_reduction_scalable_1d(
+// CHECK-SAME:     %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:          %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK:          %[[VAL_1:.*]] = tensor.dim %[[ARG_0]], %[[VAL_0]] : tensor<?xf32>
+// CHECK:          %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:          %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:          %[[VAL_4:.*]] = vector.create_mask %[[VAL_1]] : vector<[4]xi1>
+// CHECK:          %[[VAL_5:.*]] = vector.mask %[[VAL_4]] { vector.transfer_read %[[ARG_0]][%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK:          %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:          %[[VAL_7:.*]] = vector.transfer_read %[[ARG_1]][], %[[VAL_6]] : tensor<f32>, vector<f32>
+// CHECK:          %[[VAL_8:.*]] = vector.extractelement %[[VAL_7]][] : vector<f32>
+// CHECK:          %[[VAL_9:.*]] = vector.mask %[[VAL_4]] { vector.multi_reduction <add>, %[[VAL_5]], %[[VAL_8]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : f32 to vector<f32>
+// CHECK:          %[[VAL_11:.*]] = vector.transfer_write %[[VAL_10]], %[[ARG_1]][] : vector<f32>, tensor<f32>
+// CHECK:          return %[[VAL_11]] : tensor<f32>
+// CHECK:        }
----------------
banach-space wrote:

These two lines can be skipped:
* https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices

https://github.com/llvm/llvm-project/pull/97788


More information about the Mlir-commits mailing list