[llvm] 0cc38ac - [Matrix] Propagate shape information through fneg

Francis Visoiu Mistrih via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 22 14:35:13 PST 2021


Author: Francis Visoiu Mistrih
Date: 2021-01-22T14:34:28-08:00
New Revision: 0cc38acfc4e1dcdc2a9b6287bc93eef57acfd105

URL: https://github.com/llvm/llvm-project/commit/0cc38acfc4e1dcdc2a9b6287bc93eef57acfd105
DIFF: https://github.com/llvm/llvm-project/commit/0cc38acfc4e1dcdc2a9b6287bc93eef57acfd105.diff

LOG: [Matrix] Propagate shape information through fneg

Similar to binary operators like fadd/fmul/fsub, propagate shape info
through unary operators (fneg is the only one?).

Differential Revision: https://reviews.llvm.org/D95252

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
    llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index aa9894ca32b3..812922c49cfa 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -488,6 +488,7 @@ class LowerMatrixIntrinsics {
     case Instruction::FAdd:
     case Instruction::FSub:
     case Instruction::FMul: // Scalar multiply.
+    case Instruction::FNeg:
     case Instruction::Add:
     case Instruction::Mul:
     case Instruction::Sub:
@@ -724,6 +725,8 @@ class LowerMatrixIntrinsics {
       Value *Op2;
       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
         Changed |= VisitBinaryOperator(BinOp);
+      if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+        Changed |= VisitUnaryOperator(UnOp);
       if (match(Inst, m_Load(m_Value(Op1))))
         Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -1499,6 +1502,40 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  /// Lower unary operators, if shape information is available.
+  bool VisitUnaryOperator(UnaryOperator *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    Value *Op = Inst->getOperand(0);
+
+    IRBuilder<> Builder(Inst);
+    ShapeInfo &Shape = I->second;
+
+    MatrixTy Result;
+    MatrixTy M = getMatrix(Op, Shape, Builder);
+
+    // Helper to perform unary op on vectors.
+    auto BuildVectorOp = [&Builder, Inst](Value *Op) {
+      switch (Inst->getOpcode()) {
+      case Instruction::FNeg:
+        return Builder.CreateFNeg(Op);
+      default:
+        llvm_unreachable("Unsupported unary operator for matrix");
+      }
+    };
+
+    for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+      Result.addVector(BuildVectorOp(M.getVector(I)));
+
+    finalizeLowering(Inst,
+                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                             Result.getNumVectors()),
+                     Builder);
+    return true;
+  }
+
   /// Helper to linearize a matrix expression tree into a string. Currently
   /// matrix expressions are linarized by starting at an expression leaf and
   /// linearizing bottom up.

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
index 584f47d8530b..0860d7adb608 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
@@ -93,4 +93,48 @@ entry:
   ret <8 x double> %c
 }
 
+define <8 x double> @load_fneg_transpose(<8 x double>* %A.Ptr) {
+; CHECK-LABEL: @load_fneg_transpose(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <8 x double>* [[A_PTR:%.*]] to double*
+; CHECK-NEXT:    [[VEC_CAST:%.*]] = bitcast double* [[TMP0]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 2
+; CHECK-NEXT:    [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, double* [[TMP0]], i64 4
+; CHECK-NEXT:    [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 6
+; CHECK-NEXT:    [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
+; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = fneg <2 x double> [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fneg <2 x double> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP3:%.*]] = fneg <2 x double> [[COL_LOAD5]]
+; CHECK-NEXT:    [[TMP4:%.*]] = fneg <2 x double> [[COL_LOAD8]]
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x double> undef, double [[TMP5]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <4 x double> [[TMP6]], double [[TMP7]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <4 x double> [[TMP8]], double [[TMP9]], i64 2
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <2 x double> [[TMP4]], i64 0
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <4 x double> [[TMP10]], double [[TMP11]], i64 3
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <4 x double> undef, double [[TMP13]], i64 0
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <4 x double> [[TMP14]], double [[TMP15]], i64 1
+; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP18:%.*]] = insertelement <4 x double> [[TMP16]], double [[TMP17]], i64 2
+; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <2 x double> [[TMP4]], i64 1
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x double> [[TMP18]], double [[TMP19]], i64 3
+; CHECK-NEXT:    [[TMP21:%.*]] = shufflevector <4 x double> [[TMP12]], <4 x double> [[TMP20]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP21]]
+;
+entry:
+  %a = load <8 x double>, <8 x double>* %A.Ptr, align 8
+  %neg = fneg <8 x double> %a
+  %c  = call <8 x double> @llvm.matrix.transpose(<8 x double> %neg, i32 2, i32 4)
+  ret <8 x double> %c
+}
 declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32)

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
index 51796f3d1b60..6095957a6559 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
@@ -114,3 +114,37 @@ entry:
   %res = fmul <8 x double> %c, %a
   ret <8 x double> %res
 }
+
+define <8 x double> @transpose_fneg(<8 x double> %a) {
+; CHECK-LABEL: @transpose_fneg(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
+; CHECK-NEXT:    [[TMP16:%.*]] = fneg <4 x double> [[TMP7]]
+; CHECK-NEXT:    [[TMP17:%.*]] = fneg <4 x double> [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP18]]
+;
+entry:
+  %c  = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
+  %res = fneg <8 x double> %c
+  ret <8 x double> %res
+}


        


More information about the llvm-commits mailing list