[Mlir-commits] [mlir] e8dcf5f - [mlir] [VectorOps] Add expand/compress operations to Vector dialect

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 4 12:00:56 PDT 2020


Author: aartbik
Date: 2020-08-04T12:00:42-07:00
New Revision: e8dcf5f87dc20b3f08005ac767ff934e36bf2a5b

URL: https://github.com/llvm/llvm-project/commit/e8dcf5f87dc20b3f08005ac767ff934e36bf2a5b
DIFF: https://github.com/llvm/llvm-project/commit/e8dcf5f87dc20b3f08005ac767ff934e36bf2a5b.diff

LOG: [mlir] [VectorOps] Add expand/compress operations to Vector dialect

Introduces the expand and compress operations to the Vector dialect
(important memory operations for sparse computations), together
with a first reference implementation that lowers to the LLVM IR
dialect to enable running on CPU (and other targets that support
the corresponding LLVM IR intrinsics).

Reviewed By: reidtatge

Differential Revision: https://reviews.llvm.org/D84888

Added: 
    mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Target/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 4b1a6efe002f..768d8db121df 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1042,6 +1042,16 @@ def LLVM_masked_scatter
     "type($value) `,` type($mask) `into` type($ptrs)";
 }
 
+/// Create a call to Masked Expand Load intrinsic.
+def LLVM_masked_expandload
+    : LLVM_IntrOp<"masked.expandload", [0], [], [], 1>,
+      Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+/// Create a call to Masked Compress Store intrinsic.
+def LLVM_masked_compressstore
+    : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0>,
+    Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
 //
 // Atomic operations.
 //

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index b49cc4a62a50..89a2b1226e1e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1158,7 +1158,7 @@ def Vector_GatherOp :
                Variadic<VectorOfRank<[1]>>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
 
-  let summary = "gathers elements from memory into a vector as defined by an index vector";
+  let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
 
   let description = [{
     The gather operation gathers elements from memory into a 1-D vector as
@@ -1186,7 +1186,6 @@ def Vector_GatherOp :
     %g = vector.gather %base, %indices, %mask, %pass_thru
         : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
     ```
-
   }];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -1217,7 +1216,7 @@ def Vector_ScatterOp :
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$value)> {
 
-  let summary = "scatters elements from a vector into memory as defined by an index vector";
+  let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
 
   let description = [{
     The scatter operation scatters elements from a 1-D vector into memory as
@@ -1265,6 +1264,108 @@ def Vector_ScatterOp :
     "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
 }
 
+def Vector_ExpandLoadOp :
+  Vector_Op<"expandload">,
+    Arguments<(ins AnyMemRef:$base,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               VectorOfRank<[1]>:$pass_thru)>,
+    Results<(outs VectorOfRank<[1]>:$result)> {
+
+  let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
+
+  let description = [{
+    The expand load reads elements from memory into a 1-D vector as defined
+    by a base and a 1-D mask vector. When the mask is set, the next element
+    is read from memory. Otherwise, the corresponding element is taken from
+    a 1-D pass-through vector. Informally the semantics are:
+    ```
+    index = base
+    result[0] := mask[0] ? MEM[index++] : pass_thru[0]
+    result[1] := mask[1] ? MEM[index++] : pass_thru[1]
+    etc.
+    ```
+    Note that the index increment is done conditionally.
+
+    The expand load can be used directly where applicable, or can be used
+    during progressively lowering to bring other memory operations closer to
+    hardware ISA support for an expand. The semantics of the operation closely
+    correspond to those of the `llvm.masked.expandload`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
+
+    Example:
+
+    ```mlir
+    %0 = vector.expandload %base, %mask, %pass_thru
+       : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getMaskVectorType() {
+      return mask().getType().cast<VectorType>();
+    }
+    VectorType getPassThruVectorType() {
+      return pass_thru().getType().cast<VectorType>();
+    }
+    VectorType getResultVectorType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+    "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+}
+
+def Vector_CompressStoreOp :
+  Vector_Op<"compressstore">,
+    Arguments<(ins AnyMemRef:$base,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               VectorOfRank<[1]>:$value)> {
+
+  let summary = "writes elements selectively from a vector as defined by a mask";
+
+  let description = [{
+    The compress store operation writes elements from a 1-D vector into memory
+    as defined by a base and a 1-D mask vector. When the mask is set, the
+    corresponding element from the vector is written next to memory. Otherwise,
+    no action is taken for the element. Informally the semantics are:
+    ```
+    index = base
+    if (mask[0]) MEM[index++] = value[0]
+    if (mask[1]) MEM[index++] = value[1]
+    etc.
+    ```
+    Note that the index increment is done conditionally.
+
+    The compress store can be used directly where applicable, or can be used
+    during progressively lowering to bring other memory operations closer to
+    hardware ISA support for a compress. The semantics of the operation closely
+    correspond to those of the `llvm.masked.compressstore`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
+
+    Example:
+
+    ```mlir
+    vector.compressstore %base, %mask, %value
+      : memref<?xf32>, vector<8xi1>, vector<8xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getMaskVectorType() {
+      return mask().getType().cast<VectorType>();
+    }
+    VectorType getValueVectorType() {
+      return value().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
+    "type($base) `,` type($mask) `,` type($value)";
+}
+
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [NoSideEffect]>,
     Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
new file mode 100644
index 000000000000..6310d6ee8790
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @compress16(%base: memref<?xf32>,
+                 %mask: vector<16xi1>, %value: vector<16xf32>) {
+  vector.compressstore %base, %mask, %value
+    : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+func @printmem16(%A: memref<?xf32>) {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c16 = constant 16: index
+  %z = constant 0.0: f32
+  %m = vector.broadcast %z : f32 to vector<16xf32>
+  %mem = scf.for %i = %c0 to %c16 step %c1
+    iter_args(%m_iter = %m) -> (vector<16xf32>) {
+    %c = load %A[%i] : memref<?xf32>
+    %i32 = index_cast %i : index to i32
+    %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
+    scf.yield %m_new : vector<16xf32>
+  }
+  vector.print %mem : vector<16xf32>
+  return
+}
+
+func @entry() {
+  // Set up memory.
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c16 = constant 16: index
+  %A = alloc(%c16) : memref<?xf32>
+  %z = constant 0.0: f32
+  %v = vector.broadcast %z : f32 to vector<16xf32>
+  %value = scf.for %i = %c0 to %c16 step %c1
+    iter_args(%v_iter = %v) -> (vector<16xf32>) {
+    store %z, %A[%i] : memref<?xf32>
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
+    scf.yield %v_new : vector<16xf32>
+  }
+
+  // Set up masks.
+  %f = constant 0: i1
+  %t = constant 1: i1
+  %none = vector.constant_mask [0] : vector<16xi1>
+  %all = vector.constant_mask [16] : vector<16xi1>
+  %some1 = vector.constant_mask [4] : vector<16xi1>
+  %0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
+  %1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
+  %2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
+  %3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
+  %some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
+  %some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
+
+  //
+  // Expanding load tests.
+  //
+
+  call @compress16(%A, %none, %value)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+  call @compress16(%A, %all, %value)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  call @compress16(%A, %some3, %value)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK-NEXT: ( 1, 3, 7, 11, 13, 15, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  call @compress16(%A, %some2, %value)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK-NEXT: ( 1, 2, 3, 7, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  call @compress16(%A, %some1, %value)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  return
+}

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
new file mode 100644
index 000000000000..74118fc1125b
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @expand16(%base: memref<?xf32>,
+               %mask: vector<16xi1>,
+	       %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %e = vector.expandload %base, %mask, %pass_thru
+    : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %e : vector<16xf32>
+}
+
+func @entry() {
+  // Set up memory.
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c16 = constant 16: index
+  %A = alloc(%c16) : memref<?xf32>
+  scf.for %i = %c0 to %c16 step %c1 {
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    store %fi, %A[%i] : memref<?xf32>
+  }
+
+  // Set up pass thru vector.
+  %u = constant -7.0: f32
+  %v = constant 7.7: f32
+  %pass = vector.broadcast %u : f32 to vector<16xf32>
+
+  // Set up masks.
+  %f = constant 0: i1
+  %t = constant 1: i1
+  %none = vector.constant_mask [0] : vector<16xi1>
+  %all = vector.constant_mask [16] : vector<16xi1>
+  %some1 = vector.constant_mask [4] : vector<16xi1>
+  %0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
+  %1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
+  %2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
+  %3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
+  %some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
+  %some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
+
+  //
+  // Expanding load tests.
+  //
+
+  %e1 = call @expand16(%A, %none, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e1 : vector<16xf32>
+  // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
+  %e2 = call @expand16(%A, %all, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e2 : vector<16xf32>
+  // CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+  %e3 = call @expand16(%A, %some1, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e3 : vector<16xf32>
+  // CHECK-NEXT: ( 0, 1, 2, 3, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
+  %e4 = call @expand16(%A, %some2, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e4 : vector<16xf32>
+  // CHECK-NEXT: ( -7, 0, 1, 2, -7, -7, -7, 3, -7, -7, -7, 4, -7, 5, -7, 6 )
+
+  %e5 = call @expand16(%A, %some3, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e5 : vector<16xf32>
+  // CHECK-NEXT: ( -7, 0, -7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, -7, 5 )
+
+  %4 = vector.insert %v, %pass[1] : f32 into vector<16xf32>
+  %5 = vector.insert %v, %4[2] : f32 into vector<16xf32>
+  %alt_pass = vector.insert %v, %5[14] : f32 into vector<16xf32>
+  %e6 = call @expand16(%A, %some3, %alt_pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e6 : vector<16xf32>
+  // CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
+
+  return
+}

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
index 6dd0cf169552..54171e744605 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
@@ -11,34 +11,20 @@ func @scatter8(%base: memref<?xf32>,
   return
 }
 
-func @printmem(%A: memref<?xf32>) {
-  %f = constant 0.0: f32
-  %0 = vector.broadcast %f : f32 to vector<8xf32>
-  %1 = constant 0: index
-  %2 = load %A[%1] : memref<?xf32>
-  %3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
-  %4 = constant 1: index
-  %5 = load %A[%4] : memref<?xf32>
-  %6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
-  %7 = constant 2: index
-  %8 = load %A[%7] : memref<?xf32>
-  %9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
-  %10 = constant 3: index
-  %11 = load %A[%10] : memref<?xf32>
-  %12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
-  %13 = constant 4: index
-  %14 = load %A[%13] : memref<?xf32>
-  %15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
-  %16 = constant 5: index
-  %17 = load %A[%16] : memref<?xf32>
-  %18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
-  %19 = constant 6: index
-  %20 = load %A[%19] : memref<?xf32>
-  %21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
-  %22 = constant 7: index
-  %23 = load %A[%22] : memref<?xf32>
-  %24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
-  vector.print %24 : vector<8xf32>
+func @printmem8(%A: memref<?xf32>) {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c8 = constant 8: index
+  %z = constant 0.0: f32
+  %m = vector.broadcast %z : f32 to vector<8xf32>
+  %mem = scf.for %i = %c0 to %c8 step %c1
+    iter_args(%m_iter = %m) -> (vector<8xf32>) {
+    %c = load %A[%i] : memref<?xf32>
+    %i32 = index_cast %i : index to i32
+    %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32>
+    scf.yield %m_new : vector<8xf32>
+  }
+  vector.print %mem : vector<8xf32>
   return
 }
 
@@ -104,31 +90,27 @@ func @entry() {
   vector.print %idx : vector<8xi32>
   // CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
 
-  call @printmem(%A) : (memref<?xf32>) -> ()
+  call @printmem8(%A) : (memref<?xf32>) -> ()
   // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
 
   call @scatter8(%A, %idx, %none, %val)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
-  call @printmem(%A) : (memref<?xf32>) -> ()
+  call @printmem8(%A) : (memref<?xf32>) -> ()
   // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
 
   call @scatter8(%A, %idx, %some, %val)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
-  call @printmem(%A) : (memref<?xf32>) -> ()
+  call @printmem8(%A) : (memref<?xf32>) -> ()
   // CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
 
   call @scatter8(%A, %idx, %more, %val)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
-  call @printmem(%A) : (memref<?xf32>) -> ()
+  call @printmem8(%A) : (memref<?xf32>) -> ()
   // CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
 
   call @scatter8(%A, %idx, %all, %val)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
-  call @printmem(%A) : (memref<?xf32>) -> ()
+  call @printmem8(%A) : (memref<?xf32>) -> ()
   // CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
 
   return

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3dbfaf88a443..23373f5c7edf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
   return success();
 }
 
-// Helper that returns vector of pointers given a base and an index vector.
-LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
-                             LLVMTypeConverter &typeConverter, Location loc,
-                             Value memref, Value indices, MemRefType memRefType,
-                             VectorType vType, Type iType, Value &ptrs) {
+// Helper that returns the base address of a memref.
+LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
+                      Value memref, MemRefType memRefType, Value &base) {
   // Inspect stride and offset structure.
   //
   // TODO: flat memory only for now, generalize
@@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
       offset != 0 || memRefType.getMemorySpace() != 0)
     return failure();
+  base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
+  return success();
+}
+
+// Helper that returns a pointer given a memref base.
+LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
+                         Value memref, MemRefType memRefType, Value &ptr) {
+  Value base;
+  if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+    return failure();
+  auto pType = MemRefDescriptor(memref).getElementType();
+  ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
+  return success();
+}
 
-  // Create a vector of pointers from base and indices.
-  MemRefDescriptor memRefDescriptor(memref);
-  Value base = memRefDescriptor.alignedPtr(rewriter, loc);
-  int64_t size = vType.getDimSize(0);
-  auto pType = memRefDescriptor.getElementType();
-  auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
+// Helper that returns vector of pointers given a memref base and an index
+// vector.
+LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
+                             Value memref, Value indices, MemRefType memRefType,
+                             VectorType vType, Type iType, Value &ptrs) {
+  Value base;
+  if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+    return failure();
+  auto pType = MemRefDescriptor(memref).getElementType();
+  auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
   return success();
 }
@@ -305,9 +321,8 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
     VectorType vType = gather.getResultVectorType();
     Type iType = gather.getIndicesVectorType().getElementType();
     Value ptrs;
-    if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
-                              adaptor.indices(), gather.getMemRefType(), vType,
-                              iType, ptrs)))
+    if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
+                              gather.getMemRefType(), vType, iType, ptrs)))
       return failure();
 
     // Replace with the gather intrinsic.
@@ -344,9 +359,8 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
     VectorType vType = scatter.getValueVectorType();
     Type iType = scatter.getIndicesVectorType().getElementType();
     Value ptrs;
-    if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
-                              adaptor.indices(), scatter.getMemRefType(), vType,
-                              iType, ptrs)))
+    if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
+                              scatter.getMemRefType(), vType, iType, ptrs)))
       return failure();
 
     // Replace with the scatter intrinsic.
@@ -357,6 +371,60 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
   }
 };
 
+/// Conversion pattern for a vector.expandload.
+class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorExpandLoadOpConversion(MLIRContext *context,
+                                        LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
+                             typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto expand = cast<vector::ExpandLoadOp>(op);
+    auto adaptor = vector::ExpandLoadOpAdaptor(operands);
+
+    Value ptr;
+    if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
+                          ptr)))
+      return failure();
+
+    auto vType = expand.getResultVectorType();
+    rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
+        op, typeConverter.convertType(vType), ptr, adaptor.mask(),
+        adaptor.pass_thru());
+    return success();
+  }
+};
+
+/// Conversion pattern for a vector.compressstore.
+class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorCompressStoreOpConversion(MLIRContext *context,
+                                           LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
+                             context, typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto compress = cast<vector::CompressStoreOp>(op);
+    auto adaptor = vector::CompressStoreOpAdaptor(operands);
+
+    Value ptr;
+    if (failed(getBasePtr(rewriter, loc, adaptor.base(),
+                          compress.getMemRefType(), ptr)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
+        op, adaptor.value(), ptr, adaptor.mask());
+    return success();
+  }
+};
+
 /// Conversion pattern for all vector reductions.
 class VectorReductionOpConversion : public ConvertToLLVMPattern {
 public:
@@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
               VectorTransferConversion<TransferWriteOp>,
               VectorTypeCastOpConversion,
               VectorGatherOpConversion,
-              VectorScatterOpConversion>(ctx, converter);
+              VectorScatterOpConversion,
+              VectorExpandLoadOpConversion,
+              VectorCompressStoreOpConversion>(ctx, converter);
   // clang-format on
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index c788d4ccb4a0..9e64ff9af80a 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1898,6 +1898,41 @@ static LogicalResult verify(ScatterOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ExpandLoadOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ExpandLoadOp op) {
+  VectorType maskVType = op.getMaskVectorType();
+  VectorType passVType = op.getPassThruVectorType();
+  VectorType resVType = op.getResultVectorType();
+
+  if (resVType.getElementType() != op.getMemRefType().getElementType())
+    return op.emitOpError("base and result element type should match");
+
+  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+    return op.emitOpError("expected result dim to match mask dim");
+  if (resVType != passVType)
+    return op.emitOpError("expected pass_thru of same type as result type");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CompressStoreOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(CompressStoreOp op) {
+  VectorType maskVType = op.getMaskVectorType();
+  VectorType valueVType = op.getValueVectorType();
+
+  if (valueVType.getElementType() != op.getMemRefType().getElementType())
+    return op.emitOpError("base and value element type should match");
+
+  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+    return op.emitOpError("expected value dim to match mask dim");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2e5aae886c38..be70c08bc948 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -989,3 +989,23 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
 // CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<float>, !llvm.vec<3 x i32>) -> !llvm.vec<3 x ptr<float>>
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<3 x float>, !llvm.vec<3 x i1> into !llvm.vec<3 x ptr<float>>
 // CHECK: llvm.return
+
+func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
+  %0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
+  return %0 : vector<11xf32>
+}
+
+// CHECK-LABEL: func @expand_load_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
+// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<float>, !llvm.vec<11 x i1>, !llvm.vec<11 x float>) -> !llvm.vec<11 x float>
+// CHECK: llvm.return %[[E]] : !llvm.vec<11 x float>
+
+func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
+  vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
+  return
+}
+
+// CHECK-LABEL: func @compress_store_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
+// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x float>, !llvm.ptr<float>, !llvm.vec<11 x i1>) -> ()
+// CHECK: llvm.return

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea354f51645a..651fe27cd36c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1240,3 +1240,38 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
   // expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
   vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
 }
+
+// -----
+
+func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  // expected-error at +1 {{'vector.expandload' op base and result element type should match}}
+  %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
+  // expected-error at +1 {{'vector.expandload' op expected result dim to match mask dim}}
+  %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
+  // expected-error at +1 {{'vector.expandload' op expected pass_thru of same type as result type}}
+  %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
+}
+
+// -----
+
+func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
+  vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
+  vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 0bf4ed8f84c7..d4d1abe8e646 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -379,3 +379,12 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
   vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
   return
 }
+
+// CHECK-LABEL: @expand_and_compress
+func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+  // CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}

diff  --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index fc286599ee95..6bf9b9768dd3 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -237,8 +237,8 @@ llvm.func @matrix_intrinsics(%A: !llvm.vec<64 x float>, %B: !llvm.vec<48 x float
   llvm.return
 }
 
-// CHECK-LABEL: @masked_intrinsics
-llvm.func @masked_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
+// CHECK-LABEL: @masked_load_store_intrinsics
+llvm.func @masked_load_store_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
   // CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
   %a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
     (!llvm.ptr<vec<7 x float>>, !llvm.vec<7 x i1>) -> !llvm.vec<7 x float>
@@ -265,6 +265,17 @@ llvm.func @masked_gather_scatter_intrinsics(%M: !llvm.vec<7 x ptr<float>>, %mask
   llvm.return
 }
 
+// CHECK-LABEL: @masked_expand_compress_intrinsics
+llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr<float>, %mask: !llvm.vec<7 x i1>, %passthru: !llvm.vec<7 x float>) {
+  // CHECK: call <7 x float> @llvm.masked.expandload.v7f32(float* %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
+  %0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru)
+    : (!llvm.ptr<float>, !llvm.vec<7 x i1>, !llvm.vec<7 x float>) -> (!llvm.vec<7 x float>)
+  // CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, float* %{{.*}}, <7 x i1> %{{.*}})
+  "llvm.intr.masked.compressstore"(%0, %ptr, %mask)
+    : (!llvm.vec<7 x float>, !llvm.ptr<float>, !llvm.vec<7 x i1>) -> ()
+  llvm.return
+}
+
 // CHECK-LABEL: @memcpy_test
 llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>, %arg3: !llvm.ptr<i8>) {
   // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})


        


More information about the Mlir-commits mailing list