[llvm] dc2c9b0 - [Matrix] Propagate and use shape info for binary operators.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 27 07:51:17 PST 2019


Author: Florian Hahn
Date: 2019-12-27T15:50:47Z
New Revision: dc2c9b0fcf28d1d3d6f19ec15cb29dbbf1f43f9d

URL: https://github.com/llvm/llvm-project/commit/dc2c9b0fcf28d1d3d6f19ec15cb29dbbf1f43f9d
DIFF: https://github.com/llvm/llvm-project/commit/dc2c9b0fcf28d1d3d6f19ec15cb29dbbf1f43f9d.diff

LOG: [Matrix] Propagate and use shape info for binary operators.

This patch extends the current shape propagation and shape aware
lowering to also support binary operators. Those operators are uniform
with respect to their shape (shape of the input operands is the same as
the shape of their result).

Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D70898

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
    llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll
    llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index d03b55756d3c..a9566422a832 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -281,6 +281,24 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  bool isUniformShape(Value *V) {
+    Instruction *I = dyn_cast<Instruction>(V);
+    if (!I)
+      return true;
+
+    switch (I->getOpcode()) {
+    case Instruction::FAdd:
+    case Instruction::FSub:
+    case Instruction::FMul: // Scalar multiply.
+    case Instruction::Add:
+    case Instruction::Mul:
+    case Instruction::Sub:
+      return true;
+    default:
+      return false;
+    }
+  }
+
   /// Returns true if shape information can be used for \p V. The supported
   /// instructions must match the instructions that can be lowered by this pass.
   bool supportsShapeInfo(Value *V) {
@@ -299,7 +317,7 @@ class LowerMatrixIntrinsics {
       default:
         return false;
       }
-    return isa<StoreInst>(Inst);
+    return isUniformShape(V) || isa<StoreInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -366,6 +384,15 @@ class LowerMatrixIntrinsics {
         if (OpShape != ShapeMap.end())
           setShapeInfo(Inst, OpShape->second);
         continue;
+      } else if (isUniformShape(Inst)) {
+        // Find the first operand that has a known shape and use that.
+        for (auto &Op : Inst->operands()) {
+          auto OpShape = ShapeMap.find(Op.get());
+          if (OpShape != ShapeMap.end()) {
+            Propagate |= setShapeInfo(Inst, OpShape->second);
+            break;
+          }
+        }
       }
 
       if (Propagate)
@@ -390,7 +417,9 @@ class LowerMatrixIntrinsics {
 
         Value *Op1;
         Value *Op2;
-        if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
+        if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
+          Changed |= VisitBinaryOperator(BinOp);
+        else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
           Changed |= VisitStore(&Inst, Op1, Op2, Builder);
       }
     }
@@ -673,6 +702,49 @@ class LowerMatrixIntrinsics {
     LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
     return true;
   }
+
+  /// Lower binary operators, if shape information is available.
+  bool VisitBinaryOperator(BinaryOperator *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    Value *Lhs = Inst->getOperand(0);
+    Value *Rhs = Inst->getOperand(1);
+
+    IRBuilder<> Builder(Inst);
+    ShapeInfo &Shape = I->second;
+
+    ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
+    ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
+
+    // Add each column and store the result back into the opmapping
+    ColumnMatrixTy Result;
+    auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
+      switch (Inst->getOpcode()) {
+      case Instruction::Add:
+        return Builder.CreateAdd(LHS, RHS);
+      case Instruction::Mul:
+        return Builder.CreateMul(LHS, RHS);
+      case Instruction::Sub:
+        return Builder.CreateSub(LHS, RHS);
+      case Instruction::FAdd:
+        return Builder.CreateFAdd(LHS, RHS);
+      case Instruction::FMul:
+        return Builder.CreateFMul(LHS, RHS);
+      case Instruction::FSub:
+        return Builder.CreateFSub(LHS, RHS);
+      default:
+        llvm_unreachable("Unsupported binary operator for matrix");
+      }
+    };
+    for (unsigned C = 0; C < Shape.NumColumns; ++C)
+      Result.addColumn(
+          BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
+
+    finalizeLowering(Inst, Result, Builder);
+    return true;
+  }
 };
 } // namespace
 

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll
index 2edc3700dd3d..246f4e42c2fa 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/bigger-expressions-double.ll
@@ -462,15 +462,34 @@ define void @transpose_multiply_add(<9 x double>* %A.Ptr, <9 x double>* %B.Ptr,
 
 ; CHECK-NEXT:    [[TMP106:%.*]] = shufflevector <1 x double> [[TMP105]], <1 x double> undef, <3 x i32> <i32 0, i32 undef, i32 undef>
 ; CHECK-NEXT:    [[TMP107:%.*]] = shufflevector <3 x double> [[TMP97]], <3 x double> [[TMP106]], <3 x i32> <i32 0, i32 1, i32 3>
-; CHECK-NEXT:    [[TMP108:%.*]] = shufflevector <3 x double> [[TMP47]], <3 x double> [[TMP77]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP109:%.*]] = shufflevector <3 x double> [[TMP107]], <3 x double> undef, <6 x i32> <i32 0, i32 1, i32 2, i32 undef, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP110:%.*]] = shufflevector <6 x double> [[TMP108]], <6 x double> [[TMP109]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
 
-; Load %C and add result of multiply.
+; Load %C.
 
 ; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, <9 x double>* [[C_PTR:%.*]]
-; CHECK-NEXT:    [[RES:%.*]] = fadd <9 x double> [[C]], [[TMP110]]
-; CHECK-NEXT:    store <9 x double> [[RES]], <9 x double>* [[C_PTR]]
+
+; Extract columns from %C.
+
+; CHECK-NEXT:    [[SPLIT84:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 0, i32 1, i32 2>
+; CHECK-NEXT:    [[SPLIT85:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT86:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 6, i32 7, i32 8>
+
+; Add column vectors.
+
+; CHECK-NEXT:    [[TMP108:%.*]] = fadd <3 x double> [[SPLIT84]], [[TMP47]]
+; CHECK-NEXT:    [[TMP109:%.*]] = fadd <3 x double> [[SPLIT85]], [[TMP77]]
+; CHECK-NEXT:    [[TMP110:%.*]] = fadd <3 x double> [[SPLIT86]], [[TMP107]]
+
+; Store result columns.
+
+; CHECK-NEXT:    [[TMP111:%.*]] = bitcast <9 x double>* [[C_PTR]] to double*
+; CHECK-NEXT:    [[TMP112:%.*]] = bitcast double* [[TMP111]] to <3 x double>*
+; CHECK-NEXT:    store <3 x double> [[TMP108]], <3 x double>* [[TMP112]], align 8
+; CHECK-NEXT:    [[TMP113:%.*]] = getelementptr double, double* [[TMP111]], i32 3
+; CHECK-NEXT:    [[TMP114:%.*]] = bitcast double* [[TMP113]] to <3 x double>*
+; CHECK-NEXT:    store <3 x double> [[TMP109]], <3 x double>* [[TMP114]], align 8
+; CHECK-NEXT:    [[TMP115:%.*]] = getelementptr double, double* [[TMP111]], i32 6
+; CHECK-NEXT:    [[TMP116:%.*]] = bitcast double* [[TMP115]] to <3 x double>*
+; CHECK-NEXT:    store <3 x double> [[TMP110]], <3 x double>* [[TMP116]], align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
index 3c398b319c01..1092aa2837ca 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
@@ -42,3 +42,75 @@ entry:
 }
 
 declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32)
+
+define <8 x double> @transpose_fadd(<8 x double> %a) {
+; CHECK-LABEL: @transpose_fadd(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP16:%.*]] = fadd <4 x double> [[TMP7]], [[SPLIT4]]
+; CHECK-NEXT:    [[TMP17:%.*]] = fadd <4 x double> [[TMP15]], [[SPLIT5]]
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP18]]
+;
+entry:
+  %c  = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
+  %res = fadd <8 x double> %c, %a
+  ret <8 x double> %res
+}
+
+define <8 x double> @transpose_fmul(<8 x double> %a) {
+; CHECK-LABEL: @transpose_fmul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
+; CHECK-NEXT:    [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
+; CHECK-NEXT:    [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
+; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
+; CHECK-NEXT:    [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
+; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
+; CHECK-NEXT:    [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP16:%.*]] = fmul <4 x double> [[TMP7]], [[SPLIT4]]
+; CHECK-NEXT:    [[TMP17:%.*]] = fmul <4 x double> [[TMP15]], [[SPLIT5]]
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    ret <8 x double> [[TMP18]]
+;
+entry:
+  %c  = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
+  %res = fmul <8 x double> %c, %a
+  ret <8 x double> %res
+}


        


More information about the llvm-commits mailing list