Index: webrtc/base/natsocketfactory.cc |
diff --git a/webrtc/base/natsocketfactory.cc b/webrtc/base/natsocketfactory.cc |
index d240527e4348fe2f74660de82d09375e96fb8c3c..95e64872199db3bf059614fcb8f8d7733e571e31 100644 |
--- a/webrtc/base/natsocketfactory.cc |
+++ b/webrtc/base/natsocketfactory.cc |
@@ -97,29 +97,20 @@ class NATSocket : public AsyncSocket, public sigslot::has_slots<> { |
return -1; |
} |
- int result; |
- socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); |
- result = (socket_) ? socket_->Bind(addr) : -1; |
- if (result >= 0) { |
- socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); |
- socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); |
- socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); |
- socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); |
- } else { |
- server_addr_.Clear(); |
- delete socket_; |
- socket_ = nullptr; |
- } |
- |
- return result; |
+ return BindInternal(addr); |
} |
int Connect(const SocketAddress& addr) override { |
- if (!socket_) { // socket must be bound, for now |
- return -1; |
+ int result = 0; |
+ // If we're not already bound (meaning |socket_| is null), bind to ANY |
+ // address. |
+ if (!socket_) { |
+ result = BindInternal(SocketAddress(GetAnyIP(family_), 0)); |
+ if (result < 0) { |
+ return result; |
+ } |
} |
- int result = 0; |
if (type_ == SOCK_STREAM) { |
result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_); |
} else { |
@@ -225,8 +216,16 @@ class NATSocket : public AsyncSocket, public sigslot::has_slots<> { |
AsyncSocket* Accept(SocketAddress* paddr) override { |
return socket_->Accept(paddr); |
} |
- int GetError() const override { return socket_->GetError(); } |
- void SetError(int error) override { socket_->SetError(error); } |
+ int GetError() const override { |
+ return socket_ ? socket_->GetError() : error_; |
+ } |
+ void SetError(int error) override { |
+ if (socket_) { |
+ socket_->SetError(error); |
+ } else { |
+ error_ = error; |
+ } |
+ } |
ConnState GetState() const override { |
return connected_ ? CS_CONNECTED : CS_CLOSED; |
} |
@@ -266,6 +265,26 @@ class NATSocket : public AsyncSocket, public sigslot::has_slots<> { |
} |
private: |
+ int BindInternal(const SocketAddress& addr) { |
+ RTC_DCHECK(!socket_); |
+ |
+ int result; |
+ socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); |
+ result = (socket_) ? socket_->Bind(addr) : -1; |
+ if (result >= 0) { |
+ socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); |
+ socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); |
+ socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); |
+ socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); |
+ } else { |
+ server_addr_.Clear(); |
+ delete socket_; |
+ socket_ = nullptr; |
+ } |
+ |
+ return result; |
+ } |
+ |
// Makes sure the buffer is at least the given size. |
void Grow(size_t new_size) { |
if (size_ < new_size) { |
@@ -302,6 +321,8 @@ class NATSocket : public AsyncSocket, public sigslot::has_slots<> { |
SocketAddress remote_addr_; |
SocketAddress server_addr_; // address of the NAT server |
AsyncSocket* socket_; |
+ // Need to hold error in case it occurs before the socket is created. |
+ int error_ = 0; |
char* buf_; |
size_t size_; |
}; |