[Mlir-commits] [mlir] fb8eb42 - [mlir][ArmSME] Fix loop bounds of masked loads/stores (#78983)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 26 01:39:47 PST 2024


Author: Benjamin Maxwell
Date: 2024-01-26T09:39:43Z
New Revision: fb8eb4251ac0ac27ab0125caf1abe3dbb93732f4

URL: https://github.com/llvm/llvm-project/commit/fb8eb4251ac0ac27ab0125caf1abe3dbb93732f4
DIFF: https://github.com/llvm/llvm-project/commit/fb8eb4251ac0ac27ab0125caf1abe3dbb93732f4.diff

LOG: [mlir][ArmSME] Fix loop bounds of masked loads/stores (#78983)

Previously, for masked tile loads/stores we directly used the dimension
size from the `vector.create_mask` operation as the upper bound of the
`scf.for` over the tile slices. This was not correct, as `create_mask`
allows operands to be greater than the size of the vector dimension, in
which case the for loop bounds should be clamped to the number of tile
slices.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
    mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 056d71030bd7986..16b61c282749cf2 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -85,7 +85,18 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
     auto maskDim0 = createMaskOp.getOperands()[0];
     auto maskDim1 = createMaskOp.getOperands()[1];
 
-    upperBound = maskDim0;
+    // The upper bound of the loop must be clamped at `numTileSlices` as
+    // `vector.create_mask` allows operands to be greater than the size of a
+    // dimension.
+    auto numRowI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), maskDim0);
+    auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(), numTileSlices);
+    auto upperBoundI64 =
+        rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
+    upperBound = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getIndexType(), upperBoundI64);
+
     predicate =
         rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
   } else {

diff  --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 292f9a4d411ff7c..6c393bc38af9c74 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -39,10 +39,17 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
 // CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
+// CHECK-DAG:     %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
 // CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
 // CHECK-DAG:     %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK-NEXT:      %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 // CHECK-NEXT:      scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
@@ -150,9 +157,16 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
 // CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
+// CHECK-DAG:     %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
+// CHECK-DAG:     %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
 // CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
-// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] {
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK-NEXT:      arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {


        


More information about the Mlir-commits mailing list