[Mlir-commits] [mlir] [mlir][ArmSME] Lower extract from 2D scalable create_mask to psel (PR #96066)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jun 20 02:10:46 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/96066

>From 336757e607101a6a04682b9b3cff8e8511b88fc6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 18 Jun 2024 16:56:19 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Lower extract from 2D scalable create_mask
 to psel

Example:
```mlir
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
                                  : vector<[8]xi1> from vector<[4]x[8]xi1>
```
Becomes:
```
%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
                                   : vector<[8]xi1>, vector<[4]xi1>
```

Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this
is currently the most logical place for this lowering.
---
 mlir/include/mlir/Conversion/Passes.td        |  2 +-
 .../Conversion/VectorToArmSME/CMakeLists.txt  |  1 +
 .../VectorToArmSME/VectorToArmSME.cpp         | 78 ++++++++++++++++++-
 .../VectorToArmSME/VectorToArmSMEPass.cpp     |  1 +
 .../VectorToArmSME/unsupported.mlir           | 51 ++++++++++++
 .../VectorToArmSME/vector-to-arm-sme.mlir     | 32 ++++++++
 6 files changed, 161 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index db67d6a5ff128..9ab5faf9559a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
     Pass that converts vector dialect operations into equivalent ArmSME dialect
     operations.
   }];
-  let dependentDialects = ["arm_sme::ArmSMEDialect"];
+  let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index b062f65e914e8..6a81a09776d37 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
+  MLIRArmSVEDialect
   MLIRLLVMCommonConversion
   )
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 56ae46a6098ee..0e8575531d9b0 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
@@ -549,6 +550,77 @@ struct VectorExtractToArmSMELowering
   }
 };
 
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+///                                   : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+///                                    : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct VectorExtractFromMaskToPselLowering
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    if (extractOp.getNumIndices() != 1)
+      return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+    auto resultType = extractOp.getResult().getType();
+    auto resultVectorType = dyn_cast<VectorType>(resultType);
+    if (!resultVectorType)
+      return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
+
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
+
+    auto maskType = createMaskOp.getVectorType();
+    if (maskType.getRank() != 2 || !maskType.allDimsScalable())
+      return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
+
+    auto isSVEPredicateSize = [](int64_t size) {
+      return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
+    };
+
+    auto rowsBaseSize = maskType.getDimSize(0);
+    auto colsBaseSize = maskType.getDimSize(1);
+    if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "mask dimensions not SVE predicate-sized");
+
+    auto loc = extractOp.getLoc();
+    VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
+    VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
+
+    // Create the two 1-D masks at the location of the 2-D create_mask (which is
+    // usually outside a loop). This prevents the need for later hoisting.
+    rewriter.setInsertionPoint(createMaskOp);
+    auto rowMask = rewriter.create<vector::CreateMaskOp>(
+        loc, rowMaskType, createMaskOp.getOperand(0));
+    auto colMask = rewriter.create<vector::CreateMaskOp>(
+        loc, colMaskType, createMaskOp.getOperand(1));
+
+    rewriter.setInsertionPoint(extractOp);
+    auto position =
+        vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
+    rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
+                                                 position[0]);
+    return success();
+  }
+};
+
 /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
 /// `arm_sme.move_tile_slice_to_vector`.
 ///
@@ -728,7 +800,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
            TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
            TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
            VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
-           VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
-           VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
-          &ctx);
+           VectorExtractToArmSMELowering, VectorExtractFromMaskToPselLowering,
+           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering,
+           FoldTransferWriteOfExtractTileSlice>(&ctx);
 }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
index 2601f31be11a3..cc00bf4ca190a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 8ed52cde784ce..ff7b4bcb5f65a 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
+
+// -----
+
+/// Not SVE predicate-sized.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_0
+func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
+{
+  // CHECK-NOT: arm_sve.psel
+  %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
+  %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
+  return %slice : vector<[32]xi1>
+}
+
+// -----
+
+/// Source not 2-D scalable mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_1
+func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+  // CHECK-NOT: arm_sve.psel
+  %mask = vector.create_mask %a, %b : vector<4x[8]xi1>
+  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Source not vector.create_mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_2
+func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
+{
+  // CHECK-NOT: arm_sve.psel
+  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Not psel-like extract.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_3
+func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
+{
+  // CHECK-NOT: arm_sve.psel
+  %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+  %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
+  return %el : i1
+}
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 8aeffb066de90..ff21c70b2aa55 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -1320,3 +1320,35 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
   %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }
+
+// -----
+
+// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
+// CHECK-SAME:                                       %[[A:[a-z0-9]+]]:  index,
+// CHECK-SAME:                                       %[[B:[a-z0-9]+]]: index,
+// CHECK-SAME:                                       %[[INDEX:[a-z0-9]+]]: index)
+func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+  // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
+  // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
+  // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
+  %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_mask_to_psel(
+// CHECK-SAME:                               %[[A:[a-z0-9]+]]:  index,
+// CHECK-SAME:                               %[[B:[a-z0-9]+]]: index)
+func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
+{
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
+  // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
+  // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
+  %mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
+  %slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
+  return %slice : vector<[2]xi1>
+}

>From 855775589c7a18019f4d1601b58cf63b67284947 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 19 Jun 2024 16:02:29 +0000
Subject: [PATCH 2/3] Fixups

---
 .../VectorToArmSME/VectorToArmSME.cpp         | 157 +++++++++---------
 1 file changed, 78 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0e8575531d9b0..ee52b9ef6a6f6 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -550,77 +550,6 @@ struct VectorExtractToArmSMELowering
   }
 };
 
-/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
-/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
-/// SVE 2.1), so this is currently the most logical place for this lowering.
-///
-/// Example:
-/// ```mlir
-/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
-/// %slice = vector.extract %mask[%index]
-///                                   : vector<[8]xi1> from vector<[4]x[8]xi1>
-/// ```
-/// Becomes:
-/// ```
-/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
-/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
-/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
-///                                    : vector<[8]xi1>, vector<[4]xi1>
-/// ```
-struct VectorExtractFromMaskToPselLowering
-    : public OpRewritePattern<vector::ExtractOp> {
-  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    if (extractOp.getNumIndices() != 1)
-      return rewriter.notifyMatchFailure(extractOp, "not single extract index");
-
-    auto resultType = extractOp.getResult().getType();
-    auto resultVectorType = dyn_cast<VectorType>(resultType);
-    if (!resultVectorType)
-      return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
-
-    auto createMaskOp =
-        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
-    if (!createMaskOp)
-      return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
-
-    auto maskType = createMaskOp.getVectorType();
-    if (maskType.getRank() != 2 || !maskType.allDimsScalable())
-      return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
-
-    auto isSVEPredicateSize = [](int64_t size) {
-      return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
-    };
-
-    auto rowsBaseSize = maskType.getDimSize(0);
-    auto colsBaseSize = maskType.getDimSize(1);
-    if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
-      return rewriter.notifyMatchFailure(
-          createMaskOp, "mask dimensions not SVE predicate-sized");
-
-    auto loc = extractOp.getLoc();
-    VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
-    VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
-
-    // Create the two 1-D masks at the location of the 2-D create_mask (which is
-    // usually outside a loop). This prevents the need for later hoisting.
-    rewriter.setInsertionPoint(createMaskOp);
-    auto rowMask = rewriter.create<vector::CreateMaskOp>(
-        loc, rowMaskType, createMaskOp.getOperand(0));
-    auto colMask = rewriter.create<vector::CreateMaskOp>(
-        loc, colMaskType, createMaskOp.getOperand(1));
-
-    rewriter.setInsertionPoint(extractOp);
-    auto position =
-        vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
-                                                 position[0]);
-    return success();
-  }
-};
-
 /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
 /// `arm_sme.move_tile_slice_to_vector`.
 ///
@@ -791,16 +720,86 @@ struct FoldTransferWriteOfExtractTileSlice
   }
 };
 
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+///            : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+///            : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct ExtractFromCreateMaskToPselLowering
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    if (extractOp.getNumIndices() != 1)
+      return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+    auto resultType = extractOp.getResult().getType();
+    auto resultVectorType = dyn_cast<VectorType>(resultType);
+    if (!resultVectorType)
+      return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
+
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
+
+    auto maskType = createMaskOp.getVectorType();
+    if (maskType.getRank() != 2 || !maskType.allDimsScalable())
+      return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
+
+    auto isSVEPredicateSize = [](int64_t size) {
+      return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
+    };
+
+    auto rowsBaseSize = maskType.getDimSize(0);
+    auto colsBaseSize = maskType.getDimSize(1);
+    if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "mask dimensions not SVE predicate-sized");
+
+    auto loc = extractOp.getLoc();
+    VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
+    VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
+
+    // Create the two 1-D masks at the location of the 2-D create_mask (which is
+    // usually outside a loop). This prevents the need for later hoisting.
+    rewriter.setInsertionPoint(createMaskOp);
+    auto rowMask = rewriter.create<vector::CreateMaskOp>(
+        loc, rowMaskType, createMaskOp.getOperand(0));
+    auto colMask = rewriter.create<vector::CreateMaskOp>(
+        loc, colMaskType, createMaskOp.getOperand(1));
+
+    rewriter.setInsertionPoint(extractOp);
+    auto position =
+        vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
+    rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
+                                                 position[0]);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns
-      .add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
-           TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
-           TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
-           VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
-           VectorExtractToArmSMELowering, VectorExtractFromMaskToPselLowering,
-           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering,
-           FoldTransferWriteOfExtractTileSlice>(&ctx);
+  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+               TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
+               TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
+               VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
+               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
+               VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
+               ExtractFromCreateMaskToPselLowering>(&ctx);
 }

>From 586382e31767c496edca22944005d44b4032ce60 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 09:09:50 +0000
Subject: [PATCH 3/3] Fixups

---
 .../VectorToArmSME/vector-to-arm-sme.mlir          | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index ff21c70b2aa55..068fd0d04f1bc 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -1124,7 +1124,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect
 }
 
 //===----------------------------------------------------------------------===//
-// vector.extract
+// vector.extract --> arm_sme.move_tile_slice_to_vector
 //===----------------------------------------------------------------------===//
 
 // -----
@@ -1321,12 +1321,14 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
   return %el : f64
 }
 
+//===----------------------------------------------------------------------===//
+// vector.extract --> arm_sve.psel
+//===----------------------------------------------------------------------===//
+
 // -----
 
 // CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
-// CHECK-SAME:                                       %[[A:[a-z0-9]+]]:  index,
-// CHECK-SAME:                                       %[[B:[a-z0-9]+]]: index,
-// CHECK-SAME:                                       %[[INDEX:[a-z0-9]+]]: index)
+// CHECK-SAME:    %[[A:.*]]:  index, %[[B:.*]]: index, %[[INDEX:.*]]: index)
 func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
 {
   // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
@@ -1340,8 +1342,8 @@ func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: ind
 // -----
 
 // CHECK-LABEL: @vector_extract_mask_to_psel(
-// CHECK-SAME:                               %[[A:[a-z0-9]+]]:  index,
-// CHECK-SAME:                               %[[B:[a-z0-9]+]]: index)
+// CHECK-SAME:                               %[[A:.*]]: index,
+// CHECK-SAME:                               %[[B:.*]]: index)
 func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
 {
   // CHECK: %[[C1:.*]] = arith.constant 1 : index



More information about the Mlir-commits mailing list