Index: webrtc/call/rtp_demuxer.cc |
diff --git a/webrtc/call/rtp_demuxer.cc b/webrtc/call/rtp_demuxer.cc |
index 916e612c17ca394573a451bb8b1a9081f5fed660..b2c9510ccabc3905074afd64263487079b2c04eb 100644 |
--- a/webrtc/call/rtp_demuxer.cc |
+++ b/webrtc/call/rtp_demuxer.cc |
@@ -11,12 +11,15 @@ |
#include "webrtc/base/checks.h" |
#include "webrtc/call/rtp_demuxer.h" |
#include "webrtc/call/rtp_packet_sink_interface.h" |
+#include "webrtc/modules/rtp_rtcp/source/rtp_header_extensions.h" |
#include "webrtc/modules/rtp_rtcp/source/rtp_packet_received.h" |
namespace webrtc { |
namespace { |
+constexpr size_t kMaxProcessedSsrcs = 1000; // Prevent memory overuse. |
+ |
template <typename Key, typename Value> |
bool MultimapAssociationExists(const std::multimap<Key, Value>& multimap, |
Key key, |
@@ -27,6 +30,21 @@ bool MultimapAssociationExists(const std::multimap<Key, Value>& multimap, |
[val](Reference elem) { return elem.second == val; }); |
} |
+template <typename Key, typename Value> |
+size_t RemoveFromMultimapByValue(std::multimap<Key, Value*>* multimap, |
+ const Value* value) { |
+ size_t count = 0; |
+ for (auto it = multimap->begin(); it != multimap->end();) { |
+ if (it->second == value) { |
+ it = multimap->erase(it); |
+ ++count; |
+ } else { |
+ ++it; |
+ } |
+ } |
+ return count; |
+} |
+ |
} // namespace |
RtpDemuxer::RtpDemuxer() {} |
@@ -36,33 +54,71 @@ RtpDemuxer::~RtpDemuxer() { |
} |
void RtpDemuxer::AddSink(uint32_t ssrc, RtpPacketSinkInterface* sink) { |
+ RecordSsrcToSinkAssociation(ssrc, sink); |
+} |
+ |
+void RtpDemuxer::AddSink(const std::string& rsid, |
+ RtpPacketSinkInterface* sink) { |
+ RTC_DCHECK(StreamId::IsLegalName(rsid)); |
RTC_DCHECK(sink); |
- RTC_DCHECK(!MultimapAssociationExists(sinks_, ssrc, sink)); |
- sinks_.emplace(ssrc, sink); |
+ RTC_DCHECK(!MultimapAssociationExists(rsid_sinks_, rsid, sink)); |
+ |
+ rsid_sinks_.emplace(rsid, sink); |
+ |
+ // This RSID might now map to an SSRC which we saw earlier. |
+ processed_ssrcs_.clear(); |
} |
-size_t RtpDemuxer::RemoveSink(const RtpPacketSinkInterface* sink) { |
+bool RtpDemuxer::RemoveSink(const RtpPacketSinkInterface* sink) { |
RTC_DCHECK(sink); |
- size_t count = 0; |
- for (auto it = sinks_.begin(); it != sinks_.end(); ) { |
- if (it->second == sink) { |
- it = sinks_.erase(it); |
- ++count; |
- } else { |
- ++it; |
- } |
+ return (RemoveFromMultimapByValue(&sinks_, sink) + |
+ RemoveFromMultimapByValue(&rsid_sinks_, sink)) > 0; |
+} |
+ |
+void RtpDemuxer::RecordSsrcToSinkAssociation(uint32_t ssrc, |
+ RtpPacketSinkInterface* sink) { |
+ RTC_DCHECK(sink); |
+ // The association might already have been set by a different |
+ // configuration source. |
+ if (!MultimapAssociationExists(sinks_, ssrc, sink)) { |
+ sinks_.emplace(ssrc, sink); |
} |
- return count; |
} |
bool RtpDemuxer::OnRtpPacket(const RtpPacketReceived& packet) { |
- bool found = false; |
+ FindSsrcAssociations(packet); |
auto it_range = sinks_.equal_range(packet.Ssrc()); |
for (auto it = it_range.first; it != it_range.second; ++it) { |
- found = true; |
it->second->OnRtpPacket(packet); |
} |
- return found; |
+ return it_range.first != it_range.second; |
+} |
+ |
+void RtpDemuxer::FindSsrcAssociations(const RtpPacketReceived& packet) { |
+ // Avoid expensive string comparisons for RSID by looking the sinks up only |
+ // by SSRC whenever possible. |
+ if (processed_ssrcs_.find(packet.Ssrc()) != processed_ssrcs_.cend()) { |
+ return; |
+ } |
+ |
+ // RSID-based associations: |
+ std::string rsid; |
+ if (packet.GetExtension<RtpStreamId>(&rsid)) { |
+ // All streams associated with this RSID need to be marked as associated |
+ // with this SSRC (if they aren't already). |
+ auto it_range = rsid_sinks_.equal_range(rsid); |
+ for (auto it = it_range.first; it != it_range.second; ++it) { |
+ RecordSsrcToSinkAssociation(packet.Ssrc(), it->second); |
+ } |
+ |
+ // To prevent memory-overuse attacks, forget this RSID. Future packets |
+ // with this RSID, but a different SSRC, will not spawn new associations. |
+ rsid_sinks_.erase(it_range.first, it_range.second); |
+ } |
+ |
+ if (processed_ssrcs_.size() < kMaxProcessedSsrcs) { // Prevent memory overuse |
+ processed_ssrcs_.insert(packet.Ssrc()); // Avoid re-examining in-depth. |
+ } |
} |
} // namespace webrtc |