[PATCH] D94372: [X86][AMX] Prohibit pointer cast on load.

LuoYuanke via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 11 19:43:12 PST 2021


LuoYuanke updated this revision to Diff 315981.
LuoYuanke added a comment.

Address Pengfei's comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D94372/new/

https://reviews.llvm.org/D94372

Files:
  llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
  llvm/test/Transforms/InstCombine/X86/x86-amx-load-store.ll


Index: llvm/test/Transforms/InstCombine/X86/x86-amx-load-store.ll
===================================================================
--- /dev/null
+++ llvm/test/Transforms/InstCombine/X86/x86-amx-load-store.ll
@@ -0,0 +1,38 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -instcombine -S < %s | FileCheck %s
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+; Prohibit poiter cast for amx.
+define dso_local void @test_amx_load_store(<256 x i32>* %src, i8* %dst) {
+; CHECK-LABEL: @test_amx_load_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[VEC:%.*]] = load <256 x i32>, <256 x i32>* [[SRC:%.*]], align 64
+; CHECK-NEXT:    [[BC:%.*]] = bitcast <256 x i32> [[VEC]] to x86_amx
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* [[DST:%.*]], i64 64, x86_amx [[BC]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %vec = load <256 x i32>, <256 x i32>* %src, align 64
+  %bc = bitcast <256 x i32> %vec to x86_amx
+  tail call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %dst, i64 64, x86_amx %bc)
+  ret void
+}
+
+; Prohibit poiter cast for amx.
+define dso_local void @test_amx_load_store2(<256 x i32>* %dst, i8* %src) {
+; CHECK-LABEL: @test_amx_load_store2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AMX:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, i8* [[SRC:%.*]], i64 64)
+; CHECK-NEXT:    [[BC:%.*]] = bitcast x86_amx [[AMX]] to <256 x i32>
+; CHECK-NEXT:    store <256 x i32> [[BC]], <256 x i32>* [[DST:%.*]], align 1024
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, i8* %src, i64 64)
+  %bc = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %bc, <256 x i32>* %dst
+  ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
Index: llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
===================================================================
--- llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -589,7 +589,16 @@
   // Fold away bit casts of the loaded value by loading the desired type.
   // Note that we should not do this for pointer<->integer casts,
   // because that would result in type punning.
-  if (LI.hasOneUse())
+  if (LI.hasOneUse()) {
+    // Don't transform when the type is x86_amx, it makes the pass that lower
+    // x86_amx type happy.
+    if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) {
+      assert(!LI.getType()->isX86_AMXTy() &&
+             "load from x86_amx* should not happen!");
+      if (BC->getType()->isX86_AMXTy())
+        return nullptr;
+    }
+
     if (auto* CI = dyn_cast<CastInst>(LI.user_back()))
       if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() ==
                                     CI->getDestTy()->isPtrOrPtrVectorTy())
@@ -599,6 +608,7 @@
           IC.eraseInstFromFunction(*CI);
           return &LI;
         }
+  }
 
   // FIXME: We should also canonicalize loads of vectors when their elements are
   // cast to other types.
@@ -1114,10 +1124,12 @@
 
   // Fold away bit casts of the stored value by storing the original type.
   if (auto *BC = dyn_cast<BitCastInst>(V)) {
+    assert(!BC->getType()->isX86_AMXTy() &&
+           "store to x86_amx* should not happen!");
     V = BC->getOperand(0);
-    // Don't transform when the type is x86_amx, it make the pass that lower
+    // Don't transform when the type is x86_amx, it makes the pass that lower
     // x86_amx type happy.
-    if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy())
+    if (V->getType()->isX86_AMXTy())
       return false;
     if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) {
       combineStoreToNewValue(IC, SI, V);


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D94372.315981.patch
Type: text/x-patch
Size: 3919 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20210112/1d831239/attachment.bin>


More information about the llvm-commits mailing list