[Mlir-commits] [mlir] [mlir][ArmSME] Fix loop bounds of masked loads/stores (PR #78983)

Cullen Rhodes llvmlistbot at llvm.org
Tue Jan 23 01:59:24 PST 2024


================
@@ -47,99 +48,103 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 
   if (rank == 2)
     outIndices.push_back(indices[1]);
-}
 
-/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
-/// using `arm_sme.load_tile_slice`.
-///
-///  BEFORE:
-///  ```mlir
-///  %tile = arm_sme.tile_load %src[%c0, %c0] :
-///    memref<?x?xi32>, vector<[4]x[4]xi32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-///  %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-///  %vscale = vector.vscale
-///  %c0 = arith.constant 0 : index
-///  %c1 = arith.constant 1 : index
-///  %min_svl_s = arith.constant 4 : index
-///  %svl_s = arith.muli %min_svl_s, %vscale : index
-///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
-///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
-///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-///      %ptrue_s, %iter_tile, %tile_slice_idx
-///        : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-///    scf.yield %tile_update : vector<[4]x[4]xi32>
-///  }
-///  ```
-struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
-  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
-                                PatternRewriter &rewriter) const override {
-    if (tileLoadOp.getMask())
-      return rewriter.notifyMatchFailure(tileLoadOp,
-                                         "op has mask, apply masked patterns");
-
-    OpBuilder::InsertionGuard g(rewriter);
-    auto loc = tileLoadOp.getLoc();
-    auto tileType = tileLoadOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-
-    // Allocate a new SME tile.
-    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
-        rewriter, loc, tileType);
+  return outIndices;
+}
 
-    // Create a loop that loads each ZA tile slice from memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
-                                             step, ValueRange{initTile});
+/// Creates an scf.for for the load/store of an ArmSME tile.
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+    PatternRewriter &rewriter, Location loc, VectorType tileType,
+    ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
+    function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
+                       /*currentTile=*/Value)>
+        makeLoopBody) {
+  PatternRewriter::InsertionGuard guard(rewriter);
+
+  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+      loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+  auto vscale =
+      rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+  auto predicateType =
+      VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+  // This describes both the number of ZA tile slices and the number of
+  // elements in a vector of SVL bits for a given element type (SVL_B,
+  // SVL_H,
+  // ..., SVL_Q).
+  auto numTileSlices =
+      rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+  Value predicate;
+  Value upperBound = numTileSlices;
+  if (mask) {
+    auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          loc, "unsupported mask op, only 'vector.create_mask' is "
+               "currently supported");
+
+    auto maskDim0 = createMaskOp.getOperands()[0];
+    auto maskDim1 = createMaskOp.getOperands()[1];
+
+    // The upper bound of the loop must be clamped at `numTileSlices` as
+    // `vector.create_mask` allows operands to be greater than the size of a
+    // dimension.
+    auto numRowI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), maskDim0);
+    auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), numTileSlices);
+    auto upperBoundI64 =
+        rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
+    upperBound = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getIndexType(), upperBoundI64);
+
+    predicate =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+  } else {
+    // No mask. Create an 'all true' predicate for the tile slice.
+    predicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+  }
 
-    rewriter.setInsertionPointToStart(forOp.getBody());
+  bool loopOverTile = bool(initTile);
+  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+                                           loopOverTile ? ValueRange{initTile}
+                                                        : ValueRange{});
 
-    // Create an 'all true' predicate for the tile slice.
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(predicateType, true));
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  Value tileSliceIndex = forOp.getInductionVar();
 
-    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
-    // tile.
-    SmallVector<Value> memrefIndices;
-    auto tileSliceIndex = forOp.getInductionVar();
-    getMemrefIndices(tileLoadOp.getIndices(),
-                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
-                     numTileSlices, memrefIndices, loc, rewriter);
-    auto currentTile = forOp.getRegionIterArg(0);
-    auto loadSlice =
-        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-            rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
-            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
-    rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
+  auto adjustedIndices = getMemrefIndices(
+      memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
+  auto nextTile =
+      makeLoopBody(tileSliceIndex, adjustedIndices, predicate,
+                   loopOverTile ? forOp.getRegionIterArg(0) : Value{});
----------------
c-rhodes wrote:

```suggestion
                   /*currentTile=*/loopOverTile ? forOp.getRegionIterArg(0) : Value{});
```

https://github.com/llvm/llvm-project/pull/78983


More information about the Mlir-commits mailing list