[llvm] bf7eb48 - [Matrix] RAUW should only replace an instruction in ShapeMap if supportsShapeInfo

Adam Nemet via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 27 11:36:49 PDT 2021


Author: Adam Nemet
Date: 2021-07-27T11:36:13-07:00
New Revision: bf7eb48454872cce85da42e0c006fa214e38fc0a

URL: https://github.com/llvm/llvm-project/commit/bf7eb48454872cce85da42e0c006fa214e38fc0a
DIFF: https://github.com/llvm/llvm-project/commit/bf7eb48454872cce85da42e0c006fa214e38fc0a.diff

LOG: [Matrix] RAUW should only replace an instruction in ShapeMap if supportsShapeInfo

As an instruction is replaced in optimizeTransposes RAUW will replace it in
the ShapeMap (ShapeMap is ValueMap so that uses are updated).  In
finalizeLowering however we skip updating uses if they are in the ShapeMap
since they will be lowered separately at which point we pick up the lowered
operands.

In the testcase what happened was that since we replaced the doubled-transpose
with the shuffle, it ended up in the ShapeMap.  As we lowered the
columnwise-load the use in the shuffle was not updated.  Then as we removed
the original columnwise-load we changed that to an undef.  I.e. we ended up
with:

```
%shuf = shufflevector <8 x double> undef, <8 x double> poison, <6 x i32>
                                   ^^^^^
                                  <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
```

Besides the fix itself, I have fortified this last bit.  As we change uses to
undef when removing instruction we track the undefed instruction to make sure
we eventually remove those too.  This would have caught the issue at compile
time.

Differential Revision: https://reviews.llvm.org/D106714

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 8030a613c817b..ab75cd3f566b4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -685,6 +685,19 @@ class LowerMatrixIntrinsics {
 
   /// Try moving transposes in order to fold them away or into multiplies.
   void optimizeTransposes() {
+    auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
+      // We need to remove Old from the ShapeMap otherwise RAUW will replace it
+      // with New. We should only add New it it supportsShapeInfo so we insert
+      // it conditionally instead.
+      auto S = ShapeMap.find(&Old);
+      if (S != ShapeMap.end()) {
+        ShapeMap.erase(S);
+        if (supportsShapeInfo(New))
+          ShapeMap.insert({New, S->second});
+      }
+      Old.replaceAllUsesWith(New);
+    };
+
     // First sink all transposes inside matmuls, hoping that we end up with NN,
     // NT or TN variants.
     for (BasicBlock &BB : reverse(Func)) {
@@ -717,7 +730,7 @@ class LowerMatrixIntrinsics {
           Value *TATA;
           if (match(TA,
                     m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
-            I.replaceAllUsesWith(TATA);
+            ReplaceAllUsesWith(I, TATA);
             EraseFromParent(&I);
             EraseFromParent(TA);
           }
@@ -740,8 +753,7 @@ class LowerMatrixIntrinsics {
             NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
                                                    K->getZExtValue(),
                                                    R->getZExtValue(), "mmul");
-            setShapeInfo(NewInst, {C, R});
-            I.replaceAllUsesWith(NewInst);
+            ReplaceAllUsesWith(I, NewInst);
             EraseFromParent(&I);
             EraseFromParent(TA);
           }
@@ -774,8 +786,7 @@ class LowerMatrixIntrinsics {
           setShapeInfo(M, {C, R});
           Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
                                                          C->getZExtValue());
-          setShapeInfo(NewInst, {C, R});
-          I->replaceAllUsesWith(NewInst);
+          ReplaceAllUsesWith(*I, NewInst);
           if (I->use_empty())
             I->eraseFromParent();
           if (A->use_empty())
@@ -879,10 +890,30 @@ class LowerMatrixIntrinsics {
 
     // Delete the instructions backwards, as it has a reduced likelihood of
     // having to update as many def-use and use-def chains.
+    //
+    // Because we add to ToRemove during fusion we can't guarantee that defs
+    // are before uses.  Change uses to undef temporarily as these should get
+    // removed as well.
+    //
+    // For verification, we keep track of where we changed uses to undefs in
+    // UndefedInsts and then check that we in fact remove them.
+    SmallSet<Instruction *, 16> UndefedInsts;
     for (auto *Inst : reverse(ToRemove)) {
-      if (!Inst->use_empty())
-        Inst->replaceAllUsesWith(UndefValue::get(Inst->getType()));
+      for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
+        Use &U = *I++;
+        if (auto *Undefed = dyn_cast<Instruction>(U.getUser()))
+          UndefedInsts.insert(Undefed);
+        U.set(UndefValue::get(Inst->getType()));
+      }
       Inst->eraseFromParent();
+      UndefedInsts.erase(Inst);
+    }
+    if (!UndefedInsts.empty()) {
+      // If we didn't remove all undefed instructions, it's a hard error.
+      dbgs() << "Undefed but present instructions:\n";
+      for (auto *I : UndefedInsts)
+        dbgs() << *I << "\n";
+      llvm_unreachable("Undefed but instruction not removed");
     }
 
     return Changed;

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
index 25c27a960b4ba..2a7ac8278d5e2 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll
@@ -986,6 +986,32 @@ entry:
   ret <4 x float> %m
 }
 
+define <6 x double> @transpose_of_transpose_of_non_matrix_op(double* %a) {
+; CHECK-LABEL: @transpose_of_transpose_of_non_matrix_op(
+; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[A:%.*]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[A]], i64 4
+; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, double* [[A]], i64 8
+; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[A]], i64 12
+; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD2]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x double> [[COL_LOAD5]], <2 x double> [[COL_LOAD8]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP2]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[SHUF:%.*]] = shufflevector <8 x double> [[TMP3]], <8 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
+; CHECK-NEXT:    ret <6 x double> [[SHUF]]
+;
+  %load = call <8 x double> @llvm.matrix.column.major.load.v8f64(double* %a, i64 4, i1 false, i32 2, i32 4)
+  %shuf = shufflevector <8 x double> %load, <8 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 4, i32 5, i32 6>
+  %t = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %shuf, i32 3, i32 2)
+  %tt = call <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double> %t, i32 2, i32 3)
+  ret <6 x double> %tt
+}
+
 declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
 declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32)
 declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32)
@@ -995,3 +1021,4 @@ declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32)
 declare <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double>, i32, i32)
 declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32)
 declare <4 x float> @llvm.matrix.transpose.v4f32(<4 x float>, i32, i32)
+declare <8 x double> @llvm.matrix.column.major.load.v8f64(double*, i64, i1, i32, i32)


        


More information about the llvm-commits mailing list