[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 31 09:49:19 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

When unrolling the reduction dimension of something like a matmul for SME, you can end up with transposed reads of illegal types, like so:

```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>
```

Here the `vector<[8]x4xf32>` is an illegal type, there's no way to lower a scalable vector of fixed vectors. However, as the final type `vector<4x[8]xf32>` is legal, we can instead lift the transpose to memory (producing a strided memref), and eliminate all the illegal types. This is shown below.

```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>
```

---
Full diff: https://github.com/llvm/llvm-project/pull/80170.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+160-3) 
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+75) 


``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..a3db2d2395528 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,8 +7,6 @@
 //===----------------------------------------------------------------------===//
 //
 // This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
 //
 // Note: In the context of this pass 'tile' always refers to an SME tile.
 //
@@ -19,6 +17,7 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
 // Common match failure reasons.
 static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
     "op vector size is not multiple of SME tiles");
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %illegalRead = vector.transfer_read %memref[%a, %b]
+///                  : memref<?x?xf32>, vector<[8]x4xf32>
+///  %legalType = vector.transpose %illegalRead, [1, 0]
+///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %legalType = vector.transfer_read %transpose[%c0, %c0]
+///                  : memref<?x?xf32>, vector<4x[8]xf32>
+///  ```
+struct LiftIllegalVectorTransposeToMemory
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  static bool isIllegalVectorType(VectorType vType) {
+    bool seenFixedDim = false;
+    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+      seenFixedDim |= !scalableFlag;
+      if (seenFixedDim && scalableFlag)
+        return true;
+    }
+    return false;
+  }
+
+  static Value getExtensionSource(Operation *op) {
+    if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+      return signExtend.getIn();
+    if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+      return zeroExtend.getIn();
+    if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+      return floatExtend.getIn();
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "expected transpose from illegal type to legal type");
+
+    Value maybeRead = transposeOp.getVector();
+    auto *transposeSourceOp = maybeRead.getDefiningOp();
+    Operation *extendOp = nullptr;
+    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+      maybeRead = extendSource;
+      extendOp = transposeSourceOp;
+    }
+
+    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+    if (!illegalRead)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "expected source to be (possibility extended) transfer_read");
+
+    if (!illegalRead.getPermutationMap().isIdentity())
+      return rewriter.notifyMatchFailure(
+          illegalRead, "expected read to have identity permutation map");
+
+    auto loc = transposeOp.getLoc();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Create a subview that matches the size of the illegal read vector type.
+    auto readType = illegalRead.getVectorType();
+    auto readSizes = llvm::map_to_vector(
+        llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+        [&](auto dim) -> Value {
+          auto [size, isScalable] = dim;
+          auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+          if (!isScalable)
+            return dimSize;
+          auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+          return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+        });
+    SmallVector<Value> strides(readType.getRank(), Value(one));
+    auto readSubview = rewriter.create<memref::SubViewOp>(
+        loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+        strides);
+
+    // Apply the transpose to all values/attributes of the transfer_read.
+    // The mask.
+    Value mask = illegalRead.getMask();
+    if (mask) {
+      // Note: The transpose for the mask should fold into the
+      // vector.create_mask/constant_mask op, which will then become legal.
+      mask = rewriter.create<vector::TransposeOp>(loc, mask,
+                                                  transposeOp.getPermutation());
+    }
+    // The source memref.
+    mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+        transposeOp.getPermutation(), getContext());
+    auto transposedSubview = rewriter.create<memref::TransposeOp>(
+        loc, readSubview, AffineMapAttr::get(transposeMap));
+    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+    // The `in_bounds` attribute.
+    if (inBoundsAttr) {
+      SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+                                            inBoundsAttr.end());
+      applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+      inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+    }
+
+    VectorType legalReadType =
+        VectorType::Builder(resultType)
+            .setElementType(illegalRead.getVectorType().getElementType());
+    // Note: The indices are all zero as the subview is already offset.
+    SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
+    Value legalRead = rewriter.create<vector::TransferReadOp>(
+        loc, legalReadType, transposedSubview, readIndices,
+        illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
+        inBoundsAttr);
+
+    // Replace the transpose with the new read, extending the result if
+    // necessary.
+    rewriter.replaceOp(transposeOp, [&]() -> Value {
+      if (!extendOp)
+        return legalRead;
+      if (isa<arith::ExtSIOp>(extendOp))
+        return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
+      if (isa<arith::ExtUIOp>(extendOp))
+        return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
+      if (isa<arith::ExtFOp>(extendOp))
+        return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+      return legalRead;
+    }());
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
     OneToNTypeConverter converter;
     RewritePatternSet patterns(context);
-
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
         [](VectorType vectorType,
@@ -358,6 +514,7 @@ struct VectorLegalizationPass
           return success();
         });
 
+    patterns.add<LiftIllegalVectorTransposeToMemory>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..2317930d3d061 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,78 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
   vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
+// CHECK-SAME:                                                     %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+  // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
+  // CHECK-NEXT: return %[[EXT_TYPE]]
+  %pad = arith.constant 0 : i8
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
+  %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
+  %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
+  return %legalType : vector<4x[8]xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/80170


More information about the Mlir-commits mailing list