OLD | NEW |
| (Empty) |
1 /* | |
2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. | |
3 * | |
4 * Use of this source code is governed by a BSD-style license | |
5 * that can be found in the LICENSE file in the root of the source | |
6 * tree. An additional intellectual property rights grant can be found | |
7 * in the file PATENTS. All contributing project authors may | |
8 * be found in the AUTHORS file in the root of the source tree. | |
9 */ | |
10 | |
11 #include "webrtc/base/win32.h" | |
12 #define SECURITY_WIN32 | |
13 #include <security.h> | |
14 #include <schannel.h> | |
15 | |
16 #include <algorithm> | |
17 #include <iomanip> | |
18 #include <vector> | |
19 | |
20 #include "webrtc/base/common.h" | |
21 #include "webrtc/base/logging.h" | |
22 #include "webrtc/base/schanneladapter.h" | |
23 #include "webrtc/base/sec_buffer.h" | |
24 #include "webrtc/base/thread.h" | |
25 | |
26 namespace rtc { | |
27 | |
28 ///////////////////////////////////////////////////////////////////////////// | |
29 // SChannelAdapter | |
30 ///////////////////////////////////////////////////////////////////////////// | |
31 | |
32 extern const ConstantLabel SECURITY_ERRORS[]; | |
33 | |
34 const ConstantLabel SCHANNEL_BUFFER_TYPES[] = { | |
35 KLABEL(SECBUFFER_EMPTY), // 0 | |
36 KLABEL(SECBUFFER_DATA), // 1 | |
37 KLABEL(SECBUFFER_TOKEN), // 2 | |
38 KLABEL(SECBUFFER_PKG_PARAMS), // 3 | |
39 KLABEL(SECBUFFER_MISSING), // 4 | |
40 KLABEL(SECBUFFER_EXTRA), // 5 | |
41 KLABEL(SECBUFFER_STREAM_TRAILER), // 6 | |
42 KLABEL(SECBUFFER_STREAM_HEADER), // 7 | |
43 KLABEL(SECBUFFER_MECHLIST), // 11 | |
44 KLABEL(SECBUFFER_MECHLIST_SIGNATURE), // 12 | |
45 KLABEL(SECBUFFER_TARGET), // 13 | |
46 KLABEL(SECBUFFER_CHANNEL_BINDINGS), // 14 | |
47 LASTLABEL | |
48 }; | |
49 | |
50 void DescribeBuffer(LoggingSeverity severity, const char* prefix, | |
51 const SecBuffer& sb) { | |
52 LOG_V(severity) | |
53 << prefix | |
54 << "(" << sb.cbBuffer | |
55 << ", " << FindLabel(sb.BufferType & ~SECBUFFER_ATTRMASK, | |
56 SCHANNEL_BUFFER_TYPES) | |
57 << ", " << sb.pvBuffer << ")"; | |
58 } | |
59 | |
60 void DescribeBuffers(LoggingSeverity severity, const char* prefix, | |
61 const SecBufferDesc* sbd) { | |
62 if (!LOG_CHECK_LEVEL_V(severity)) | |
63 return; | |
64 LOG_V(severity) << prefix << "("; | |
65 for (size_t i=0; i<sbd->cBuffers; ++i) { | |
66 DescribeBuffer(severity, " ", sbd->pBuffers[i]); | |
67 } | |
68 LOG_V(severity) << ")"; | |
69 } | |
70 | |
71 const ULONG SSL_FLAGS_DEFAULT = ISC_REQ_ALLOCATE_MEMORY | |
72 | ISC_REQ_CONFIDENTIALITY | |
73 | ISC_REQ_EXTENDED_ERROR | |
74 | ISC_REQ_INTEGRITY | |
75 | ISC_REQ_REPLAY_DETECT | |
76 | ISC_REQ_SEQUENCE_DETECT | |
77 | ISC_REQ_STREAM; | |
78 //| ISC_REQ_USE_SUPPLIED_CREDS; | |
79 | |
80 typedef std::vector<char> SChannelBuffer; | |
81 | |
82 struct SChannelAdapter::SSLImpl { | |
83 CredHandle cred; | |
84 CtxtHandle ctx; | |
85 bool cred_init, ctx_init; | |
86 SChannelBuffer inbuf, outbuf, readable; | |
87 SecPkgContext_StreamSizes sizes; | |
88 | |
89 SSLImpl() : cred_init(false), ctx_init(false) { } | |
90 }; | |
91 | |
92 SChannelAdapter::SChannelAdapter(AsyncSocket* socket) | |
93 : SSLAdapter(socket), state_(SSL_NONE), mode_(SSL_MODE_TLS), | |
94 restartable_(false), signal_close_(false), message_pending_(false), | |
95 impl_(new SSLImpl) { | |
96 } | |
97 | |
98 SChannelAdapter::~SChannelAdapter() { | |
99 Cleanup(); | |
100 } | |
101 | |
102 void | |
103 SChannelAdapter::SetMode(SSLMode mode) { | |
104 // SSL_MODE_DTLS isn't supported. | |
105 ASSERT(mode == SSL_MODE_TLS); | |
106 mode_ = mode; | |
107 } | |
108 | |
109 int | |
110 SChannelAdapter::StartSSL(const char* hostname, bool restartable) { | |
111 if (state_ != SSL_NONE) | |
112 return -1; | |
113 | |
114 if (mode_ != SSL_MODE_TLS) | |
115 return -1; | |
116 | |
117 ssl_host_name_ = hostname; | |
118 restartable_ = restartable; | |
119 | |
120 if (socket_->GetState() != Socket::CS_CONNECTED) { | |
121 state_ = SSL_WAIT; | |
122 return 0; | |
123 } | |
124 | |
125 state_ = SSL_CONNECTING; | |
126 if (int err = BeginSSL()) { | |
127 Error("BeginSSL", err, false); | |
128 return err; | |
129 } | |
130 | |
131 return 0; | |
132 } | |
133 | |
134 int | |
135 SChannelAdapter::BeginSSL() { | |
136 LOG(LS_VERBOSE) << "BeginSSL: " << ssl_host_name_; | |
137 ASSERT(state_ == SSL_CONNECTING); | |
138 | |
139 SECURITY_STATUS ret; | |
140 | |
141 SCHANNEL_CRED sc_cred = { 0 }; | |
142 sc_cred.dwVersion = SCHANNEL_CRED_VERSION; | |
143 //sc_cred.dwMinimumCipherStrength = 128; // Note: use system default | |
144 sc_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_AUTO_CRED_VALIDATION; | |
145 | |
146 ret = AcquireCredentialsHandle(NULL, const_cast<LPTSTR>(UNISP_NAME), | |
147 SECPKG_CRED_OUTBOUND, NULL, &sc_cred, NULL, | |
148 NULL, &impl_->cred, NULL); | |
149 if (ret != SEC_E_OK) { | |
150 LOG(LS_ERROR) << "AcquireCredentialsHandle error: " | |
151 << ErrorName(ret, SECURITY_ERRORS); | |
152 return ret; | |
153 } | |
154 impl_->cred_init = true; | |
155 | |
156 if (LOG_CHECK_LEVEL(LS_VERBOSE)) { | |
157 SecPkgCred_CipherStrengths cipher_strengths = { 0 }; | |
158 ret = QueryCredentialsAttributes(&impl_->cred, | |
159 SECPKG_ATTR_CIPHER_STRENGTHS, | |
160 &cipher_strengths); | |
161 if (SUCCEEDED(ret)) { | |
162 LOG(LS_VERBOSE) << "SChannel cipher strength: " | |
163 << cipher_strengths.dwMinimumCipherStrength << " - " | |
164 << cipher_strengths.dwMaximumCipherStrength; | |
165 } | |
166 | |
167 SecPkgCred_SupportedAlgs supported_algs = { 0 }; | |
168 ret = QueryCredentialsAttributes(&impl_->cred, | |
169 SECPKG_ATTR_SUPPORTED_ALGS, | |
170 &supported_algs); | |
171 if (SUCCEEDED(ret)) { | |
172 LOG(LS_VERBOSE) << "SChannel supported algorithms:"; | |
173 for (DWORD i=0; i<supported_algs.cSupportedAlgs; ++i) { | |
174 ALG_ID alg_id = supported_algs.palgSupportedAlgs[i]; | |
175 PCCRYPT_OID_INFO oinfo = CryptFindOIDInfo(CRYPT_OID_INFO_ALGID_KEY, | |
176 &alg_id, 0); | |
177 LPCWSTR alg_name = (NULL != oinfo) ? oinfo->pwszName : L"Unknown"; | |
178 LOG(LS_VERBOSE) << " " << ToUtf8(alg_name) << " (" << alg_id << ")"; | |
179 } | |
180 CSecBufferBase::FreeSSPI(supported_algs.palgSupportedAlgs); | |
181 } | |
182 } | |
183 | |
184 ULONG flags = SSL_FLAGS_DEFAULT, ret_flags = 0; | |
185 if (ignore_bad_cert()) | |
186 flags |= ISC_REQ_MANUAL_CRED_VALIDATION; | |
187 | |
188 CSecBufferBundle<2, CSecBufferBase::FreeSSPI> sb_out; | |
189 ret = InitializeSecurityContextA(&impl_->cred, NULL, | |
190 const_cast<char*>(ssl_host_name_.c_str()), | |
191 flags, 0, 0, NULL, 0, | |
192 &impl_->ctx, sb_out.desc(), | |
193 &ret_flags, NULL); | |
194 if (SUCCEEDED(ret)) | |
195 impl_->ctx_init = true; | |
196 return ProcessContext(ret, NULL, sb_out.desc()); | |
197 } | |
198 | |
199 int | |
200 SChannelAdapter::ContinueSSL() { | |
201 LOG(LS_VERBOSE) << "ContinueSSL"; | |
202 ASSERT(state_ == SSL_CONNECTING); | |
203 | |
204 SECURITY_STATUS ret; | |
205 | |
206 CSecBufferBundle<2> sb_in; | |
207 sb_in[0].BufferType = SECBUFFER_TOKEN; | |
208 sb_in[0].cbBuffer = static_cast<unsigned long>(impl_->inbuf.size()); | |
209 sb_in[0].pvBuffer = &impl_->inbuf[0]; | |
210 //DescribeBuffers(LS_VERBOSE, "Input Buffer ", sb_in.desc()); | |
211 | |
212 ULONG flags = SSL_FLAGS_DEFAULT, ret_flags = 0; | |
213 if (ignore_bad_cert()) | |
214 flags |= ISC_REQ_MANUAL_CRED_VALIDATION; | |
215 | |
216 CSecBufferBundle<2, CSecBufferBase::FreeSSPI> sb_out; | |
217 ret = InitializeSecurityContextA(&impl_->cred, &impl_->ctx, | |
218 const_cast<char*>(ssl_host_name_.c_str()), | |
219 flags, 0, 0, sb_in.desc(), 0, | |
220 NULL, sb_out.desc(), | |
221 &ret_flags, NULL); | |
222 return ProcessContext(ret, sb_in.desc(), sb_out.desc()); | |
223 } | |
224 | |
225 int | |
226 SChannelAdapter::ProcessContext(long int status, _SecBufferDesc* sbd_in, | |
227 _SecBufferDesc* sbd_out) { | |
228 if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED && | |
229 status != SEC_E_INCOMPLETE_MESSAGE) { | |
230 LOG(LS_ERROR) | |
231 << "InitializeSecurityContext error: " | |
232 << ErrorName(status, SECURITY_ERRORS); | |
233 } | |
234 //if (sbd_in) | |
235 // DescribeBuffers(LS_VERBOSE, "Input Buffer ", sbd_in); | |
236 //if (sbd_out) | |
237 // DescribeBuffers(LS_VERBOSE, "Output Buffer ", sbd_out); | |
238 | |
239 if (status == SEC_E_INCOMPLETE_MESSAGE) { | |
240 // Wait for more input from server. | |
241 return Flush(); | |
242 } | |
243 | |
244 if (FAILED(status)) { | |
245 // We can't continue. Common errors: | |
246 // SEC_E_CERT_EXPIRED - Typically, this means the computer clock is wrong. | |
247 return status; | |
248 } | |
249 | |
250 // Note: we check both input and output buffers for SECBUFFER_EXTRA. | |
251 // Experience shows it appearing in the input, but the documentation claims | |
252 // it should appear in the output. | |
253 size_t extra = 0; | |
254 if (sbd_in) { | |
255 for (size_t i=0; i<sbd_in->cBuffers; ++i) { | |
256 SecBuffer& buffer = sbd_in->pBuffers[i]; | |
257 if (buffer.BufferType == SECBUFFER_EXTRA) { | |
258 extra += buffer.cbBuffer; | |
259 } | |
260 } | |
261 } | |
262 if (sbd_out) { | |
263 for (size_t i=0; i<sbd_out->cBuffers; ++i) { | |
264 SecBuffer& buffer = sbd_out->pBuffers[i]; | |
265 if (buffer.BufferType == SECBUFFER_EXTRA) { | |
266 extra += buffer.cbBuffer; | |
267 } else if (buffer.BufferType == SECBUFFER_TOKEN) { | |
268 impl_->outbuf.insert(impl_->outbuf.end(), | |
269 reinterpret_cast<char*>(buffer.pvBuffer), | |
270 reinterpret_cast<char*>(buffer.pvBuffer) + buffer.cbBuffer); | |
271 } | |
272 } | |
273 } | |
274 | |
275 if (extra) { | |
276 ASSERT(extra <= impl_->inbuf.size()); | |
277 size_t consumed = impl_->inbuf.size() - extra; | |
278 memmove(&impl_->inbuf[0], &impl_->inbuf[consumed], extra); | |
279 impl_->inbuf.resize(extra); | |
280 } else { | |
281 impl_->inbuf.clear(); | |
282 } | |
283 | |
284 if (SEC_I_CONTINUE_NEEDED == status) { | |
285 // Send data to server and wait for response. | |
286 // Note: ContinueSSL will result in a Flush, anyway. | |
287 return impl_->inbuf.empty() ? Flush() : ContinueSSL(); | |
288 } | |
289 | |
290 if (SEC_E_OK == status) { | |
291 LOG(LS_VERBOSE) << "QueryContextAttributes"; | |
292 status = QueryContextAttributes(&impl_->ctx, SECPKG_ATTR_STREAM_SIZES, | |
293 &impl_->sizes); | |
294 if (FAILED(status)) { | |
295 LOG(LS_ERROR) << "QueryContextAttributes error: " | |
296 << ErrorName(status, SECURITY_ERRORS); | |
297 return status; | |
298 } | |
299 | |
300 state_ = SSL_CONNECTED; | |
301 | |
302 if (int err = DecryptData()) { | |
303 return err; | |
304 } else if (int err = Flush()) { | |
305 return err; | |
306 } else { | |
307 // If we decrypted any data, queue up a notification here | |
308 PostEvent(); | |
309 // Signal our connectedness | |
310 AsyncSocketAdapter::OnConnectEvent(this); | |
311 } | |
312 return 0; | |
313 } | |
314 | |
315 if (SEC_I_INCOMPLETE_CREDENTIALS == status) { | |
316 // We don't support client authentication in schannel. | |
317 return status; | |
318 } | |
319 | |
320 // We don't expect any other codes | |
321 ASSERT(false); | |
322 return status; | |
323 } | |
324 | |
325 int | |
326 SChannelAdapter::DecryptData() { | |
327 SChannelBuffer& inbuf = impl_->inbuf; | |
328 SChannelBuffer& readable = impl_->readable; | |
329 | |
330 while (!inbuf.empty()) { | |
331 CSecBufferBundle<4> in_buf; | |
332 in_buf[0].BufferType = SECBUFFER_DATA; | |
333 in_buf[0].cbBuffer = static_cast<unsigned long>(inbuf.size()); | |
334 in_buf[0].pvBuffer = &inbuf[0]; | |
335 | |
336 //DescribeBuffers(LS_VERBOSE, "Decrypt In ", in_buf.desc()); | |
337 SECURITY_STATUS status = DecryptMessage(&impl_->ctx, in_buf.desc(), 0, 0); | |
338 //DescribeBuffers(LS_VERBOSE, "Decrypt Out ", in_buf.desc()); | |
339 | |
340 // Note: We are explicitly treating SEC_E_OK, SEC_I_CONTEXT_EXPIRED, and | |
341 // any other successful results as continue. | |
342 if (SUCCEEDED(status)) { | |
343 size_t data_len = 0, extra_len = 0; | |
344 for (size_t i=0; i<in_buf.desc()->cBuffers; ++i) { | |
345 if (in_buf[i].BufferType == SECBUFFER_DATA) { | |
346 data_len += in_buf[i].cbBuffer; | |
347 readable.insert(readable.end(), | |
348 reinterpret_cast<char*>(in_buf[i].pvBuffer), | |
349 reinterpret_cast<char*>(in_buf[i].pvBuffer) + in_buf[i].cbBuffer); | |
350 } else if (in_buf[i].BufferType == SECBUFFER_EXTRA) { | |
351 extra_len += in_buf[i].cbBuffer; | |
352 } | |
353 } | |
354 // There is a bug on Win2K where SEC_I_CONTEXT_EXPIRED is misclassified. | |
355 if ((data_len == 0) && (inbuf[0] == 0x15)) { | |
356 status = SEC_I_CONTEXT_EXPIRED; | |
357 } | |
358 if (extra_len) { | |
359 size_t consumed = inbuf.size() - extra_len; | |
360 memmove(&inbuf[0], &inbuf[consumed], extra_len); | |
361 inbuf.resize(extra_len); | |
362 } else { | |
363 inbuf.clear(); | |
364 } | |
365 // TODO: Handle SEC_I_CONTEXT_EXPIRED to do clean shutdown | |
366 if (status != SEC_E_OK) { | |
367 LOG(LS_INFO) << "DecryptMessage returned continuation code: " | |
368 << ErrorName(status, SECURITY_ERRORS); | |
369 } | |
370 continue; | |
371 } | |
372 | |
373 if (status == SEC_E_INCOMPLETE_MESSAGE) { | |
374 break; | |
375 } else { | |
376 return status; | |
377 } | |
378 } | |
379 | |
380 return 0; | |
381 } | |
382 | |
383 void | |
384 SChannelAdapter::Cleanup() { | |
385 if (impl_->ctx_init) | |
386 DeleteSecurityContext(&impl_->ctx); | |
387 if (impl_->cred_init) | |
388 FreeCredentialsHandle(&impl_->cred); | |
389 delete impl_; | |
390 } | |
391 | |
392 void | |
393 SChannelAdapter::PostEvent() { | |
394 // Check if there's anything notable to signal | |
395 if (impl_->readable.empty() && !signal_close_) | |
396 return; | |
397 | |
398 // Only one post in the queue at a time | |
399 if (message_pending_) | |
400 return; | |
401 | |
402 if (Thread* thread = Thread::Current()) { | |
403 message_pending_ = true; | |
404 thread->Post(this); | |
405 } else { | |
406 LOG(LS_ERROR) << "No thread context available for SChannelAdapter"; | |
407 ASSERT(false); | |
408 } | |
409 } | |
410 | |
411 void | |
412 SChannelAdapter::Error(const char* context, int err, bool signal) { | |
413 LOG(LS_WARNING) << "SChannelAdapter::Error(" | |
414 << context << ", " | |
415 << ErrorName(err, SECURITY_ERRORS) << ")"; | |
416 state_ = SSL_ERROR; | |
417 SetError(err); | |
418 if (signal) | |
419 AsyncSocketAdapter::OnCloseEvent(this, err); | |
420 } | |
421 | |
422 int | |
423 SChannelAdapter::Read() { | |
424 char buffer[4096]; | |
425 SChannelBuffer& inbuf = impl_->inbuf; | |
426 while (true) { | |
427 int ret = AsyncSocketAdapter::Recv(buffer, sizeof(buffer)); | |
428 if (ret > 0) { | |
429 inbuf.insert(inbuf.end(), buffer, buffer + ret); | |
430 } else if (GetError() == EWOULDBLOCK) { | |
431 return 0; // Blocking | |
432 } else { | |
433 return GetError(); | |
434 } | |
435 } | |
436 } | |
437 | |
438 int | |
439 SChannelAdapter::Flush() { | |
440 int result = 0; | |
441 size_t pos = 0; | |
442 SChannelBuffer& outbuf = impl_->outbuf; | |
443 while (pos < outbuf.size()) { | |
444 int sent = AsyncSocketAdapter::Send(&outbuf[pos], outbuf.size() - pos); | |
445 if (sent > 0) { | |
446 pos += sent; | |
447 } else if (GetError() == EWOULDBLOCK) { | |
448 break; // Blocking | |
449 } else { | |
450 result = GetError(); | |
451 break; | |
452 } | |
453 } | |
454 if (int remainder = static_cast<int>(outbuf.size() - pos)) { | |
455 memmove(&outbuf[0], &outbuf[pos], remainder); | |
456 outbuf.resize(remainder); | |
457 } else { | |
458 outbuf.clear(); | |
459 } | |
460 return result; | |
461 } | |
462 | |
463 // | |
464 // AsyncSocket Implementation | |
465 // | |
466 | |
467 int | |
468 SChannelAdapter::Send(const void* pv, size_t cb) { | |
469 switch (state_) { | |
470 case SSL_NONE: | |
471 return AsyncSocketAdapter::Send(pv, cb); | |
472 | |
473 case SSL_WAIT: | |
474 case SSL_CONNECTING: | |
475 SetError(EWOULDBLOCK); | |
476 return SOCKET_ERROR; | |
477 | |
478 case SSL_CONNECTED: | |
479 break; | |
480 | |
481 case SSL_ERROR: | |
482 default: | |
483 return SOCKET_ERROR; | |
484 } | |
485 | |
486 size_t written = 0; | |
487 SChannelBuffer& outbuf = impl_->outbuf; | |
488 while (written < cb) { | |
489 const size_t encrypt_len = std::min<size_t>(cb - written, | |
490 impl_->sizes.cbMaximumMessage); | |
491 | |
492 CSecBufferBundle<4> out_buf; | |
493 out_buf[0].BufferType = SECBUFFER_STREAM_HEADER; | |
494 out_buf[0].cbBuffer = impl_->sizes.cbHeader; | |
495 out_buf[1].BufferType = SECBUFFER_DATA; | |
496 out_buf[1].cbBuffer = static_cast<unsigned long>(encrypt_len); | |
497 out_buf[2].BufferType = SECBUFFER_STREAM_TRAILER; | |
498 out_buf[2].cbBuffer = impl_->sizes.cbTrailer; | |
499 | |
500 size_t packet_len = out_buf[0].cbBuffer | |
501 + out_buf[1].cbBuffer | |
502 + out_buf[2].cbBuffer; | |
503 | |
504 SChannelBuffer message; | |
505 message.resize(packet_len); | |
506 out_buf[0].pvBuffer = &message[0]; | |
507 out_buf[1].pvBuffer = &message[out_buf[0].cbBuffer]; | |
508 out_buf[2].pvBuffer = &message[out_buf[0].cbBuffer + out_buf[1].cbBuffer]; | |
509 | |
510 memcpy(out_buf[1].pvBuffer, | |
511 static_cast<const char*>(pv) + written, | |
512 encrypt_len); | |
513 | |
514 //DescribeBuffers(LS_VERBOSE, "Encrypt In ", out_buf.desc()); | |
515 SECURITY_STATUS res = EncryptMessage(&impl_->ctx, 0, out_buf.desc(), 0); | |
516 //DescribeBuffers(LS_VERBOSE, "Encrypt Out ", out_buf.desc()); | |
517 | |
518 if (FAILED(res)) { | |
519 Error("EncryptMessage", res, false); | |
520 return SOCKET_ERROR; | |
521 } | |
522 | |
523 // We assume that the header and data segments do not change length, | |
524 // or else encrypting the concatenated packet in-place is wrong. | |
525 ASSERT(out_buf[0].cbBuffer == impl_->sizes.cbHeader); | |
526 ASSERT(out_buf[1].cbBuffer == static_cast<unsigned long>(encrypt_len)); | |
527 | |
528 // However, the length of the trailer may change due to padding. | |
529 ASSERT(out_buf[2].cbBuffer <= impl_->sizes.cbTrailer); | |
530 | |
531 packet_len = out_buf[0].cbBuffer | |
532 + out_buf[1].cbBuffer | |
533 + out_buf[2].cbBuffer; | |
534 | |
535 written += encrypt_len; | |
536 outbuf.insert(outbuf.end(), &message[0], &message[packet_len-1]+1); | |
537 } | |
538 | |
539 if (int err = Flush()) { | |
540 state_ = SSL_ERROR; | |
541 SetError(err); | |
542 return SOCKET_ERROR; | |
543 } | |
544 | |
545 return static_cast<int>(written); | |
546 } | |
547 | |
548 int | |
549 SChannelAdapter::Recv(void* pv, size_t cb) { | |
550 switch (state_) { | |
551 case SSL_NONE: | |
552 return AsyncSocketAdapter::Recv(pv, cb); | |
553 | |
554 case SSL_WAIT: | |
555 case SSL_CONNECTING: | |
556 SetError(EWOULDBLOCK); | |
557 return SOCKET_ERROR; | |
558 | |
559 case SSL_CONNECTED: | |
560 break; | |
561 | |
562 case SSL_ERROR: | |
563 default: | |
564 return SOCKET_ERROR; | |
565 } | |
566 | |
567 SChannelBuffer& readable = impl_->readable; | |
568 if (readable.empty()) { | |
569 SetError(EWOULDBLOCK); | |
570 return SOCKET_ERROR; | |
571 } | |
572 size_t read = std::min(cb, readable.size()); | |
573 memcpy(pv, &readable[0], read); | |
574 if (size_t remaining = readable.size() - read) { | |
575 memmove(&readable[0], &readable[read], remaining); | |
576 readable.resize(remaining); | |
577 } else { | |
578 readable.clear(); | |
579 } | |
580 | |
581 PostEvent(); | |
582 return static_cast<int>(read); | |
583 } | |
584 | |
585 int | |
586 SChannelAdapter::Close() { | |
587 if (!impl_->readable.empty()) { | |
588 LOG(WARNING) << "SChannelAdapter::Close with readable data"; | |
589 // Note: this isn't strictly an error, but we're using it temporarily to | |
590 // track bugs. | |
591 //ASSERT(false); | |
592 } | |
593 if (state_ == SSL_CONNECTED) { | |
594 DWORD token = SCHANNEL_SHUTDOWN; | |
595 CSecBufferBundle<1> sb_in; | |
596 sb_in[0].BufferType = SECBUFFER_TOKEN; | |
597 sb_in[0].cbBuffer = sizeof(token); | |
598 sb_in[0].pvBuffer = &token; | |
599 ApplyControlToken(&impl_->ctx, sb_in.desc()); | |
600 // TODO: In theory, to do a nice shutdown, we need to begin shutdown | |
601 // negotiation with more calls to InitializeSecurityContext. Since the | |
602 // socket api doesn't support nice shutdown at this point, we don't bother. | |
603 } | |
604 Cleanup(); | |
605 impl_ = new SSLImpl; | |
606 state_ = restartable_ ? SSL_WAIT : SSL_NONE; | |
607 signal_close_ = false; | |
608 message_pending_ = false; | |
609 return AsyncSocketAdapter::Close(); | |
610 } | |
611 | |
612 Socket::ConnState | |
613 SChannelAdapter::GetState() const { | |
614 if (signal_close_) | |
615 return CS_CONNECTED; | |
616 ConnState state = socket_->GetState(); | |
617 if ((state == CS_CONNECTED) | |
618 && ((state_ == SSL_WAIT) || (state_ == SSL_CONNECTING))) | |
619 state = CS_CONNECTING; | |
620 return state; | |
621 } | |
622 | |
623 void | |
624 SChannelAdapter::OnConnectEvent(AsyncSocket* socket) { | |
625 LOG(LS_VERBOSE) << "SChannelAdapter::OnConnectEvent"; | |
626 if (state_ != SSL_WAIT) { | |
627 ASSERT(state_ == SSL_NONE); | |
628 AsyncSocketAdapter::OnConnectEvent(socket); | |
629 return; | |
630 } | |
631 | |
632 state_ = SSL_CONNECTING; | |
633 if (int err = BeginSSL()) { | |
634 Error("BeginSSL", err); | |
635 } | |
636 } | |
637 | |
638 void | |
639 SChannelAdapter::OnReadEvent(AsyncSocket* socket) { | |
640 if (state_ == SSL_NONE) { | |
641 AsyncSocketAdapter::OnReadEvent(socket); | |
642 return; | |
643 } | |
644 | |
645 if (int err = Read()) { | |
646 Error("Read", err); | |
647 return; | |
648 } | |
649 | |
650 if (impl_->inbuf.empty()) | |
651 return; | |
652 | |
653 if (state_ == SSL_CONNECTED) { | |
654 if (int err = DecryptData()) { | |
655 Error("DecryptData", err); | |
656 } else if (!impl_->readable.empty()) { | |
657 AsyncSocketAdapter::OnReadEvent(this); | |
658 } | |
659 } else if (state_ == SSL_CONNECTING) { | |
660 if (int err = ContinueSSL()) { | |
661 Error("ContinueSSL", err); | |
662 } | |
663 } | |
664 } | |
665 | |
666 void | |
667 SChannelAdapter::OnWriteEvent(AsyncSocket* socket) { | |
668 if (state_ == SSL_NONE) { | |
669 AsyncSocketAdapter::OnWriteEvent(socket); | |
670 return; | |
671 } | |
672 | |
673 if (int err = Flush()) { | |
674 Error("Flush", err); | |
675 return; | |
676 } | |
677 | |
678 // See if we have more data to write | |
679 if (!impl_->outbuf.empty()) | |
680 return; | |
681 | |
682 // Buffer is empty, submit notification | |
683 if (state_ == SSL_CONNECTED) { | |
684 AsyncSocketAdapter::OnWriteEvent(socket); | |
685 } | |
686 } | |
687 | |
688 void | |
689 SChannelAdapter::OnCloseEvent(AsyncSocket* socket, int err) { | |
690 if ((state_ == SSL_NONE) || impl_->readable.empty()) { | |
691 AsyncSocketAdapter::OnCloseEvent(socket, err); | |
692 return; | |
693 } | |
694 | |
695 // If readable is non-empty, then we have a pending Message | |
696 // that will allow us to signal close (eventually). | |
697 signal_close_ = true; | |
698 } | |
699 | |
700 void | |
701 SChannelAdapter::OnMessage(Message* pmsg) { | |
702 if (!message_pending_) | |
703 return; // This occurs when socket is closed | |
704 | |
705 message_pending_ = false; | |
706 if (!impl_->readable.empty()) { | |
707 AsyncSocketAdapter::OnReadEvent(this); | |
708 } else if (signal_close_) { | |
709 signal_close_ = false; | |
710 AsyncSocketAdapter::OnCloseEvent(this, 0); // TODO: cache this error? | |
711 } | |
712 } | |
713 | |
714 } // namespace rtc | |
OLD | NEW |