[Mlir-commits] [mlir] 6b66f21 - [mlir] [VectorOps] Canonicalization of 1-D memory operations

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 13 17:16:39 PDT 2020


Author: aartbik
Date: 2020-08-13T17:15:35-07:00
New Revision: 6b66f21446b982ec698830d0ea8469cee0b208ac

URL: https://github.com/llvm/llvm-project/commit/6b66f21446b982ec698830d0ea8469cee0b208ac
DIFF: https://github.com/llvm/llvm-project/commit/6b66f21446b982ec698830d0ea8469cee0b208ac.diff

LOG: [mlir] [VectorOps] Canonicalization of 1-D memory operations

Masked loading/storing in various forms can be optimized
into simpler memory operations when the mask is all true
or all false. Note that the backend does similar optimizations
but doing this early may expose more opportunities for further
optimizations. This further prepares progressively lowering
transfer read and write into 1-D memory operations.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D85769

Added: 
    mlir/test/Dialect/Vector/vector-mem-transforms.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 4f98fd97df48..3dc01b3c0914 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1198,6 +1198,7 @@ def Vector_MaskedLoadOp :
   }];
   let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
     "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_MaskedStoreOp :
@@ -1244,6 +1245,7 @@ def Vector_MaskedStoreOp :
   }];
   let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
     "type($mask) `,` type($value) `into` type($base)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_GatherOp :
@@ -1303,6 +1305,7 @@ def Vector_GatherOp :
     }
   }];
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_ScatterOp :
@@ -1358,6 +1361,7 @@ def Vector_ScatterOp :
   }];
   let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` "
     "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_ExpandLoadOp :
@@ -1411,6 +1415,7 @@ def Vector_ExpandLoadOp :
   }];
   let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
     "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_CompressStoreOp :
@@ -1460,6 +1465,7 @@ def Vector_CompressStoreOp :
   }];
   let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
     "type($base) `,` type($mask) `,` type($value)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_ShapeCastOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1c0a5ceb8d86..d69fe96f6a2a 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -30,6 +30,66 @@
 using namespace mlir;
 using namespace mlir::vector;
 
+/// Helper enum to classify mask value.
+enum class MaskFormat {
+  AllTrue = 0,
+  AllFalse = 1,
+  Unknown = 2,
+};
+
+/// Helper method to classify a 1-D mask value. Currently, the method
+/// looks "under the hood" of a constant value with dense attributes
+/// and a constant mask operation (since the client may be called at
+/// various stages during progressive lowering).
+static MaskFormat get1DMaskFormat(Value mask) {
+  if (auto c = mask.getDefiningOp<ConstantOp>()) {
+    // Inspect constant dense values. We count up for bits that
+    // are set, count down for bits that are cleared, and bail
+    // when a mix is detected.
+    if (auto denseElts = c.value().dyn_cast<DenseIntElementsAttr>()) {
+      int64_t val = 0;
+      for (llvm::APInt b : denseElts)
+        if (b.getBoolValue() && val >= 0)
+          val++;
+        else if (!b.getBoolValue() && val <= 0)
+          val--;
+        else
+          return MaskFormat::Unknown;
+      if (val > 0)
+        return MaskFormat::AllTrue;
+      if (val < 0)
+        return MaskFormat::AllFalse;
+    }
+  } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
+    // Inspect constant mask index. If the index exceeds the
+    // dimension size, all bits are set. If the index is zero
+    // or less, no bits are set.
+    ArrayAttr masks = m.mask_dim_sizes();
+    assert(masks.size() == 1);
+    int64_t i = masks[0].cast<IntegerAttr>().getInt();
+    int64_t u = m.getType().cast<VectorType>().getDimSize(0);
+    if (i >= u)
+      return MaskFormat::AllTrue;
+    if (i <= 0)
+      return MaskFormat::AllFalse;
+  }
+  return MaskFormat::Unknown;
+}
+
+/// Helper method to cast a 1-D memref<10xf32> "base" into a
+/// memref<vector<10xf32>> in the output parameter "newBase",
+/// using the 'element' vector type "vt". Returns true on success.
+static bool castedToMemRef(Location loc, Value base, MemRefType mt,
+                           VectorType vt, PatternRewriter &rewriter,
+                           Value &newBase) {
+  // The vector.type_cast operation does not accept unknown memref<?xf32>.
+  // TODO: generalize the cast and accept this case too
+  if (!mt.hasStaticShape())
+    return false;
+  newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // VectorDialect
 //===----------------------------------------------------------------------===//
@@ -1869,6 +1929,35 @@ static LogicalResult verify(MaskedLoadOp op) {
   return success();
 }
 
+namespace {
+class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
+public:
+  using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(MaskedLoadOp load,
+                                PatternRewriter &rewriter) const override {
+    Value newBase;
+    switch (get1DMaskFormat(load.mask())) {
+    case MaskFormat::AllTrue:
+      if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
+                          load.getResultVectorType(), rewriter, newBase))
+        return failure();
+      rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
+      return success();
+    case MaskFormat::AllFalse:
+      rewriter.replaceOp(load, load.pass_thru());
+      return success();
+    case MaskFormat::Unknown:
+      return failure();
+    }
+  }
+};
+} // namespace
+
+void MaskedLoadOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<MaskedLoadFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MaskedStoreOp
 //===----------------------------------------------------------------------===//
@@ -1885,6 +1974,35 @@ static LogicalResult verify(MaskedStoreOp op) {
   return success();
 }
 
+namespace {
+class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
+public:
+  using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(MaskedStoreOp store,
+                                PatternRewriter &rewriter) const override {
+    Value newBase;
+    switch (get1DMaskFormat(store.mask())) {
+    case MaskFormat::AllTrue:
+      if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
+                          store.getValueVectorType(), rewriter, newBase))
+        return failure();
+      rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
+      return success();
+    case MaskFormat::AllFalse:
+      rewriter.eraseOp(store);
+      return success();
+    case MaskFormat::Unknown:
+      return failure();
+    }
+  }
+};
+} // namespace
+
+void MaskedStoreOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<MaskedStoreFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // GatherOp
 //===----------------------------------------------------------------------===//
@@ -1909,6 +2027,30 @@ static LogicalResult verify(GatherOp op) {
   return success();
 }
 
+namespace {
+class GatherFolder final : public OpRewritePattern<GatherOp> {
+public:
+  using OpRewritePattern<GatherOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherOp gather,
+                                PatternRewriter &rewriter) const override {
+    switch (get1DMaskFormat(gather.mask())) {
+    case MaskFormat::AllTrue:
+      return failure(); // no unmasked equivalent
+    case MaskFormat::AllFalse:
+      rewriter.replaceOp(gather, gather.pass_thru());
+      return success();
+    case MaskFormat::Unknown:
+      return failure();
+    }
+  }
+};
+} // namespace
+
+void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<GatherFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -1928,6 +2070,30 @@ static LogicalResult verify(ScatterOp op) {
   return success();
 }
 
+namespace {
+class ScatterFolder final : public OpRewritePattern<ScatterOp> {
+public:
+  using OpRewritePattern<ScatterOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ScatterOp scatter,
+                                PatternRewriter &rewriter) const override {
+    switch (get1DMaskFormat(scatter.mask())) {
+    case MaskFormat::AllTrue:
+      return failure(); // no unmasked equivalent
+    case MaskFormat::AllFalse:
+      rewriter.eraseOp(scatter);
+      return success();
+    case MaskFormat::Unknown:
+      return failure();
+    }
+  }
+};
+} // namespace
+
+void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                            MLIRContext *context) {
+  results.insert<ScatterFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ExpandLoadOp
 //===----------------------------------------------------------------------===//
@@ -1947,6 +2113,36 @@ static LogicalResult verify(ExpandLoadOp op) {
   return success();
 }
 
+namespace {
+class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
+public:
+  using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ExpandLoadOp expand,
+                                PatternRewriter &rewriter) const override {
+    Value newBase;
+    switch (get1DMaskFormat(expand.mask())) {
+    case MaskFormat::AllTrue:
+      if (!castedToMemRef(expand.getLoc(), expand.base(),
+                          expand.getMemRefType(), expand.getResultVectorType(),
+                          rewriter, newBase))
+        return failure();
+      rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
+      return success();
+    case MaskFormat::AllFalse:
+      rewriter.replaceOp(expand, expand.pass_thru());
+      return success();
+    case MaskFormat::Unknown:
+      return failure();
+    }
+  }
+};
+} // namespace
+
+void ExpandLoadOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ExpandLoadFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CompressStoreOp
 //===----------------------------------------------------------------------===//
@@ -1963,6 +2159,36 @@ static LogicalResult verify(CompressStoreOp op) {
   return success();
 }
 
+namespace {
+class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
+public:
+  using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CompressStoreOp compress,
+                                PatternRewriter &rewriter) const override {
+    Value newBase;
+    switch (get1DMaskFormat(compress.mask())) {
+    case MaskFormat::AllTrue:
+      if (!castedToMemRef(compress.getLoc(), compress.base(),
+                          compress.getMemRefType(),
+                          compress.getValueVectorType(), rewriter, newBase))
+        return failure();
+      rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
+      return success();
+    case MaskFormat::AllFalse:
+      rewriter.eraseOp(compress);
+      return success();
+    case MaskFormat::Unknown:
+      return failure();
+    }
+  }
+};
+} // namespace
+
+void CompressStoreOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<CompressStoreFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//
@@ -2390,7 +2616,9 @@ void CreateMaskOp::getCanonicalizationPatterns(
 
 void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder,
+  patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
+                  GatherFolder, ScatterFolder, ExpandLoadFolder,
+                  CompressStoreFolder, StridedSliceConstantMaskFolder,
                   TransposeFolder>(context);
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
new file mode 100644
index 000000000000..7d79d8b1ed72
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -0,0 +1,177 @@
+// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+
+//
+// TODO: optimize this one too!
+//
+// CHECK-LABEL: func @maskedload0(
+// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[M:.*]] = vector.constant_mask
+// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-NEXT: return %[[T]] : vector<16xf32>
+
+func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %base, %mask, %pass_thru
+    : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @maskedload1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
+// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
+// CHECK-NEXT: return %[[T1]] : vector<16xf32>
+
+func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %base, %mask, %pass_thru
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @maskedload2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: return %[[A1]] : vector<16xf32>
+
+func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  %ld = vector.maskedload %base, %mask, %pass_thru
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @maskedstore1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
+// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
+// CHECK-NEXT: return
+
+func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  vector.maskedstore %base, %mask, %value
+    : vector<16xi1>, vector<16xf32> into memref<16xf32>
+  return
+}
+
+// CHECK-LABEL: func @maskedstore2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: return
+
+func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  vector.maskedstore %base, %mask, %value
+    : vector<16xi1>, vector<16xf32> into memref<16xf32>
+  return
+}
+
+// CHECK-LABEL: func @gather1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// CHECK-NEXT: return %1 : vector<16xf32>
+
+func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.gather %base, %indices, %mask, %pass_thru
+    : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @gather2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
+// CHECK-NEXT: return %[[A2]] : vector<16xf32>
+
+func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  %ld = vector.gather %base, %indices, %mask, %pass_thru
+    : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @scatter1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+// CHECK-NEXT: return
+
+func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  vector.scatter %base, %indices, %mask, %value
+    : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
+// CHECK-NEXT: return
+
+func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+  %0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  vector.scatter %base, %indices, %mask, %value
+    : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+  return
+}
+
+// CHECK-LABEL: func @expand1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
+// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
+// CHECK-NEXT: return %[[T1]] : vector<16xf32>
+
+func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.expandload %base, %mask, %pass_thru
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @expand2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: return %[[A1]] : vector<16xf32>
+
+func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  %ld = vector.expandload %base, %mask, %pass_thru
+    : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+// CHECK-LABEL: func @compress1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
+// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
+// CHECK-NEXT: return
+
+func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  vector.compressstore %base, %mask, %value  : memref<16xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+// CHECK-LABEL: func @compress2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
+// CHECK-NEXT: return
+
+func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
+  %mask = vector.constant_mask [0] : vector<16xi1>
+  vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
+  return
+}


        


More information about the Mlir-commits mailing list