[Mlir-commits] [mlir] [mlir][ArmSME] Update `OuterProductFusion` to account for recent changes (PR #102125)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 6 04:35:50 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-sme
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
- Use vector.interleave rather than the LLVM intrinsic
- Remove dependency on LLVM dialect
- Remove manual outerproduct erases (these are now trivially dead)
- Remove comment explaining issues with previous tile allocator
- Update pipeline in `multi-tile-matmul-mixed-types.mlir`
---
Full diff: https://github.com/llvm/llvm-project/pull/102125.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+10-10)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+1-1)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (-1)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp (+5-44)
- (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+16-16)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir (+1-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 9178655f010c9..3f1776f57e4c7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -910,11 +910,11 @@ def FMopa2WayOp
The 2 outer products in the example above can be fused into a single outer
product as follows:
- ```mlir
- %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
- %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+ ```mlir
+ %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
+ %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- ```
+ ```
This is implemented in the `-arm-sme-outer-product-fusion` pass.
@@ -1167,13 +1167,13 @@ def SMopa4WayOp
product as follows:
```mlir
- %lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+ %lhs0 = vector.interleave %a0, %a2 : vector<[4]xi8> -> vector<[8]xi8>
+ %lhs1 = vector.interleave %a1, %a3 : vector<[4]xi8> -> vector<[8]xi8>
+ %lhs = vector.interleave %lhs0, %lhs1 : vector<[8]xi8> -> vector<[16]xi8>
- %rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
- %rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+ %rhs0 = vector.interleave %b0, %b2 : vector<[4]xi8> -> vector<[8]xi8>
+ %rhs1 = vector.interleave %b1, %b3 : vector<[4]xi8> -> vector<[8]xi8>
+ %rhs = vector.interleave %rhs0, %rhs1 : vector<[8]xi8> -> vector<[16]xi8>
%0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 921234daad1f1..45efabf5fe1b4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -180,7 +180,7 @@ def OuterProductFusion
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
}];
let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
- let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
+ let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
}
def VectorLegalization
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f9b5080e82db..a29624468ba2d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRPass
MLIRArmSMEDialect
MLIRFuncDialect
- MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRIndexDialect
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 1e711678dc9ab..ee1e374b25b04 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -80,15 +79,6 @@ static LogicalResult isCompatible(PatternRewriter &rewriter,
return success();
}
-// Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
-static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs) {
- auto inputType = cast<VectorType>(lhs.getType());
- VectorType inputTypeX2 =
- VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
- return rewriter.create<LLVM::vector_interleave2>(loc, inputTypeX2, lhs, rhs);
-}
-
// Fuse two 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 2-way outer product operation.
//
@@ -106,10 +96,8 @@ static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
//
// Becomes:
//
-// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
-// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
-// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
+// %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
class OuterProductFusion2Way
@@ -135,28 +123,7 @@ class OuterProductFusion2Way
if (!op1->hasOneUse()) {
// If the first outer product has uses other than as the input to another
- // outer product, it can't be erased after fusion. This is a problem when
- // it also has an accumulator as this will be used as the root for tile
- // allocation and since the widening outer product uses the same
- // accumulator it will get assigned the same tile ID, resulting in 3
- // outer products accumulating to the same tile and incorrect results.
- //
- // Example:
- //
- // %acc = arith.constant dense<0.0> ; root for tile allocation
- // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
- // vector.print %0 ; intermediary use, can't erase %0
- // %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
- //
- // After fusion and tile allocation
- //
- // %0 = arm_sme.zero {tile_id = 0 : i32}
- // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
- // vector.print %1
- // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
- //
- // No accumulator would be ok, but it's simpler to prevent this
- // altogether, since it has no benefit.
+ // outer product, it can't be erased after fusion.
return rewriter.notifyMatchFailure(op,
kMatchFailureOuterProductNotSingleUse);
}
@@ -169,7 +136,7 @@ class OuterProductFusion2Way
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
- return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
+ return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
};
auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -226,8 +193,6 @@ class OuterProductFusion2Way
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
- rewriter.eraseOp(op1);
-
return success();
}
@@ -319,7 +284,7 @@ class OuterProductFusion4Way
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
- return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
+ return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
};
auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -400,10 +365,6 @@ class OuterProductFusion4Way
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
- rewriter.eraseOp(op3);
- rewriter.eraseOp(op2);
- rewriter.eraseOp(op1);
-
return success();
}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 4887d611643fb..9000551783576 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -4,10 +4,10 @@
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
-// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[A0]], %[[A1]] : vector<[4]xf16> -> vector<[8]xf16>
+// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[B0]], %[[B1]] : vector<[4]xf16> -> vector<[8]xf16>
+// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
// CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
func.func @outerproduct_add_widening_2way_f16f16f32(
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
@@ -225,18 +225,18 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
// CHECK-SAME: %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>,
// CHECK-SAME: %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
-// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
-// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
-// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: %[[LHS0:.*]] = vector.interleave %[[A0]], %[[A2]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS1:.*]] = vector.interleave %[[A1]], %[[A3]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS0:.*]] = vector.interleave %[[B0]], %[[B2]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS1:.*]] = vector.interleave %[[B1]], %[[B3]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[LHS0]], %[[LHS1]] : vector<[8]xi8> -> vector<[16]xi8>
+// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[RHS0]], %[[RHS1]] : vector<[8]xi8> -> vector<[16]xi8>
+// CHECK-DAG: %[[LHS0_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS1_MASK:.*]] = vector.interleave %[[A1_MASK]], %[[A3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS0_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS1_MASK:.*]] = vector.interleave %[[B1_MASK]], %[[B3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[LHS0_MASK]], %[[LHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[RHS0_MASK]], %[[RHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
func.func @outerproduct_add_widening_4way_signed_i8i8i32(
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index aabd9d2ce788e..5784ecbbe4014 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -1,11 +1,7 @@
// RUN: mlir-opt %s \
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN: -arm-sme-vector-legalization -canonicalize -cse \
-// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
-// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
-// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
``````````
</details>
https://github.com/llvm/llvm-project/pull/102125
More information about the Mlir-commits
mailing list