[Mlir-commits] [mlir] [mlir][vector] Add patterns for vector masked load/store (PR #74834)

Jakub Kuderski llvmlistbot at llvm.org
Wed Dec 13 08:37:49 PST 2023


================
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  } {
+// CHECK-LABEL:  @vector_maskedload
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+//       CHECK:    %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+//       CHECK:    %[[C0_i32:.*]] = arith.constant 0 : i32
+//       CHECK:    %[[C1_i32:.*]] = arith.constant 1 : i32
+//       CHECK:    %[[C4_i32:.*]] = arith.constant 4 : i32
+//       CHECK:    %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:    %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:    %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:    %[[S1:.*]] = scf.for %[[ARG1:.*]] = %[[C0_i32]] to %[[C4_i32]] step %[[C1_i32]] iter_args(%[[ARG2:.*]] = %[[CST]]) -> (vector<4xf32>)  : i32 {
+//       CHECK:      %[[S2:.*]] = vector.extractelement %[[S0]][%[[ARG1]] : i32] : vector<4xi1>
+//       CHECK:      %[[S3:.*]] = scf.if %[[S2]] -> (vector<4xf32>) {
+//       CHECK:        %[[S4:.*]] = arith.index_cast %[[ARG1]] : i32 to index
+//       CHECK:        %[[S5:.*]] = arith.addi %[[S4]], %[[C4]] : index
+//       CHECK:        %[[S6:.*]] = memref.load %[[ARG0]][%[[C0]], %[[S5]]] : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>
+//       CHECK:        %[[S7:.*]] = vector.insertelement %[[S6]], %[[ARG2]][%[[S4]] : index] : vector<4xf32>
+//       CHECK:        scf.yield %[[S7]] : vector<4xf32>
+//       CHECK:      } else {
+//       CHECK:        scf.yield %[[ARG2]] : vector<4xf32>
+//       CHECK:      }
+//       CHECK:      scf.yield %[[S3]] : vector<4xf32>
+//       CHECK:    }
+//       CHECK:    return %[[S1]] : vector<4xf32>
+//       CHECK:  }
+func.func @vector_maskedload(%arg0 : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK-LABEL:  @vector_maskedstore
+//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32, #spirv.storage_class<StorageBuffer>>, %[[ARG1:.*]]: vector<4xf32>) {
+//       CHECK:  %[[C0_I32:.*]] = arith.constant 0 : i32
+//       CHECK:  %[[C1_I32:.*]] = arith.constant 1 : i32
+//       CHECK:  %[[C4_I32:.*]] = arith.constant 4 : i32
+//       CHECK:  %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:  %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:  %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
+//       CHECK:  scf.for %[[ARG2:.*]] = %[[C0_I32]] to %[[C4_I32]] step %[[C1_I32]]  : i32 {
----------------
kuhar wrote:

For example here we can statically tell the loop has 4 iterations, and the mask type is `vector<4xi1>`

https://github.com/llvm/llvm-project/pull/74834


More information about the Mlir-commits mailing list