[Mlir-commits] [mlir] [MLIR][Vector] Add Lowering for vector.step (PR #113655)

Manupa Karunaratne llvmlistbot at llvm.org
Thu Oct 31 09:30:10 PDT 2024


================
@@ -144,43 +144,40 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?xf32>) -> tensor<1x1x?xf32> {
+#map = affine_map<(d0) -> (d0)>
+func.func @vectorize_linalg_index(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
   %0 = linalg.generic {
     indexing_maps = [#map],
-    iterator_types = ["parallel", "parallel", "parallel"]
-  } outs(%arg1 : tensor<1x1x?xf32>) {
+    iterator_types = ["parallel"]
+  } outs(%arg1 : tensor<?xf32>) {
   ^bb0(%in: f32):
     %1 = linalg.index 0 : index
-    %2 = linalg.index 1 : index
-    %3 = linalg.index 2 : index
-    %4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x?xf32>
-    linalg.yield %4 : f32
-  } -> tensor<1x1x?xf32>
-  return %0 : tensor<1x1x?xf32>
+    %2 = tensor.extract %arg0[%1] : tensor<?xf32>
+    linalg.yield %2 : f32
+  } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
 }
 
 // CHECK-LABEL: @vectorize_linalg_index
-// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
-// CHECK-DAG:          %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:          %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:          %[[C2:.*]] = arith.constant 2 : index
-// CHECK:        %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
-// CHECK:        %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
-// CHECK:       %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
-// CHECK:            %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
-// CHECK:             %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
-// CHECK:           return %[[OUT]] : tensor<1x1x?xf32>
+// CHECK-SAME:   %[[SRC:.*]]: tensor<?xf32>, %[[DST:.*]]: tensor<?xf32>
+// CHECK-DAG:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[DST_DIM0:.*]] = tensor.dim %[[DST]], %[[C0]] : tensor<?xf32>
+// CHECK:        %[[MASK:.*]] = vector.create_mask %[[DST_DIM0]] : vector<[4]xi1>
+// CHECK-DAG:    %[[STEP:.+]] = vector.step : vector<[4]xindex>
+// CHECK-DAG:    %[[STEP_ELEMENT:.+]] = vector.extractelement %[[STEP]][%c0_i32 : i32] : vector<[4]xindex>
+
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%[[STEP_ELEMENT]]], %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
+// CHECK: return %[[OUT]] : tensor<?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [1, 1, [4]] {vectorize_nd_extract} : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [[4]] {vectorize_nd_extract} : !transform.any_op
 
     %func = transform.structured.match ops{["func.func"]} in %arg1
       : (!transform.any_op) -> !transform.any_op
     transform.apply_patterns to %func {
-      transform.apply_patterns.canonicalization
       transform.apply_patterns.linalg.tiling_canonicalization
     } : !transform.any_op
----------------
manupak wrote:

(Prior to this PR `canonicalization` is folding the vector.step . So I dont see why we have to remove this here and now)

https://github.com/llvm/llvm-project/pull/113655


More information about the Mlir-commits mailing list