[llvm] [msan] Handle NEON bfmmla (PR #176264)

Thurston Dang via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 21 11:12:17 PST 2026


https://github.com/thurstond updated https://github.com/llvm/llvm-project/pull/176264

>From e840cd6d1d2a18e278aefbaa866201cdfc54c20e Mon Sep 17 00:00:00 2001
From: Thurston Dang <thurston at google.com>
Date: Thu, 15 Jan 2026 21:56:25 +0000
Subject: [PATCH 1/5] [msan] Handle NEON bfmmla

This patch adapts handleNEONMatrixMultiply() (used for integer matrix
multiply: smmla/ummla/usmmla) to floating-point (bfmmla).
---
 .../Instrumentation/MemorySanitizer.cpp       | 93 +++++++++++--------
 .../aarch64-bf16-dotprod-intrinsics.ll        | 27 +++---
 .../MemorySanitizer/AArch64/aarch64-matmul.ll |  4 +-
 3 files changed, 71 insertions(+), 53 deletions(-)

diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 5c0f4599c9473..d98f86a85e41b 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -5354,12 +5354,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     }
   }
 
-  // <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8
-  //               (<4 x i32> %R, <16 x i8> %X, <16 x i8> %Y)
-  // <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8
-  //               (<4 x i32> %R, <16 x i8> %X, <16 x i8> %Y)
-  // <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8
-  //               (<4 x i32> R%, <16 x i8> %X, <16 x i8> %Y)
+  // Integer matrix multiplication:
+  // - <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8
+  //                 (<4 x i32> %R, <16 x i8> %X, <16 x i8> %Y)
+  // - <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8
+  //                 (<4 x i32> %R, <16 x i8> %X, <16 x i8> %Y)
+  // - <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8
+  //                 (<4 x i32> R%, <16 x i8> %X, <16 x i8> %Y)
   //
   // Note:
   // - < 4 x *> is a 2x2 matrix
@@ -5377,14 +5378,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   // TODO: consider allowing multiplication of zero with an uninitialized value
   //       to result in an initialized value.
   //
-  // TODO: handle floating-point matrix multiply using ummla on the shadows:
-  //   case Intrinsic::aarch64_neon_bfmmla:
-  //     handleNEONMatrixMultiply(I, /*ARows=*/ 2, /*ACols=*/ 4,
-  //                                 /*BRows=*/ 4, /*BCols=*/ 2);
-  //
-  void handleNEONMatrixMultiply(IntrinsicInst &I, unsigned int ARows,
-                                unsigned int ACols, unsigned int BRows,
-                                unsigned int BCols) {
+  // Floating-point matrix multiplication:
+  // - <4 x float> @llvm.aarch64.neon.bfmmla
+  //                   (<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b)
+  // Although there are half as many elements of %a and %b compared to the
+  // integer case, each element is twice the bit-width. Thus, we can reuse the
+  // shadow propagation logic if we cast the shadows to the same type as the
+  // integer case, and apply ummla to the shadows.
+  void handleNEONMatrixMultiply(IntrinsicInst &I) {
     IRBuilder<> IRB(&I);
 
     assert(I.arg_size() == 3);
@@ -5402,47 +5403,65 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     [[maybe_unused]] FixedVectorType *ATy = cast<FixedVectorType>(A->getType());
     [[maybe_unused]] FixedVectorType *BTy = cast<FixedVectorType>(B->getType());
 
-    assert(ACols == BRows);
-    assert(ATy->getNumElements() == ARows * ACols);
-    assert(BTy->getNumElements() == BRows * BCols);
-    assert(RTy->getNumElements() == ARows * BCols);
+    Value *ShadowR = getShadow(&I, 0);
+    Value *ShadowA = getShadow(&I, 1);
+    Value *ShadowB = getShadow(&I, 2);
+
+    // We will use ummla to compute the shadow. These are the types it expects.
+    // These are also the types of the corresponding shadows.
+    FixedVectorType *ExpectedRTy = FixedVectorType::get(IntegerType::get(*MS.C, 32), 4);
+    FixedVectorType *ExpectedATy = FixedVectorType::get(IntegerType::get(*MS.C, 8), 16);
+    FixedVectorType *ExpectedBTy = FixedVectorType::get(IntegerType::get(*MS.C, 8), 16);
 
-    LLVM_DEBUG(dbgs() << "### R: " << *RTy->getElementType() << "\n");
-    LLVM_DEBUG(dbgs() << "### A: " << *ATy->getElementType() << "\n");
     if (RTy->getElementType()->isIntegerTy()) {
-      // Types are not identical e.g., <4 x i32> %R, <16 x i8> %A
+      // Types of R and A/B are not identical e.g., <4 x i32> %R, <16 x i8> %A
       assert(ATy->getElementType()->isIntegerTy());
+
+      assert(RTy == ExpectedRTy);
+      assert(ATy == ExpectedATy);
+      assert(BTy == ExpectedBTy);
     } else {
-      assert(RTy->getElementType()->isFloatingPointTy());
       assert(ATy->getElementType()->isFloatingPointTy());
+      assert(BTy->getElementType()->isFloatingPointTy());
+
+      // Technically, what we care about is that:
+      //   getShadowTy(RTy)->canLosslesslyBitCastTo(ExpectedRTy)) etc.
+      // but that is equivalent.
+      assert(RTy->canLosslesslyBitCastTo(ExpectedRTy));
+      assert(ATy->canLosslesslyBitCastTo(ExpectedATy));
+      assert(BTy->canLosslesslyBitCastTo(ExpectedBTy));
+
+      ShadowA = IRB.CreateBitCast(ShadowA, getShadowTy(ExpectedATy));
+      ShadowB = IRB.CreateBitCast(ShadowB, getShadowTy(ExpectedBTy));
     }
     assert(ATy->getElementType() == BTy->getElementType());
 
-    Value *ShadowR = getShadow(&I, 0);
-    Value *ShadowA = getShadow(&I, 1);
-    Value *ShadowB = getShadow(&I, 2);
+    // From this point on, use Expected{R,A,B}Type.
 
     // If the value is fully initialized, the shadow will be 000...001.
     // Otherwise, the shadow will be all zero.
     // (This is the opposite of how we typically handle shadows.)
-    ShadowA = IRB.CreateZExt(IRB.CreateICmpEQ(ShadowA, getCleanShadow(A)),
-                             ShadowA->getType());
-    ShadowB = IRB.CreateZExt(IRB.CreateICmpEQ(ShadowB, getCleanShadow(B)),
-                             ShadowB->getType());
+    ShadowA = IRB.CreateZExt(IRB.CreateICmpEQ(ShadowA, getCleanShadow(ExpectedATy)),
+                             getShadowTy(ExpectedATy));
+    ShadowB = IRB.CreateZExt(IRB.CreateICmpEQ(ShadowB, getCleanShadow(ExpectedBTy)),
+                             getShadowTy(ExpectedBTy));
 
     Value *ShadowAB = IRB.CreateIntrinsic(
-        I.getType(), I.getIntrinsicID(), {getCleanShadow(R), ShadowA, ShadowB});
+        ExpectedRTy, Intrinsic::aarch64_neon_ummla,
+        {getCleanShadow(ExpectedRTy), ShadowA, ShadowB});
 
+    // ummla multiplies a 2x8 matrix with an 8x2 matrix. If all entries of the
+    // input matrices are equal to 0x1, all entries of the output matrix will
+    // be 0x8.
     Value *FullyInit = ConstantVector::getSplat(
-        RTy->getElementCount(),
-        ConstantInt::get(cast<VectorType>(getShadowTy(R))->getElementType(),
-                         ACols));
+        ExpectedRTy->getElementCount(),
+        ConstantInt::get(ExpectedRTy->getElementType(), 0x8));
 
     ShadowAB = IRB.CreateSExt(IRB.CreateICmpNE(ShadowAB, FullyInit),
                               ShadowAB->getType());
 
-    ShadowR = IRB.CreateSExt(IRB.CreateICmpNE(ShadowR, getCleanShadow(R)),
-                             ShadowR->getType());
+    ShadowR = IRB.CreateSExt(IRB.CreateICmpNE(ShadowR, getCleanShadow(ExpectedRTy)),
+                             ExpectedRTy);
 
     setShadow(&I, IRB.CreateOr(ShadowAB, ShadowR));
     setOriginForNaryOp(I);
@@ -6827,8 +6846,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     case Intrinsic::aarch64_neon_smmla:
     case Intrinsic::aarch64_neon_ummla:
     case Intrinsic::aarch64_neon_usmmla:
-      handleNEONMatrixMultiply(I, /*ARows=*/2, /*ACols=*/8, /*BRows=*/8,
-                               /*BCols=*/2);
+    case Intrinsic::aarch64_neon_bfmmla:
+      handleNEONMatrixMultiply(I);
       break;
 
     default:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
index aa771e3cb2fc0..0aa9813cb7a28 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -239,21 +239,20 @@ define <4 x float> @test_vbfmmlaq_f32(<4 x float> %r, <8 x bfloat> %a, <8 x bflo
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 16), align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <8 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 32), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x i32> [[TMP0]] to i128
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i128 [[TMP3]], 0
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <8 x i16> [[TMP1]] to i128
-; CHECK-NEXT:    [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
-; CHECK-NEXT:    [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <8 x i16> [[TMP2]] to i128
-; CHECK-NEXT:    [[_MSCMP2:%.*]] = icmp ne i128 [[TMP5]], 0
-; CHECK-NEXT:    [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
-; CHECK-NEXT:    br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
-; CHECK:       [[BB6]]:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR5]]
-; CHECK-NEXT:    unreachable
-; CHECK:       [[BB7]]:
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <8 x i16> [[TMP1]] to <16 x i8>
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <8 x i16> [[TMP2]] to <16 x i8>
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq <16 x i8> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i1> [[TMP5]] to <16 x i8>
+; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq <16 x i8> [[TMP4]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = zext <16 x i1> [[TMP7]] to <16 x i8>
+; CHECK-NEXT:    [[TMP9:%.*]] = call <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32> zeroinitializer, <16 x i8> [[TMP6]], <16 x i8> [[TMP8]])
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <4 x i32> [[TMP9]], splat (i32 8)
+; CHECK-NEXT:    [[TMP11:%.*]] = sext <4 x i1> [[TMP10]] to <4 x i32>
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <4 x i32> [[TMP0]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i1> [[TMP12]] to <4 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = or <4 x i32> [[TMP11]], [[TMP13]]
 ; CHECK-NEXT:    [[VBFMMLAQ_V3_I:%.*]] = call <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float> [[R]], <8 x bfloat> [[A]], <8 x bfloat> [[B]])
-; CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    store <4 x i32> [[TMP14]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <4 x float> [[VBFMMLAQ_V3_I]]
 ;
 entry:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-matmul.ll b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-matmul.ll
index b5d6b627366eb..7a782b2cfd36d 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-matmul.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-matmul.ll
@@ -24,7 +24,7 @@ define <4 x i32> @smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) sa
 ; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i1> [[TMP3]] to <16 x i8>
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq <16 x i8> [[TMP2]], zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i1> [[TMP5]] to <16 x i8>
-; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x i32> @llvm.aarch64.neon.smmla.v4i32.v16i8(<4 x i32> zeroinitializer, <16 x i8> [[TMP4]], <16 x i8> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32> zeroinitializer, <16 x i8> [[TMP4]], <16 x i8> [[TMP6]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp ne <4 x i32> [[TMP7]], splat (i32 8)
 ; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i1> [[TMP8]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <4 x i32> [[TMP0]], zeroinitializer
@@ -78,7 +78,7 @@ define <4 x i32> @usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) s
 ; CHECK-NEXT:    [[TMP4:%.*]] = zext <16 x i1> [[TMP3]] to <16 x i8>
 ; CHECK-NEXT:    [[TMP5:%.*]] = icmp eq <16 x i8> [[TMP2]], zeroinitializer
 ; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i1> [[TMP5]] to <16 x i8>
-; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8(<4 x i32> zeroinitializer, <16 x i8> [[TMP4]], <16 x i8> [[TMP6]])
+; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8(<4 x i32> zeroinitializer, <16 x i8> [[TMP4]], <16 x i8> [[TMP6]])
 ; CHECK-NEXT:    [[TMP8:%.*]] = icmp ne <4 x i32> [[TMP7]], splat (i32 8)
 ; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i1> [[TMP8]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <4 x i32> [[TMP0]], zeroinitializer

>From f2eb06c07d9fc5c6fd770f037cfcca691952b2b8 Mon Sep 17 00:00:00 2001
From: Thurston Dang <thurston at google.com>
Date: Thu, 15 Jan 2026 22:58:43 +0000
Subject: [PATCH 2/5] Fix comment

---
 .../MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll   | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
index 0aa9813cb7a28..c87e583d7713f 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -6,7 +6,6 @@
 ; Strictly handled:
 ; - llvm.aarch64.neon.bfdot.v2f32.v4bf16
 ; - llvm.aarch64.neon.bfdot.v4f32.v8bf16
-; - llvm.aarch64.neon.bfmmla
 ; - llvm.aarch64.neon.bfmlalb
 ; - llvm.aarch64.neon.bfmlalt
 ;

>From 148a0e93b0aafa6ad24c8151a80dcedd7e28dc34 Mon Sep 17 00:00:00 2001
From: Thurston Dang <thurston at google.com>
Date: Thu, 15 Jan 2026 22:59:05 +0000
Subject: [PATCH 3/5] clang-format

---
 .../Instrumentation/MemorySanitizer.cpp       | 29 +++++++++++--------
 1 file changed, 17 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index d98f86a85e41b..09c6ce7c63fed 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -5409,9 +5409,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
 
     // We will use ummla to compute the shadow. These are the types it expects.
     // These are also the types of the corresponding shadows.
-    FixedVectorType *ExpectedRTy = FixedVectorType::get(IntegerType::get(*MS.C, 32), 4);
-    FixedVectorType *ExpectedATy = FixedVectorType::get(IntegerType::get(*MS.C, 8), 16);
-    FixedVectorType *ExpectedBTy = FixedVectorType::get(IntegerType::get(*MS.C, 8), 16);
+    FixedVectorType *ExpectedRTy =
+        FixedVectorType::get(IntegerType::get(*MS.C, 32), 4);
+    FixedVectorType *ExpectedATy =
+        FixedVectorType::get(IntegerType::get(*MS.C, 8), 16);
+    FixedVectorType *ExpectedBTy =
+        FixedVectorType::get(IntegerType::get(*MS.C, 8), 16);
 
     if (RTy->getElementType()->isIntegerTy()) {
       // Types of R and A/B are not identical e.g., <4 x i32> %R, <16 x i8> %A
@@ -5441,14 +5444,16 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     // If the value is fully initialized, the shadow will be 000...001.
     // Otherwise, the shadow will be all zero.
     // (This is the opposite of how we typically handle shadows.)
-    ShadowA = IRB.CreateZExt(IRB.CreateICmpEQ(ShadowA, getCleanShadow(ExpectedATy)),
-                             getShadowTy(ExpectedATy));
-    ShadowB = IRB.CreateZExt(IRB.CreateICmpEQ(ShadowB, getCleanShadow(ExpectedBTy)),
-                             getShadowTy(ExpectedBTy));
+    ShadowA =
+        IRB.CreateZExt(IRB.CreateICmpEQ(ShadowA, getCleanShadow(ExpectedATy)),
+                       getShadowTy(ExpectedATy));
+    ShadowB =
+        IRB.CreateZExt(IRB.CreateICmpEQ(ShadowB, getCleanShadow(ExpectedBTy)),
+                       getShadowTy(ExpectedBTy));
 
-    Value *ShadowAB = IRB.CreateIntrinsic(
-        ExpectedRTy, Intrinsic::aarch64_neon_ummla,
-        {getCleanShadow(ExpectedRTy), ShadowA, ShadowB});
+    Value *ShadowAB =
+        IRB.CreateIntrinsic(ExpectedRTy, Intrinsic::aarch64_neon_ummla,
+                            {getCleanShadow(ExpectedRTy), ShadowA, ShadowB});
 
     // ummla multiplies a 2x8 matrix with an 8x2 matrix. If all entries of the
     // input matrices are equal to 0x1, all entries of the output matrix will
@@ -5460,8 +5465,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     ShadowAB = IRB.CreateSExt(IRB.CreateICmpNE(ShadowAB, FullyInit),
                               ShadowAB->getType());
 
-    ShadowR = IRB.CreateSExt(IRB.CreateICmpNE(ShadowR, getCleanShadow(ExpectedRTy)),
-                             ExpectedRTy);
+    ShadowR = IRB.CreateSExt(
+        IRB.CreateICmpNE(ShadowR, getCleanShadow(ExpectedRTy)), ExpectedRTy);
 
     setShadow(&I, IRB.CreateOr(ShadowAB, ShadowR));
     setOriginForNaryOp(I);

>From b473cc6b10b043339e34cf4b13febb0072c6e766 Mon Sep 17 00:00:00 2001
From: Thurston Dang <thurston at google.com>
Date: Tue, 20 Jan 2026 22:25:34 +0000
Subject: [PATCH 4/5] Add sketch proof

---
 .../Instrumentation/MemorySanitizer.cpp       | 38 ++++++++++++++++---
 1 file changed, 32 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 09c6ce7c63fed..a3aebdb80980d 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -5360,11 +5360,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   // - <4 x i32> @llvm.aarch64.neon.ummla.v4i32.v16i8
   //                 (<4 x i32> %R, <16 x i8> %X, <16 x i8> %Y)
   // - <4 x i32> @llvm.aarch64.neon.usmmla.v4i32.v16i8
-  //                 (<4 x i32> R%, <16 x i8> %X, <16 x i8> %Y)
+  //                 (<4 x i32> %R, <16 x i8> %X, <16 x i8> %Y)
   //
   // Note:
-  // - < 4 x *> is a 2x2 matrix
-  // - <16 x *> is a 2x8 matrix and 8x2 matrix respectively
+  // - <4 x i32> is a 2x2 matrix
+  // - <16 x i8> %X and %Y are 2x8 and 8x2 matrices respectively
+  //
+  //   2x8 %X                                8x2 %Y
+  //   [ X01 X02 X03 X04 X05 X06 X07 X08 ]   [ Y01 Y09 ]
+  //   [ X09 X10 X11 X12 X13 X14 X15 X16 ] x [ Y02 Y10 ]
+  //                                         [ Y03 Y11 ]
+  //                                         [ Y04 Y12 ]
+  //                                         [ Y05 Y13 ]
+  //                                         [ Y06 Y14 ]
+  //                                         [ Y07 Y15 ]
+  //                                         [ Y08 Y16 ]
   //
   // The general shadow propagation approach is:
   // 1) get the shadows of the input matrices %X and %Y
@@ -5380,11 +5390,27 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   //
   // Floating-point matrix multiplication:
   // - <4 x float> @llvm.aarch64.neon.bfmmla
-  //                   (<4 x float> %r, <8 x bfloat> %a, <8 x bfloat> %b)
-  // Although there are half as many elements of %a and %b compared to the
+  //                   (<4 x float> %R, <8 x bfloat> %X, <8 x bfloat> %Y)
+  //   %X and %Y are 2x4 and 4x2 matrices respectively
+  //
+  // Although there are half as many elements of %X and %Y compared to the
   // integer case, each element is twice the bit-width. Thus, we can reuse the
   // shadow propagation logic if we cast the shadows to the same type as the
-  // integer case, and apply ummla to the shadows.
+  // integer case, and apply ummla to the shadows:
+  //
+  //   2x4 %X                                4x2 %Y
+  //   [ A01:A02 A03:A04 A05:A06 A07:A08 ]   [ B01:B02 B09:B10 ]
+  //   [ A09:A10 A11:A12 A13:A14 A15:A16 ] x [ B03:B04 B11:B12 ]
+  //                                         [ B05:B06 B13:B14 ]
+  //                                         [ B07:B08 B15:B16 ]
+  //
+  // For example, consider multiplying the first row of %X with the first
+  // column of Y. We want to know if
+  // A01:A02*B01:B02 + A03:A04*B03:B04 + A05:A06*B06:B06 + A07:A08*B07:B08 is
+  // fully initialized, which will be true if and only if (A01, A02, ..., A08)
+  // and (B01, B02, ..., B08) are each fully initialized. This latter condition
+  // is equivalent to what is tested by the instrumentation for the integer
+  // form.
   void handleNEONMatrixMultiply(IntrinsicInst &I) {
     IRBuilder<> IRB(&I);
 

>From a2a756b9c3d460b3e86b67379b94771fd7935f3d Mon Sep 17 00:00:00 2001
From: Thurston Dang <thurston at google.com>
Date: Wed, 21 Jan 2026 19:11:54 +0000
Subject: [PATCH 5/5] Fix comment

---
 .../MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
index 3f60f5fe0c279..c77887260cc7d 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -4,9 +4,6 @@
 ; Forked from llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
 ;
 ; Strictly handled:
-; - llvm.aarch64.neon.bfdot.v2f32.v4bf16
-; - llvm.aarch64.neon.bfdot.v4f32.v8bf16
-; - llvm.aarch64.neon.bfmmla
 ; - llvm.aarch64.neon.bfmlalb
 ; - llvm.aarch64.neon.bfmlalt
 ;



More information about the llvm-commits mailing list