From 2f0eaa3d26099ebc9edd5e62db87722dac626692 Mon Sep 17 00:00:00 2001
From: Kirithika <kirithika@multicorewareinc.com>
Date: Tue, 20 Dec 2022 14:47:29 +0530
Subject: [PATCH] Readjust reference frames and NAL type based on temporal
 layer

---
 source/common/frame.cpp    |   1 +
 source/common/frame.h      |   2 +
 source/encoder/api.cpp     |   4 +-
 source/encoder/dpb.cpp     | 167 +++++++++++++++++++++++++++++++++++--
 source/encoder/dpb.h       |   9 +-
 source/encoder/encoder.cpp |   2 +
 source/x265.h              |   2 +-
 7 files changed, 174 insertions(+), 13 deletions(-)

diff --git a/source/common/frame.cpp b/source/common/frame.cpp
index 48c538714..768d69f34 100644
--- a/source/common/frame.cpp
+++ b/source/common/frame.cpp
@@ -75,6 +75,7 @@ Frame::Frame()
     m_prevMCSTF = NULL;
 
     m_tempLayer = 0;
+    m_sameLayerRefPic = false;
 }
 
 bool Frame::create(x265_param *param, float* quantOffsets)
diff --git a/source/common/frame.h b/source/common/frame.h
index c916f7714..fcd0031bc 100644
--- a/source/common/frame.h
+++ b/source/common/frame.h
@@ -163,6 +163,8 @@ public:
     /*Frame's temporal layer info*/
     uint8_t                m_tempLayer;
     int8_t                 m_gopId;
+    bool                   m_sameLayerRefPic;
+
     Frame();
 
     bool create(x265_param *param, float* quantOffsets);
diff --git a/source/encoder/api.cpp b/source/encoder/api.cpp
index 5f1bba67b..9b53e62ae 100644
--- a/source/encoder/api.cpp
+++ b/source/encoder/api.cpp
@@ -1297,7 +1297,7 @@ FILE* x265_csvlog_open(const x265_param* param)
             if (param->csvLogLevel)
             {
                 fprintf(csvfp, "Encode Order, Type, POC, QP, Bits, Scenecut, ");
-                if (param->bEnableTemporalSubLayers > 2)
+                if (!!param->bEnableTemporalSubLayers)
                     fprintf(csvfp, "Temporal Sub Layer ID, ");
                 if (param->csvLogLevel >= 2)
                     fprintf(csvfp, "I/P cost ratio, ");
@@ -1412,7 +1412,7 @@ void x265_csvlog_frame(const x265_param* param, const x265_picture* pic)
     const x265_frame_stats* frameStats = &pic->frameData;
     fprintf(param->csvfpt, "%d, %c-SLICE, %4d, %2.2lf, %10d, %d,", frameStats->encoderOrder, frameStats->sliceType, frameStats->poc,
                                                                    frameStats->qp, (int)frameStats->bits, frameStats->bScenecut);
-    if (param->bEnableTemporalSubLayers > 2)
+    if (!!param->bEnableTemporalSubLayers)
         fprintf(param->csvfpt, "%d,", frameStats->tLayer);
     if (param->csvLogLevel >= 2)
         fprintf(param->csvfpt, "%.2f,", frameStats->ipCostRatio);
diff --git a/source/encoder/dpb.cpp b/source/encoder/dpb.cpp
index bfe6f2290..24d3cd202 100644
--- a/source/encoder/dpb.cpp
+++ b/source/encoder/dpb.cpp
@@ -150,12 +150,13 @@ void DPB::prepareEncode(Frame *newFrame)
     {
         newFrame->m_encData->m_bHasReferences = false;
 
+        newFrame->m_tempLayer = (newFrame->m_param->bEnableTemporalSubLayers && !m_bTemporalSublayer) ? 1 : newFrame->m_tempLayer;
         // Adjust NAL type for unreferenced B frames (change from _R "referenced"
         // to _N "non-referenced" NAL unit type)
         switch (slice->m_nalUnitType)
         {
         case NAL_UNIT_CODED_SLICE_TRAIL_R:
-            slice->m_nalUnitType = m_bTemporalSublayer ? NAL_UNIT_CODED_SLICE_TSA_N : NAL_UNIT_CODED_SLICE_TRAIL_N;
+            slice->m_nalUnitType = newFrame->m_param->bEnableTemporalSubLayers ? NAL_UNIT_CODED_SLICE_TSA_N : NAL_UNIT_CODED_SLICE_TRAIL_N;
             break;
         case NAL_UNIT_CODED_SLICE_RADL_R:
             slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_RADL_N;
@@ -176,13 +177,94 @@ void DPB::prepareEncode(Frame *newFrame)
 
     m_picList.pushFront(*newFrame);
 
+    if (m_bTemporalSublayer && getTemporalLayerNonReferenceFlag())
+    {
+        switch (slice->m_nalUnitType)
+        {
+        case NAL_UNIT_CODED_SLICE_TRAIL_R:
+            slice->m_nalUnitType =  NAL_UNIT_CODED_SLICE_TRAIL_N;
+            break;
+        case NAL_UNIT_CODED_SLICE_RADL_R:
+            slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_RADL_N;
+            break;
+        case NAL_UNIT_CODED_SLICE_RASL_R:
+            slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_RASL_N;
+            break;
+        default:
+            break;
+        }
+    }
     // Do decoding refresh marking if any
     decodingRefreshMarking(pocCurr, slice->m_nalUnitType);
 
-    computeRPS(pocCurr, slice->isIRAP(), &slice->m_rps, slice->m_sps->maxDecPicBuffering[newFrame->m_tempLayer]);
-
+    computeRPS(pocCurr, newFrame->m_tempLayer, slice->isIRAP(), &slice->m_rps, slice->m_sps->maxDecPicBuffering[newFrame->m_tempLayer]);
+    bool isTSAPic = ((slice->m_nalUnitType == 2) || (slice->m_nalUnitType == 3)) ? true : false;
     // Mark pictures in m_piclist as unreferenced if they are not included in RPS
-    applyReferencePictureSet(&slice->m_rps, pocCurr);
+    applyReferencePictureSet(&slice->m_rps, pocCurr, newFrame->m_tempLayer, isTSAPic);
+
+
+    if (m_bTemporalSublayer && newFrame->m_tempLayer > 0
+        && !(slice->m_nalUnitType == NAL_UNIT_CODED_SLICE_RADL_N     // Check if not a leading picture
+            || slice->m_nalUnitType == NAL_UNIT_CODED_SLICE_RADL_R
+            || slice->m_nalUnitType == NAL_UNIT_CODED_SLICE_RASL_N
+            || slice->m_nalUnitType == NAL_UNIT_CODED_SLICE_RASL_R)
+        )
+    {
+        if (isTemporalLayerSwitchingPoint(pocCurr, newFrame->m_tempLayer) || (slice->m_sps->maxTempSubLayers == 1))
+        {
+            if (getTemporalLayerNonReferenceFlag())
+            {
+                slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_TSA_N;
+            }
+            else
+            {
+                slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_TSA_R;
+            }
+        }
+        else if (isStepwiseTemporalLayerSwitchingPoint(&slice->m_rps, pocCurr, newFrame->m_tempLayer))
+        {
+            bool isSTSA = true;
+            int id = newFrame->m_gopOffset % x265_gop_ra_length[newFrame->m_gopId];
+            for (int ii = id; (ii < x265_gop_ra_length[newFrame->m_gopId] && isSTSA == true); ii++)
+            {
+                int tempIdRef = x265_gop_ra[newFrame->m_gopId][ii].layer;
+                if (tempIdRef == newFrame->m_tempLayer)
+                {
+                    for (int jj = 0; jj < slice->m_rps.numberOfPositivePictures + slice->m_rps.numberOfNegativePictures; jj++)
+                    {
+                        if (slice->m_rps.bUsed[jj])
+                        {
+                            int refPoc = x265_gop_ra[newFrame->m_gopId][ii].poc_offset + slice->m_rps.deltaPOC[jj];
+                            int kk = 0;
+                            for (kk = 0; kk < x265_gop_ra_length[newFrame->m_gopId]; kk++)
+                            {
+                                if (x265_gop_ra[newFrame->m_gopId][kk].poc_offset == refPoc)
+                                {
+                                    break;
+                                }
+                            }
+                            if (x265_gop_ra[newFrame->m_gopId][kk].layer >= newFrame->m_tempLayer)
+                            {
+                                isSTSA = false;
+                                break;
+                            }
+                        }
+                    }
+                }
+            }
+            if (isSTSA == true)
+            {
+                if (getTemporalLayerNonReferenceFlag())
+                {
+                    slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_STSA_N;
+                }
+                else
+                {
+                    slice->m_nalUnitType = NAL_UNIT_CODED_SLICE_STSA_R;
+                }
+            }
+        }
+    }
 
     if (slice->m_sliceType != I_SLICE)
         slice->m_numRefIdx[0] = x265_clip3(1, newFrame->m_param->maxNumReferences, slice->m_rps.numberOfNegativePictures);
@@ -226,7 +308,7 @@ void DPB::prepareEncode(Frame *newFrame)
     }
 }
 
-void DPB::computeRPS(int curPoc, bool isRAP, RPS * rps, unsigned int maxDecPicBuffer)
+void DPB::computeRPS(int curPoc, int tempId, bool isRAP, RPS * rps, unsigned int maxDecPicBuffer)
 {
     unsigned int poci = 0, numNeg = 0, numPos = 0;
 
@@ -236,7 +318,7 @@ void DPB::computeRPS(int curPoc, bool isRAP, RPS * rps, unsigned int maxDecPicBu
     {
         if ((iterPic->m_poc != curPoc) && iterPic->m_encData->m_bHasReferences)
         {
-            if ((m_lastIDR >= curPoc) || (m_lastIDR <= iterPic->m_poc))
+            if ((!m_bTemporalSublayer || (iterPic->m_tempLayer <= tempId)) && ((m_lastIDR >= curPoc) || (m_lastIDR <= iterPic->m_poc)))
             {
                     rps->poc[poci] = iterPic->m_poc;
                     rps->deltaPOC[poci] = rps->poc[poci] - curPoc;
@@ -255,6 +337,18 @@ void DPB::computeRPS(int curPoc, bool isRAP, RPS * rps, unsigned int maxDecPicBu
     rps->sortDeltaPOC();
 }
 
+bool DPB::getTemporalLayerNonReferenceFlag()
+{
+    Frame* curFrame = m_picList.first();
+    if (curFrame->m_encData->m_bHasReferences)
+    {
+        curFrame->m_sameLayerRefPic = true;
+        return false;
+    }
+    else
+        return true;
+}
+
 /* Marking reference pictures when an IDR/CRA is encountered. */
 void DPB::decodingRefreshMarking(int pocCurr, NalUnitType nalUnitType)
 {
@@ -304,7 +398,7 @@ void DPB::decodingRefreshMarking(int pocCurr, NalUnitType nalUnitType)
 }
 
 /** Function for applying picture marking based on the Reference Picture Set */
-void DPB::applyReferencePictureSet(RPS *rps, int curPoc)
+void DPB::applyReferencePictureSet(RPS *rps, int curPoc, int tempId, bool isTSAPicture)
 {
     // loop through all pictures in the reference picture buffer
     Frame* iterFrame = m_picList.first();
@@ -325,9 +419,68 @@ void DPB::applyReferencePictureSet(RPS *rps, int curPoc)
             }
             if (!referenced)
                 iterFrame->m_encData->m_bHasReferences = false;
+
+            if (m_bTemporalSublayer)
+            {
+                //check that pictures of higher temporal layers are not used
+                assert(referenced == 0 || iterFrame->m_encData->m_bHasReferences == false || iterFrame->m_tempLayer <= tempId);
+
+                //check that pictures of higher or equal temporal layer are not in the RPS if the current picture is a TSA picture
+                if (isTSAPicture)
+                {
+                    assert(referenced == 0 || iterFrame->m_tempLayer < tempId);
+                }
+                //check that pictures marked as temporal layer non-reference pictures are not used for reference
+                if (iterFrame->m_tempLayer == tempId)
+                {
+                    assert(referenced == 0 || iterFrame->m_sameLayerRefPic == true);
+                }
+            }
+        }
+        iterFrame = iterFrame->m_next;
+    }
+}
+
+bool DPB::isTemporalLayerSwitchingPoint(int curPoc, int tempId)
+{
+    // loop through all pictures in the reference picture buffer
+    Frame* iterFrame = m_picList.first();
+    while (iterFrame)
+    {
+        if (iterFrame->m_poc != curPoc && iterFrame->m_encData->m_bHasReferences)
+        {
+            if (iterFrame->m_tempLayer >= tempId)
+            {
+                return false;
+            }
+        }
+        iterFrame = iterFrame->m_next;
+    }
+    return true;
+}
+
+bool DPB::isStepwiseTemporalLayerSwitchingPoint(RPS *rps, int curPoc, int tempId)
+{
+    // loop through all pictures in the reference picture buffer
+    Frame* iterFrame = m_picList.first();
+    while (iterFrame)
+    {
+        if (iterFrame->m_poc != curPoc && iterFrame->m_encData->m_bHasReferences)
+        {
+            for (int i = 0; i < rps->numberOfPositivePictures + rps->numberOfNegativePictures; i++)
+            {
+                if ((iterFrame->m_poc == curPoc + rps->deltaPOC[i]) && rps->bUsed[i])
+                {
+                    if (iterFrame->m_tempLayer >= tempId)
+                    {
+                        return false;
+                    }
+                }
+            }
         }
         iterFrame = iterFrame->m_next;
     }
+    return true;
 }
 
 /* deciding the nal_unit_type */
diff --git a/source/encoder/dpb.h b/source/encoder/dpb.h
index e47d54d61..2cc7df778 100644
--- a/source/encoder/dpb.h
+++ b/source/encoder/dpb.h
@@ -66,7 +66,7 @@ public:
         m_bRefreshPending = false;
         m_frameDataFreeList = NULL;
         m_bOpenGOP = param->bOpenGOP;
-        m_bTemporalSublayer = !!param->bEnableTemporalSubLayers;
+        m_bTemporalSublayer = (param->bEnableTemporalSubLayers > 2);
     }
 
     ~DPB();
@@ -77,10 +77,13 @@ public:
 
 protected:
 
-    void computeRPS(int curPoc, bool isRAP, RPS * rps, unsigned int maxDecPicBuffer);
+    void computeRPS(int curPoc,int tempId, bool isRAP, RPS * rps, unsigned int maxDecPicBuffer);
 
-    void applyReferencePictureSet(RPS *rps, int curPoc);
+    void applyReferencePictureSet(RPS *rps, int curPoc, int tempId, bool isTSAPicture);
+    bool getTemporalLayerNonReferenceFlag();
     void decodingRefreshMarking(int pocCurr, NalUnitType nalUnitType);
+    bool isTemporalLayerSwitchingPoint(int curPoc, int tempId);
+    bool isStepwiseTemporalLayerSwitchingPoint(RPS *rps, int curPoc, int tempId);
 
     NalUnitType getNalUnitType(int curPoc, bool bIsKeyFrame);
 };
diff --git a/source/encoder/encoder.cpp b/source/encoder/encoder.cpp
index 51068a875..64a4e231c 100644
--- a/source/encoder/encoder.cpp
+++ b/source/encoder/encoder.cpp
@@ -1637,6 +1637,8 @@ int Encoder::encode(const x265_picture* pic_in, x265_picture* pic_out)
             inFrame->m_lowres.satdCost = (int64_t)-1;
             inFrame->m_lowresInit = false;
             inFrame->m_isInsideWindow = 0;
+            inFrame->m_tempLayer = 0;
+            inFrame->m_sameLayerRefPic = 0;
         }
 
         /* Copy input picture into a Frame and PicYuv, send to lookahead */
diff --git a/source/x265.h b/source/x265.h
index 86d324d10..81df146e8 100644
--- a/source/x265.h
+++ b/source/x265.h
@@ -60,7 +60,7 @@ typedef enum
     NAL_UNIT_CODED_SLICE_TRAIL_N = 0,
     NAL_UNIT_CODED_SLICE_TRAIL_R,
     NAL_UNIT_CODED_SLICE_TSA_N,
-    NAL_UNIT_CODED_SLICE_TLA_R,
+    NAL_UNIT_CODED_SLICE_TSA_R,
     NAL_UNIT_CODED_SLICE_STSA_N,
     NAL_UNIT_CODED_SLICE_STSA_R,
     NAL_UNIT_CODED_SLICE_RADL_N,
-- 
2.28.0.windows.1

