[Mlir-commits] [mlir] [mlir] ArithToLLVM: fix memref bitcast lowering (PR #125148)
Ivan Butygin
llvmlistbot at llvm.org
Fri Jan 31 07:01:54 PST 2025
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/125148
>From 2fc606a5830d81c1479507fa4ea23e8fba73ad78 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 31 Jan 2025 02:54:17 +0100
Subject: [PATCH 1/2] [mlir] ArithToLLVM: fix memref bitcast lowering
arith.bitcast is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal.
With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 20 +++++++++++++++++++
.../Conversion/LLVMCommon/VectorPattern.cpp | 7 ++++++-
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 12 +++++++++++
3 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed898142936..b726faa92a03a0 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,23 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};
+/// No-op bitcast.
+struct IdentityBitcastLowering final
+ : public OpConversionPattern<arith::BitcastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ Value src = adaptor.getIn();
+ if (src.getType() != getTypeConverter()->convertType(op.getType()))
+ return rewriter.notifyMatchFailure(op, "Types are different");
+
+ rewriter.replaceOp(op, src);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -524,6 +541,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(
void mlir::arith::populateArithToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+ patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
+
// clang-format off
patterns.add<
AddFOpLowering,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 626135c10a3e96..c9d3b57b0d596e 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}
+static bool isVectorCompatibleType(Type type) {
+ return isa<LLVM::LLVMArrayType, VectorType, IntegerType, FloatType>(type) &&
+ LLVM::isCompatibleType(type);
+}
+
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
@@ -111,7 +116,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
- if (!llvm::all_of(operands.getTypes(), isCompatibleType))
+ if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
return failure();
auto llvmNDVectorTy = operands[0].getType();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47cc..9a6c4bca88f3bf 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
+// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+// CHECK: return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+ %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+ func.return %2 : memref<?xbf16>
+}
>From 88d92bf15721d2bf6afa1a71e5854b6ed7b0d605 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 31 Jan 2025 15:58:48 +0100
Subject: [PATCH 2/2] add comments
---
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 6 ++++--
mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp | 2 ++
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index b726faa92a03a0..5c1afe8034c73b 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,7 +54,8 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};
-/// No-op bitcast.
+/// No-op bitcast. Propagate type input arg if converted source and dest types
+/// are the same.
struct IdentityBitcastLowering final
: public OpConversionPattern<arith::BitcastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -63,7 +64,8 @@ struct IdentityBitcastLowering final
matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Value src = adaptor.getIn();
- if (src.getType() != getTypeConverter()->convertType(op.getType()))
+ Type resultType = getTypeConverter()->convertType(op.getType());
+ if (src.getType() != resultType)
return rewriter.notifyMatchFailure(op, "Types are different");
rewriter.replaceOp(op, src);
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index c9d3b57b0d596e..e51363ca28505c 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -104,6 +104,8 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
}
static bool isVectorCompatibleType(Type type) {
+ // Limit `vectorOneToOneRewrite` to scalar and vector types (and to
+ // `LLVM::LLVMArrayType` which have a special handling).
return isa<LLVM::LLVMArrayType, VectorType, IntegerType, FloatType>(type) &&
LLVM::isCompatibleType(type);
}
More information about the Mlir-commits
mailing list