[Mlir-commits] [mlir] Reland [mlir] ArithToLLVM: fix memref bitcast lowering (#125148) (PR #126939)

Ivan Butygin llvmlistbot at llvm.org
Wed Feb 12 08:30:32 PST 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/126939

Reland https://github.com/llvm/llvm-project/pull/125148

Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.

>From 07b3df3518182ac1f9da5840eaa2744801137581 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 12 Feb 2025 12:19:13 +0100
Subject: [PATCH] [mlir] ArithToLLVM: fix memref bitcast lowering (#125148)

`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add a separate pattern which removes op if converted types are the same.
---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 25 +++++++++++++++++
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 28 ++++++++++++++++++-
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed89814293..ced18a48766bf 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
+/// No-op bitcast. Propagate type input arg if converted source and dest types
+/// are the same.
+struct IdentityBitcastLowering final
+    : public OpConversionPattern<arith::BitcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    Value src = adaptor.getIn();
+    Type resultType = getTypeConverter()->convertType(op.getType());
+    if (src.getType() != resultType)
+      return rewriter.notifyMatchFailure(op, "Types are different");
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -524,6 +543,12 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+  // Set a higher pattern benefit for IdentityBitcastLowering so it will run
+  // before BitcastOpLowering.
+  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
+                                        /*patternBenefit*/ 10);
+
   // clang-format off
   patterns.add<
     AddFOpLowering,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47c..7daf4ef8717bc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -577,12 +577,26 @@ func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
 // -----
 
 // CHECK-LABEL: @select
+//  CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
 func.func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 {
-  // CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32
+  // CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARG1]], %[[ARG2]] : i1, i32
+  // CHECK: return %[[RES]]
   %0 = arith.select %arg0, %arg1, %arg2 : i32
   return %0 : i32
 }
 
+// CHECK-LABEL: @select_complex
+//  CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: complex<f32>, %[[ARG2:.*]]: complex<f32>)
+func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>) -> complex<f32> {
+  // CHECK-DAG: %[[ARGC1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : complex<f32> to !llvm.struct<(f32, f32)>
+  // CHECK-DAG: %[[ARGC2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : complex<f32> to !llvm.struct<(f32, f32)>
+  //     CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARGC1]], %[[ARGC2]] : i1, !llvm.struct<(f32, f32)>
+  //     CHECK: %[[RESC:.*]] = builtin.unrealized_conversion_cast %[[RES]] : !llvm.struct<(f32, f32)> to complex<f32>
+  //     CHECK: return %[[RESC]]
+  %0 = arith.select %arg0, %arg1, %arg2 : complex<f32>
+  return %0 : complex<f32>
+}
+
 // -----
 
 // CHECK-LABEL: @ceildivsi
@@ -727,3 +741,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
+//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+//       CHECK:   return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+  func.return %2 : memref<?xbf16>
+}



More information about the Mlir-commits mailing list