[Mlir-commits] [mlir] [mlir][ArmSME] Fix loop bounds of masked loads/stores (PR #78983)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jan 23 09:25:01 PST 2024


================
@@ -400,77 +383,25 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
   LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
                                 PatternRewriter &rewriter) const override {
-    OpBuilder::InsertionGuard g(rewriter);
-    auto loc = tileStoreOp.getLoc();
-    auto tileType = tileStoreOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-
-    Value maskCols;
-    Value upperBound;
-    auto maskOp = tileStoreOp.getMask();
-    if (maskOp) {
-      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp)
-        return rewriter.notifyMatchFailure(
-            tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
-                         "currently supported");
-
-      auto numRows = createMaskOp.getOperands()[0];
-      auto numCols = createMaskOp.getOperands()[1];
-
-      upperBound = numRows;
-      maskCols =
-          rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
-    } else {
-      // Store all tile slices if no mask.
-      auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-          loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-      auto vscale =
-          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-      // This describes both the number of ZA tile slices and the number of
-      // elements in a vector of SVL bits for a given element type (SVL_B,
-      // SVL_H,
-      // ..., SVL_Q).
-      auto numTileSlices =
-          rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
-      upperBound = numTileSlices;
-      // Create an 'all true' predicate for the tile slice.
-      maskCols = rewriter.create<arith::ConstantOp>(
-          loc, DenseElementsAttr::get(predicateType, true));
-    }
-
     // Create a loop that stores each (active) active ZA tile slice from memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-
-    rewriter.setInsertionPointToStart(forOp.getBody());
-
-    SmallVector<Value> memrefIndices;
-    auto tileSliceIndex = forOp.getInductionVar();
-    getMemrefIndices(tileStoreOp.getIndices(),
-                     tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
-                     upperBound, memrefIndices, loc, rewriter);
-
-    tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
-        rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
-        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
-
-    return success();
+    return createLoadStoreForOverTileSlices(
+        rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
+        tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
+        tileStoreOp.getMask(),
+        [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
+          tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+              rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+              predicate, tileStoreOp.getBase(), memrefIndices,
+              tileStoreOp.getLayout());
+        });
   }
 };
 
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
-           TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
-          patterns.getContext());
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
+               TileStoreOpConversion>(patterns.getContext());
----------------
MacDue wrote:

See: #79172. 

https://github.com/llvm/llvm-project/pull/78983


More information about the Mlir-commits mailing list