[Mlir-commits] [mlir] [mlir][ArmSME] Fix loop bounds of masked loads/stores (PR #78983)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jan 23 09:22:33 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/78983

>From b9bd637d21b404178be9c20e17fbbc35529864f9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jan 2024 16:36:05 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Refactor ArmSMEToSCF to used shared
 loop-building helper (NFC)

This will make fixing a bug (next patch) a change to one place, rather
than fixing three separate rewrites.

Note: `TileLoadOpWithMaskAndPadZeroConversion` has been merged into
`TileLoadOpConversion`, since after this change those two rewrites were
pretty much identical.
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 348 +++++++-----------
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |   4 +-
 2 files changed, 136 insertions(+), 216 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index c9d7c0c313b5c8..85ff10387628e4 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -30,11 +30,12 @@ namespace {
 /// `outIndices`:
 ///   rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
 ///   rank 2: (indices[0] + tileSliceIndex, indices[1])
-void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
-                      Value tileSliceNumElts,
-                      SmallVectorImpl<Value> &outIndices, Location loc,
-                      PatternRewriter &rewriter) {
+SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
+                                       Value tileSliceIndex,
+                                       Value tileSliceNumElts, Location loc,
+                                       PatternRewriter &rewriter) {
   assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+  SmallVector<Value, 2> outIndices;
 
   auto tileSliceOffset = tileSliceIndex;
   if (rank == 1)
@@ -47,99 +48,92 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 
   if (rank == 2)
     outIndices.push_back(indices[1]);
-}
 
-/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
-/// using `arm_sme.load_tile_slice`.
-///
-///  BEFORE:
-///  ```mlir
-///  %tile = arm_sme.tile_load %src[%c0, %c0] :
-///    memref<?x?xi32>, vector<[4]x[4]xi32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-///  %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-///  %vscale = vector.vscale
-///  %c0 = arith.constant 0 : index
-///  %c1 = arith.constant 1 : index
-///  %min_svl_s = arith.constant 4 : index
-///  %svl_s = arith.muli %min_svl_s, %vscale : index
-///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
-///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
-///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-///      %ptrue_s, %iter_tile, %tile_slice_idx
-///        : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-///    scf.yield %tile_update : vector<[4]x[4]xi32>
-///  }
-///  ```
-struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
-  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
-                                PatternRewriter &rewriter) const override {
-    if (tileLoadOp.getMask())
-      return rewriter.notifyMatchFailure(tileLoadOp,
-                                         "op has mask, apply masked patterns");
-
-    OpBuilder::InsertionGuard g(rewriter);
-    auto loc = tileLoadOp.getLoc();
-    auto tileType = tileLoadOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-
-    // Allocate a new SME tile.
-    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
-        rewriter, loc, tileType);
+  return outIndices;
+}
 
-    // Create a loop that loads each ZA tile slice from memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
-                                             step, ValueRange{initTile});
+/// Creates an scf.for for the load/store of an ArmSME tile.
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+    PatternRewriter &rewriter, Location loc, VectorType tileType,
+    ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
+    function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
+                       /*currentTile=*/Value)>
+        makeLoopBody) {
+  PatternRewriter::InsertionGuard guard(rewriter);
+
+  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+      loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+  auto vscale =
+      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+  auto predicateType =
+      VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+  // This describes both the number of ZA tile slices and the number of
+  // elements in a vector of SVL bits for a given element type (SVL_B,
+  // SVL_H, ..., SVL_Q).
+  auto numTileSlices =
+      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+  Value predicate;
+  Value upperBound;
+  if (mask) {
+    auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          loc, "unsupported mask op, only 'vector.create_mask' is "
+               "currently supported");
+
+    auto maskDim0 = createMaskOp.getOperands()[0];
+    auto maskDim1 = createMaskOp.getOperands()[1];
+
+    upperBound = maskDim0;
+    predicate =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+  } else {
+    upperBound = numTileSlices;
+    // No mask. Create an 'all true' predicate for the tile slice.
+    predicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+  }
 
-    rewriter.setInsertionPointToStart(forOp.getBody());
+  bool hasCarriedArgs = bool(initTile);
+  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+                                           hasCarriedArgs ? ValueRange{initTile}
+                                                          : ValueRange{});
 
-    // Create an 'all true' predicate for the tile slice.
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(predicateType, true));
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  Value tileSliceIndex = forOp.getInductionVar();
 
-    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
-    // tile.
-    SmallVector<Value> memrefIndices;
-    auto tileSliceIndex = forOp.getInductionVar();
-    getMemrefIndices(tileLoadOp.getIndices(),
-                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
-                     numTileSlices, memrefIndices, loc, rewriter);
-    auto currentTile = forOp.getRegionIterArg(0);
-    auto loadSlice =
-        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-            rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
-            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
-    rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
+  auto adjustedIndices = getMemrefIndices(
+      memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
+  auto nextTile = makeLoopBody(
+      tileSliceIndex, adjustedIndices, predicate,
+      /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
 
-    rewriter.setInsertionPointAfter(forOp);
+  assert(bool(nextTile) == hasCarriedArgs);
+  if (nextTile)
+    rewriter.create<scf::YieldOp>(loc, nextTile);
 
-    // Replace 'arm_sme.tile_load' with the result.
-    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
+  return forOp;
+}
 
-    return success();
-  }
-};
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+    PatternRewriter &rewriter, Location loc, VectorType tileType,
+    ValueRange memrefIndices, int memrefRank, Value mask,
+    function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
+        makeLoopBody) {
+  return createLoadStoreForOverTileSlices(
+      rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
+      [&](Value index, ValueRange adjustedIndices, Value predicate,
+          Value) -> Value {
+        makeLoopBody(index, adjustedIndices, predicate);
+        return {};
+      });
+}
 
-/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
 ///
 ///  BEFORE:
 ///  ```mlir
@@ -168,77 +162,56 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  ```
 ///
 /// NOTE: Only mask of 'vector.create_mask' op is currently supported.
-struct TileLoadOpWithMaskAndPadZeroConversion
-    : public OpRewritePattern<arm_sme::TileLoadOp> {
+struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
+    auto mask = tileLoadOp.getMask();
 
-    auto maskOp = tileLoadOp.getMask();
-    if (!maskOp)
-      return rewriter.notifyMatchFailure(
-          tileLoadOp, "op has no mask, needs unmasked pattern");
-
-    auto padOp = tileLoadOp.getPadding();
-    assert(padOp && "expected padding when masking!");
+    Value initTile;
+    if (mask) {
+      auto padOp = tileLoadOp.getPadding();
+      assert(padOp && "expected padding when masking!");
 
-    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-    if (!createMaskOp)
-      return rewriter.notifyMatchFailure(
-          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
-                      "currently supported");
-
-    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
-    if (!constPadOp || constPadOp.getValue() !=
-                           rewriter.getZeroAttr(tileType.getElementType()))
-      return rewriter.notifyMatchFailure(
-          tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
-
-    auto numRows = createMaskOp.getOperands()[0];
-    auto numCols = createMaskOp.getOperands()[1];
-
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto numColsOp =
-        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+      auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+      if (!constPadOp || constPadOp.getValue() !=
+                             rewriter.getZeroAttr(tileType.getElementType()))
+        return rewriter.notifyMatchFailure(
+            tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
 
-    // Initialize tile with zero to satisfy padding. Inactive cols will be
-    // zeroed anyway since the loads use zeroing predication. For inactive rows
-    // however, no load will occur so these need to be zeroed.
-    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
-        rewriter, loc, tileType);
+      // Initialize tile with zero to satisfy padding. Inactive cols will be
+      // zeroed anyway since the loads use zeroing predication. For inactive
+      // rows however, no load will occur so these need to be zeroed.
+      initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+          rewriter, loc, tileType);
+    } else {
+      // Allocate a new SME tile.
+      initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+          rewriter, loc, tileType);
+    }
 
     // Create a loop to load the active tile slices from memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto upperBound = numRows;
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
-                                             ValueRange{initTile});
-
-    rewriter.setInsertionPointToStart(forOp.getBody());
-
-    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
-    // tile.
-    SmallVector<Value> memrefIndices;
-    auto tileSliceIndex = forOp.getInductionVar();
-    auto currentTile = forOp.getRegionIterArg(0);
-    getMemrefIndices(tileLoadOp.getIndices(),
-                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
-                     upperBound, memrefIndices, loc, rewriter);
-    auto loadSlice =
-        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-            rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
-            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
-    rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
-
-    rewriter.setInsertionPointAfter(forOp);
+    auto forOp = createLoadStoreForOverTileSlices(
+        rewriter, loc, tileType, tileLoadOp.getIndices(),
+        tileLoadOp.getMemRefType().getRank(), mask, initTile,
+        [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
+            Value currentTile) -> Value {
+          // Create 'arm_sme.load_tile_slice' to load tile slice from memory
+          // into tile.
+          return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+              rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+              currentTile, memrefIndices, tileSliceIndex,
+              tileLoadOp.getLayout());
+        });
+
+    if (failed(forOp))
+      return forOp;
 
     // Replace 'arm_sme.tile_load' with the result.
-    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
+    rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
 
     return success();
   }
@@ -345,10 +318,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
         loc, predicateType, maskIndex.getResult());
 
-    SmallVector<Value> memrefIndices;
-    getMemrefIndices(tileLoadOp.getIndices(),
-                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
-                     numTileSlices, memrefIndices, loc, rewriter);
+    auto memrefIndices = getMemrefIndices(
+        tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
+        tileSliceIndex, numTileSlices, loc, rewriter);
 
     // Splat pad into 1-D vector matching type of tile slice.
     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
@@ -400,77 +372,25 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
   LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
                                 PatternRewriter &rewriter) const override {
-    OpBuilder::InsertionGuard g(rewriter);
-    auto loc = tileStoreOp.getLoc();
-    auto tileType = tileStoreOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-
-    Value maskCols;
-    Value upperBound;
-    auto maskOp = tileStoreOp.getMask();
-    if (maskOp) {
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp)
-        return rewriter.notifyMatchFailure(
-            tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
-                         "currently supported");
-
-      auto numRows = createMaskOp.getOperands()[0];
-      auto numCols = createMaskOp.getOperands()[1];
-
-      upperBound = numRows;
-      maskCols =
-          rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
-    } else {
-      // Store all tile slices if no mask.
-      auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-          loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-      auto vscale =
-          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-      // This describes both the number of ZA tile slices and the number of
-      // elements in a vector of SVL bits for a given element type (SVL_B,
-      // SVL_H,
-      // ..., SVL_Q).
-      auto numTileSlices =
-          rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
-      upperBound = numTileSlices;
-      // Create an 'all true' predicate for the tile slice.
-      maskCols = rewriter.create<arith::ConstantOp>(
-          loc, DenseElementsAttr::get(predicateType, true));
-    }
-
     // Create a loop that stores each (active) active ZA tile slice from memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-
-    rewriter.setInsertionPointToStart(forOp.getBody());
-
-    SmallVector<Value> memrefIndices;
-    auto tileSliceIndex = forOp.getInductionVar();
-    getMemrefIndices(tileStoreOp.getIndices(),
-                     tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
-                     upperBound, memrefIndices, loc, rewriter);
-
-    tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
-        rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
-        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
-
-    return success();
+    return createLoadStoreForOverTileSlices(
+        rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
+        tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
+        tileStoreOp.getMask(),
+        [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
+          tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+              rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+              predicate, tileStoreOp.getBase(), memrefIndices,
+              tileStoreOp.getLayout());
+        });
   }
 };
 
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
-           TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
-          patterns.getContext());
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
+               TileStoreOpConversion>(patterns.getContext());
 }
 
 namespace {
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 5d79a0405114a2..292f9a4d411ff7 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -11,9 +11,9 @@
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
-// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
-// CHECK-NEXT:      %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK-NEXT:      %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 // CHECK-NEXT:      scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>

>From 5c1bccbae611f846bb487805f02677495dad79e2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 22 Jan 2024 12:01:48 +0000
Subject: [PATCH 2/2] [mlir][ArmSME] Fix loop bounds of masked loads/stores

Previously, for masked tile loads/stores we directly used the dimension
size from the `vector.create_mask` operation as the upper bound of the
`scf.for` over the tile slices. This was not correct, as `create_mask`
allows operands to be greater than the size of the vector dimension, in
which case the for loop bounds should be clamped to the number of tile
slices.
---
 .../lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 13 ++++++++++++-
 .../Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir | 18 ++++++++++++++++--
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 85ff10387628e4..adf3aca91ba8b5 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -86,7 +86,18 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
     auto maskDim0 = createMaskOp.getOperands()[0];
     auto maskDim1 = createMaskOp.getOperands()[1];
 
-    upperBound = maskDim0;
+    // The upper bound of the loop must be clamped at `numTileSlices` as
+    // `vector.create_mask` allows operands to be greater than the size of a
+    // dimension.
+    auto numRowI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), maskDim0);
+    auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), numTileSlices);
+    auto upperBoundI64 =
+        rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
+    upperBound = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getIndexType(), upperBoundI64);
+
     predicate =
         rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
   } else {
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 292f9a4d411ff7..6c393bc38af9c7 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -39,10 +39,17 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
 // CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
+// CHECK-DAG:     %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
 // CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
 // CHECK-DAG:     %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK-NEXT:      %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 // CHECK-NEXT:      scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
@@ -150,9 +157,16 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
 // CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
+// CHECK-DAG:     %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
 // CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
-// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] {
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK-NEXT:      arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {



More information about the Mlir-commits mailing list