[Mlir-commits] [mlir] 79010e2 - [mlir] ArithToLLVM: fix memref bitcast lowering (#125148)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 12 03:19:16 PST 2025


Author: Ivan Butygin
Date: 2025-02-12T14:19:13+03:00
New Revision: 79010e2e4d0e27ee87887bfaef2c32e908c92a8e

URL: https://github.com/llvm/llvm-project/commit/79010e2e4d0e27ee87887bfaef2c32e908c92a8e
DIFF: https://github.com/llvm/llvm-project/commit/79010e2e4d0e27ee87887bfaef2c32e908c92a8e.diff

LOG: [mlir] ArithToLLVM: fix memref bitcast lowering (#125148)

`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a
separate pattern which removes op if converted types are the same.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed89814293..5c1afe8034c73 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
+/// No-op bitcast. Propagate type input arg if converted source and dest types
+/// are the same.
+struct IdentityBitcastLowering final
+    : public OpConversionPattern<arith::BitcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    Value src = adaptor.getIn();
+    Type resultType = getTypeConverter()->convertType(op.getType());
+    if (src.getType() != resultType)
+      return rewriter.notifyMatchFailure(op, "Types are 
diff erent");
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -524,6 +543,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
+
   // clang-format off
   patterns.add<
     AddFOpLowering,

diff  --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index bf3f31729c3da..fe4781138fa29 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,6 +103,14 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
+static bool isVectorCompatibleType(Type type) {
+  // Limit `vectorOneToOneRewrite` to scalar and vector types (and to
+  // `LLVM::LLVMArrayType` which have a special handling).
+  return isa<LLVM::LLVMArrayType, LLVM::LLVMPointerType, VectorType,
+             IntegerType, FloatType>(type) &&
+         LLVM::isCompatibleType(type);
+}
+
 LogicalResult LLVM::detail::vectorOneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
     ArrayRef<NamedAttribute> targetAttrs,
@@ -111,7 +119,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
-  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
+  if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
     return failure();
 
   auto llvmNDVectorTy = operands[0].getType();

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47c..9a6c4bca88f3b 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
+//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+//       CHECK:   return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+  func.return %2 : memref<?xbf16>
+}


        


More information about the Mlir-commits mailing list