[Mlir-commits] [mlir] 1a4e0c9 - [MLIR][Vector] Add canonicalization for interleave/deinterleave chain (#196979)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 18 02:43:25 PDT 2026
Author: Artem Kroviakov
Date: 2026-05-18T09:43:19Z
New Revision: 1a4e0c9ec987b01c7e895393f38bcf27d19e3486
URL: https://github.com/llvm/llvm-project/commit/1a4e0c9ec987b01c7e895393f38bcf27d19e3486
DIFF: https://github.com/llvm/llvm-project/commit/1a4e0c9ec987b01c7e895393f38bcf27d19e3486.diff
LOG: [MLIR][Vector] Add canonicalization for interleave/deinterleave chain (#196979)
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 28a8109cb59c0..5acf2b4ab7649 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -584,6 +584,7 @@ def Vector_InterleaveOp :
return ::llvm::cast<VectorType>(getResult().getType());
}
}];
+ let hasCanonicalizer = 1;
}
class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
@@ -2560,7 +2561,7 @@ def Vector_TypeCastOp :
}
def Vector_ConstantMaskOp :
- Vector_Op<"constant_mask", [Pure,
+ Vector_Op<"constant_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
@@ -2620,7 +2621,7 @@ def Vector_ConstantMaskOp :
}
def Vector_CreateMaskOp :
- Vector_Op<"create_mask", [Pure,
+ Vector_Op<"create_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
Arguments<(ins Variadic<Index>:$mask_dim_sizes)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9d58bc9172452..1297f4561b6b7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -8375,6 +8375,34 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
// InterleaveOp
//===----------------------------------------------------------------------===//
+namespace {
+
+/// This folder works on the following round-trip identity:
+/// interleave(deinterleave(x).even, deinterleave(x).odd) -> x
+struct InterleaveDeinterleaveFolder : public OpRewritePattern<InterleaveOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
+ PatternRewriter &rewriter) const override {
+ auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
+ auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
+ if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
+ return failure();
+ for (auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
+ if (cast<OpResult>(operand).getResultNumber() != idx)
+ return failure();
+ }
+ rewriter.replaceOp(interleaveOp, lhsDefOp.getSource());
+ return success();
+ }
+};
+} // namespace
+
+void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<InterleaveDeinterleaveFolder>(context);
+}
+
std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6aa92ab79a0dd..ed4c908c6e5f2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -4407,3 +4407,18 @@ func.func @no_fold_alltrue_mask_empty_body_scalar_result(
%result = vector.mask %all_true, %passthru { vector.yield %val : i32 } : vector<1xi1> -> i32
return %result : i32
}
+
+// -----
+
+// The test checks the `InterleaveDeinterleaveFolder` pattern of `vector.interleave`
+// to correctly fold the following identity:
+// interleave(deinterleave(x).even, deinterleave(x).odd) -> x
+
+// CHECK-LABEL: func @interleave_deinterleave_fold
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>)
+// CHECK: return %[[ARG0]]
+func.func @interleave_deinterleave_fold(%arg0: vector<4xf32>) -> vector<4xf32> {
+ %even, %odd = vector.deinterleave %arg0 : vector<4xf32> -> vector<2xf32>
+ %result = vector.interleave %even, %odd : vector<2xf32> -> vector<4xf32>
+ return %result : vector<4xf32>
+}
More information about the Mlir-commits
mailing list