[Mlir-commits] [mlir] [mlir][vector] Restrict narrow-type-emulation patterns (PR #115612)

Andrzej Warzyński llvmlistbot at llvm.org
Sun Nov 10 04:23:24 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/115612

>From 711a5483acee422fda12cf2b6f36fe8d4bb4de3e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 9 Nov 2024 16:22:57 +0000
Subject: [PATCH 1/2] [mlir][vector] Restrict narrow-type-emulation patterns
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

All patterns under `populateVectorNarrowTypeEmulationPatterns` assume a
1-D vector load/store (as opposed to n-D vector load/store). This is
evident from `ConvertVectorTransferRead`, e.g., here:

```cpp
auto newRead = rewriter.create<vector::TransferReadOp>(
    loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
    getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
    newPadding);

auto bitCast = rewriter.create<vector::BitCastOp>(
    loc, VectorType::get(numElements * scale, oldElementType), newRead);
```

Attempts to use these patterns in more generic cases fail, as shown
below:

```mlir
func.func @vector_maskedload_2d_i8_negative(
  %idx1: index,
  %idx2: index,
  %num_elems: index,
  %passthru: vector<2x4xi8>) -> vector<2x4xi8> {

    %0 = memref.alloc() : memref<3x4xi8>
    %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1>
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
      memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
    return %1 : vector<2x4xi8>

}
```

For example, casting to i32 produces:
```bash
error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
         ^
```

Instead of reworking these patterns (that's going to require much more
effort), I’ve marked them as 1-D only and extended
"TestEmulateNarrowTypePass" with an option to disable the Memref type
converter - that's to be able to add negative tests (otherwise, the type
converter throws an error we can't really test for). While not ideal,
this workaround should suit a test pass.
---
 .../Transforms/VectorEmulateNarrowType.cpp    |  20 ++++
 .../emulate-narrow-type-unsupported.mlir      | 112 ++++++++++++++++++
 .../Dialect/MemRef/TestEmulateNarrowType.cpp  |  11 +-
 3 files changed, 142 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..591194f43018ad 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -217,6 +217,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getValueToStore().getType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -283,6 +287,10 @@ struct ConvertVectorMaskedStore final
   matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getValueToStore().getType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -372,6 +380,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
@@ -473,6 +485,10 @@ struct ConvertVectorMaskedLoad final
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
@@ -624,6 +640,10 @@ struct ConvertVectorTransferRead final
   matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
     Type oldElementType = op.getType().getElementType();
diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..30ce13e8169c47
--- /dev/null
+++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,112 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s
+
+// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
+// insted of 1-D vectors. That's currently not supported.
+
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
+    return %1 : vector<2x4xi8>
+}
+
+// No support for loading 2D vectors - expect no conversions
+//  CHECK-LABEL: func @vector_load_2d_i8_negative
+//        CHECK:   memref.alloc() : memref<3x4xi8>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
+func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
+    %c0 = arith.constant 0 : i4
+    %0 = memref.alloc() : memref<3x8xi4>
+    %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
+      memref<3x8xi4>, vector<2x8xi4>
+    return %1 : vector<2x8xi4>
+}
+//  CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
+//        CHECK:  memref.alloc() : memref<3x8xi4>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
+    %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+      memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
+    return %1 : vector<2x4xi8>
+}
+
+//  CHECK-LABEL: func @vector_maskedload_2d_i8_negative
+//        CHECK:  memref.alloc() : memref<3x4xi8>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
+    %0 = memref.alloc() : memref<8x8x16xi4>
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c8 = arith.constant 8 : index
+    %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+    %cst_2 = arith.constant dense<0> : vector<16xi4>
+    %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
+    %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+    %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+    %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+    %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+    return %63 : vector<8x8x16xi4>
+}
+
+//  CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
+//        CHECK: memref.alloc() : memref<8x8x16xi4>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xi8>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+    return
+}
+
+//  CHECK-LABEL: func @vector_store_2d_i8_negative
+//        CHECK: memref.alloc() : memref<4x8xi8>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+  %0 = memref.alloc() : memref<3x8xi8>
+  %mask = vector.create_mask %arg2 : vector<8xi1>
+  vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+  return
+}
+
+//  CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
+//        CHECK: memref.alloc() : memref<3x8xi8>
+//    CHECK-NOT: i32
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index eeb26d1876c1db..7401e470ed4f2c 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -78,7 +78,11 @@ struct TestEmulateNarrowTypePass
           IntegerType::get(ty.getContext(), arithComputeBitwidth));
     });
 
-    memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+    // With the type converter enabled, we are effectively unable to write
+    // negative tests. This is a workaround specifically for negative tests.
+    if (!disableMemrefTypeConversion)
+      memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+
     ConversionTarget target(*ctx);
     target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
       return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
@@ -109,6 +113,11 @@ struct TestEmulateNarrowTypePass
   Option<unsigned> arithComputeBitwidth{
       *this, "arith-compute-bitwidth",
       llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
+
+  Option<bool> disableMemrefTypeConversion{
+      *this, "skip-memref-type-conversion",
+      llvm::cl::desc("disable memref type conversion (to test failures)"),
+      llvm::cl::init(false)};
 };
 } // namespace
 

>From dfb853798d3eae6c903ababd0956570832aa24b5 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 10 Nov 2024 11:38:25 +0000
Subject: [PATCH 2/2] fixup! [mlir][vector] Restrict narrow-type-emulation
 patterns

Update tests that were still loading/storing 1-D vectors
---
 .../Transforms/VectorEmulateNarrowType.cpp      |  5 +++++
 .../Vector/emulate-narrow-type-unsupported.mlir | 17 ++++++++---------
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 591194f43018ad..c404bdd049cc7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -217,6 +217,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    // See #115653
     if (op.getValueToStore().getType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
@@ -287,6 +288,7 @@ struct ConvertVectorMaskedStore final
   matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    // See #115653
     if (op.getValueToStore().getType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
@@ -380,6 +382,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    // See #115653
     if (op.getVectorType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
@@ -485,6 +488,7 @@ struct ConvertVectorMaskedLoad final
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    // See #115653
     if (op.getVectorType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
@@ -640,6 +644,7 @@ struct ConvertVectorTransferRead final
   matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    // See #115653
     if (op.getVectorType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
index 30ce13e8169c47..a5a6fc4acfe10f 100644
--- a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
+++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
@@ -65,12 +65,11 @@ func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x
     %c16 = arith.constant 16 : index
     %c8 = arith.constant 8 : index
     %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
-    %cst_2 = arith.constant dense<0> : vector<16xi4>
+    %cst_2 = arith.constant dense<0> : vector<8x16xi4>
     %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
     %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
-    %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
-    %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
-    %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+    %50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4>
+    %63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4>
     return %63 : vector<8x8x16xi4>
 }
 
@@ -84,9 +83,9 @@ func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x
 /// vector.store
 ///----------------------------------------------------------------------------------------
 
-func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg2: index) {
     %0 = memref.alloc() : memref<4x8xi8>
-    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<2x8xi8>
     return
 }
 
@@ -100,10 +99,10 @@ func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2:
 /// vector.maskedstore
 ///----------------------------------------------------------------------------------------
 
-func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
-  %mask = vector.create_mask %arg2 : vector<8xi1>
-  vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+  %mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1>
+  vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8>
   return
 }
 



More information about the Mlir-commits mailing list