[Mlir-commits] [mlir] [mlir][ArmSME] Add support for lowering masked tile_load ops (PR #70915)

Cullen Rhodes llvmlistbot at llvm.org
Fri Nov 3 04:30:53 PDT 2023


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/70915

>From 39d75e2557a7639f383726bd4b4893e47c0ca333 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 09:15:52 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Add support for lowering masked tile_load
 ops

This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                          vector<[4]x[4]xi32>

The former (constant non-zero pad) is lowered as follows:
---------------------------------------------------------

  %tile = arm_sme.zero : vector<[4]x[4]xi32>
  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:
------------------------------------------------

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector<[4]xi32> into vector<[4]x[4]xi32>

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 254 +++++++++++++++++-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  56 ++++
 .../CPU/ArmSME/test-transfer-read-2d.mlir     | 212 +++++++++++++++
 3 files changed, 519 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 80da6ffda1ed2ea..b46be34e7fff359 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,9 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
     if (tileLoadOp.getMask())
-      // TODO: add masked patterns.
       return rewriter.notifyMatchFailure(
-          tileLoadOp, "op has mask, needs masked pattern(s)");
+          tileLoadOp, "op has mask, apply masked patterns");
 
     OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();
@@ -142,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   }
 };
 
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 0 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///    %tile_update = arm_sme.load_tile_slice
+///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (!constPadOp || constPadOp.getValue() !=
+                           rewriter.getZeroAttr(tileType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Initialize tile with zero to satisfy padding. Inactive cols will be
+    // zeroed anyway since the loads use zeroing predication. For inactive rows
+    // however, no load will occur so these need to be zeroed.
+    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+    // Create a loop to load the active tile slices from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = numRows;
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+    // tile.
+    SmallVector<Value> memrefIndices;
+    auto tileSliceIndex = forOp.getInductionVar();
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     upperBound, memrefIndices, loc, rewriter);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+        tileSliceIndex, tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 1 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+///    %slice = scf.if %row_is_active -> vector<[4]xi32> {
+///      %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+///        : memref<?x?xi32>, vector<[4]xi1>,
+///          vector<[4]xi32> into vector<[4]xi32>
+///      scf.yield %slice : vector<[4]xi32>
+///    } else {
+///      scf.yield %pad_1d : vector<[4]xi32>
+///    }
+///    // Insert slice into tile
+///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (constPadOp &&
+        constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Create 'arm_sme.get_tile' op.
+    auto tileId = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+    // use as input tile to 'arm_sme.load_tile_slice' ops.
+    auto tile =
+        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    auto rowIsActive = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+    SmallVector<Value> memrefIndices;
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
+
+    // Splat pad into 1-D vector matching type of tile slice.
+    auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+    Operation *slice = rewriter.create<scf::IfOp>(
+        loc, rowIsActive,
+        [&](OpBuilder &b, Location loc) {
+          // If the row is active, emit a masked load where the predicate is
+          // 'numCols'. Pad is used for inactive elements, taken from
+          // passthru.
+          auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+              loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+              numColsOp, /*passthru=*/pad1DOp);
+          rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+        },
+        [&](OpBuilder &b, Location loc) {
+          // Inactive rows are filled with pad.
+          rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+        });
+
+    // TODO: If the load is vertical the transpose can't be done in-flight with
+    // a regular (SVE) maskedload. Propagate layout to
+    // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+    // is currently broken.
+
+    // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+        tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
 /// slice using `arm_sme.store_tile_slice`.
 ///
@@ -273,7 +520,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
                TileVectorPrintOpConversion>(patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index d61f588941b408c..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0 : i32
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
+// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK:               %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:               scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK:             } else {
+// CHECK:               scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK:             }
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..644f90d950645b8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,212 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c4 = arith.constant 4 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+    memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0 : vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_f32 = arith.constant 1.0 : f32
+  %c10_f32 = arith.constant 10.0 : f32
+
+  %A = memref.alloc(%d0, %d1) : memref<?x?xf32>
+
+  %init = arith.constant 0.0 : f32
+  scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
+    scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
+      memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
+      %inner_val_next = arith.addf %inner_val, %c1_f32 : f32
+      scf.yield %inner_val_next : f32
+    }
+    %val_next = arith.addf %val, %c10_f32 : f32
+    scf.yield %val_next : f32
+  }
+
+  return %A : memref<?x?xf32>
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+
+  // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
+  // non-zero offsets while remaining inbounds.
+  %vscale = vector.vscale
+  %svl_s = arith.muli %c4, %vscale : index
+  %svl_s_plus_two = arith.addi %svl_s, %c2 : index
+
+  %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
+
+  // 1.a. Read 2D vector from 2D memref.
+  //
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 22, 23
+  // CHECK-NEXT: ( 30, 31, 32, 33
+  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 1.b. Same as 1.a., but with non-zero offsets.
+  //
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 12, 13, 14, 15
+  // CHECK-NEXT: ( 22, 23, 24, 25
+  // CHECK-NEXT: ( 32, 33, 34, 35
+  // CHECK-NEXT: ( 42, 43, 44, 45
+  call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+
+  // 2. Same as 1.a., but with mask and a pad of constant zero.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 0
+  // CHECK-NEXT: ( 10, 11, 12, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  call @transfer_read_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 3. Same as 1.a., but with mask and non-zero pad.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, -42
+  // CHECK-NEXT: ( 10, 11, 12, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  call @transfer_read_2d_mask_non_zero_pad(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 4. Same as 1.a., but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, 20, 30
+  // CHECK-NEXT: ( 1, 11, 21, 31
+  // CHECK-NEXT: ( 2, 12, 22, 32
+  // CHECK-NEXT: ( 3, 13, 23, 33
+  call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Same as 2., but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, 0, 0
+  // CHECK-NEXT: ( 1, 11, 0, 0
+  // CHECK-NEXT: ( 2, 12, 0, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  call @transfer_read_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Same as 3, but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, -42, -42
+  // CHECK-NEXT: ( 1, 11, -42, -42
+  // CHECK-NEXT: ( 2, 12, -42, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  call @transfer_read_2d_mask_non_zero_pad_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  memref.dealloc %A : memref<?x?xf32>
+
+  return
+}

>From 300dc50195d4bae44d1c8fde8cda8a7dda8e0ab1 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 1 Nov 2023 09:18:42 +0000
Subject: [PATCH 2/3] run clang-format

---
 mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index b46be34e7fff359..46e81bb935c406a 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -80,8 +80,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
     if (tileLoadOp.getMask())
-      return rewriter.notifyMatchFailure(
-          tileLoadOp, "op has mask, apply masked patterns");
+      return rewriter.notifyMatchFailure(tileLoadOp,
+                                         "op has mask, apply masked patterns");
 
     OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();

>From c805d053bf82fbe86b38a6190cbcc42c56b426fb Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 3 Nov 2023 11:27:43 +0000
Subject: [PATCH 3/3] Combine masks and replace if

---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 59 ++++++++-----------
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           | 16 ++---
 2 files changed, 33 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 46e81bb935c406a..b3561f725f63214 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -255,6 +255,7 @@ struct TileLoadOpWithMaskAndPadZeroConversion
 ///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
 ///  %num_rows = arith.constant 2 : index
 ///  %num_cols = arith.constant 4 : index
+///  %num_cols_i32 = arith.index_castui %num_cols : index to i32
 ///  %tile_id = arm_sme.get_tile_id : i32
 ///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
@@ -264,14 +265,13 @@ struct TileLoadOpWithMaskAndPadZeroConversion
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
 ///    %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
-///    %slice = scf.if %row_is_active -> vector<[4]xi32> {
-///      %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
-///        : memref<?x?xi32>, vector<[4]xi1>,
-///          vector<[4]xi32> into vector<[4]xi32>
-///      scf.yield %slice : vector<[4]xi32>
-///    } else {
-///      scf.yield %pad_1d : vector<[4]xi32>
-///    }
+///    %row_is_active_i32 = arith.extsi %row_is_active : i1 to i32
+///    %mask = arith.andi %row_is_active_i32, %num_cols_i32 : i32
+///    %mask_index = arith.index_cast %mask : i32 to index
+///    %mask_1d = vector.create_mask %mask_index : vector<[4]xi1>
+///    %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad
+///      : memref<?x?xi32>, vector<[4]xi1>,
+///        vector<[4]xi32> into vector<[4]xi32>
 ///    // Insert slice into tile
 ///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
 ///      : vector<[4]xi32> into vector<[4]x[4]xi32>
@@ -312,11 +312,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto numRows = createMaskOp.getOperands()[0];
     auto numCols = createMaskOp.getOperands()[1];
 
-    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto numColsOp =
-        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+    auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
+        loc, rewriter.getI32Type(), numCols);
 
     // Create 'arm_sme.get_tile' op.
     auto tileId = rewriter.create<arm_sme::GetTileID>(
@@ -343,8 +340,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
 
     auto tileSliceIndex = forOp.getInductionVar();
 
+    // Combine masks.
     auto rowIsActive = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+    auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
+        loc, rewriter.getI32Type(), rowIsActive);
+    auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
+    auto maskIndex =
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
+        loc, predicateType, maskIndex.getResult());
 
     SmallVector<Value> memrefIndices;
     getMemrefIndices(tileLoadOp.getIndices(),
@@ -352,32 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
                      numTileSlices, memrefIndices, loc, rewriter);
 
     // Splat pad into 1-D vector matching type of tile slice.
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
     auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
 
-    Operation *slice = rewriter.create<scf::IfOp>(
-        loc, rowIsActive,
-        [&](OpBuilder &b, Location loc) {
-          // If the row is active, emit a masked load where the predicate is
-          // 'numCols'. Pad is used for inactive elements, taken from
-          // passthru.
-          auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
-              loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
-              numColsOp, /*passthru=*/pad1DOp);
-          rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
-        },
-        [&](OpBuilder &b, Location loc) {
-          // Inactive rows are filled with pad.
-          rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
-        });
-
-    // TODO: If the load is vertical the transpose can't be done in-flight with
-    // a regular (SVE) maskedload. Propagate layout to
-    // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
-    // is currently broken.
+    auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+        loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
+        /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
     rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+        loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
         tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 55ea56f42c96ed9..7b073d37b2f771b 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -66,20 +66,20 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
-// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[NUM_COLS:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
 // CHECK-NEXT:        %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT:        %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32
+// CHECK-NEXT:        %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32
+// CHECK-NEXT:        %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
+// CHECK-NEXT:        %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
-// CHECK:             %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
-// CHECK:               %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK:               scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
-// CHECK:             } else {
-// CHECK:               scf.yield %[[PAD_1D]] : vector<[4]xi32>
-// CHECK:             }
-// CHECK:             arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index



More information about the Mlir-commits mailing list