Commit 704421b7 by ziyue

完善

parent 581ebfad
...@@ -49,4 +49,3 @@ else() ...@@ -49,4 +49,3 @@ else()
endif() endif()
target_link_libraries(MediaServer jsoncpp ${LINK_LIB_LIST}) target_link_libraries(MediaServer jsoncpp ${LINK_LIB_LIST})
message(${LINK_LIB_LIST})
//
// Created by xueyuegui on 19-12-7.
//
#include "dtls_transport.h"
#include <iostream>
DtlsTransport::DtlsTransport(bool is_server) : is_server_(is_server) {
dtls_transport_.reset(new RTC::DtlsTransport(this));
}
DtlsTransport::~DtlsTransport() {}
void DtlsTransport::Start() {
if (is_server_) {
dtls_transport_->Run(RTC::DtlsTransport::Role::SERVER);
} else {
dtls_transport_->Run(RTC::DtlsTransport::Role::CLIENT);
}
}
void DtlsTransport::Close() {}
void DtlsTransport::OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) {}
void DtlsTransport::OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport,
RTC::CryptoSuite srtp_crypto_suite,
uint8_t *srtpLocalKey, size_t srtpLocalKeyLen,
uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen,
std::string &remoteCert) {
std::string client_key;
std::string server_key;
server_key.assign((char *) srtpLocalKey, srtpLocalKeyLen);
client_key.assign((char *) srtpRemoteKey, srtpRemoteKeyLen);
if (is_server_) {
// If we are server, we swap the keys
client_key.swap(server_key);
}
if (handshake_completed_callback_) {
handshake_completed_callback_(client_key, server_key, srtp_crypto_suite);
}
}
void DtlsTransport::OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) {
if (handshake_failed_callback_) {
handshake_failed_callback_();
}
}
void DtlsTransport::OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) {}
void DtlsTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport,
const uint8_t *data, size_t len) {
if (output_callback_) {
output_callback_((char *) data, len);
}
}
void DtlsTransport::OutputData(char *buf, size_t len) {
if (output_callback_) {
output_callback_(buf, len);
}
}
void DtlsTransport::OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport,
const uint8_t *data, size_t len) {}
bool DtlsTransport::IsDtlsPacket(const char *buf, size_t len) {
return RTC::DtlsTransport::IsDtls((uint8_t *) buf, len);
}
void DtlsTransport::InputData(char *buf, size_t len) {
dtls_transport_->ProcessDtlsData((uint8_t *) buf, len);
}
//
// Created by xueyuegui on 19-12-7.
//
#ifndef MYWEBRTC_MYDTLSTRANSPORT_H
#define MYWEBRTC_MYDTLSTRANSPORT_H
#include <functional>
#include <memory>
#include "rtc_dtls_transport.h"
class DtlsTransport : RTC::DtlsTransport::Listener {
public:
typedef std::shared_ptr<DtlsTransport> Ptr;
DtlsTransport(bool bServer);
~DtlsTransport();
void Start();
void Close();
void InputData(char *buf, size_t len);
void OutputData(char *buf, size_t len);
static bool IsDtlsPacket(const char *buf, size_t len);
std::string GetMyFingerprint() {
auto finger_prints = dtls_transport_->GetLocalFingerprints();
for (size_t i = 0; i < finger_prints.size(); i++) {
if (finger_prints[i].algorithm == RTC::DtlsTransport::FingerprintAlgorithm::SHA256) {
return finger_prints[i].value;
}
}
return "";
};
void SetHandshakeCompletedCB(std::function<void(std::string clientKey, std::string serverKey, RTC::CryptoSuite)> cb) {
handshake_completed_callback_ = std::move(cb);
}
void SetHandshakeFailedCB(std::function<void()> cb) { handshake_failed_callback_ = std::move(cb); }
void SetOutPutCB(std::function<void(char *buf, size_t len)> cb) { output_callback_ = std::move(cb); }
/* Pure virtual methods inherited from RTC::DtlsTransport::Listener. */
public:
void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override;
void OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, RTC::CryptoSuite srtpCryptoSuite, uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) override;
void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override;
void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override;
void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data,size_t len) override;
void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override;
private:
bool is_server_ = false;
std::function<void()> handshake_failed_callback_;
std::shared_ptr<RTC::DtlsTransport> dtls_transport_;
std::function<void(char *buf, size_t len)> output_callback_;
std::function<void(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite)> handshake_completed_callback_;
};
#endif// MYWEBRTC_MYDTLSTRANSPORT_H
#define MS_CLASS "RTC::IceServer"
// #define MS_LOG_DEV_LEVEL 3
#include <utility>
#include "ice_server.h" #include "ice_server.h"
#include <iostream> namespace RTC
{
static constexpr size_t StunSerializeBufferSize{65536}; /* Static. */
static uint8_t StunSerializeBuffer[StunSerializeBufferSize];
static constexpr size_t StunSerializeBufferSize{ 65536 };
IceServer::IceServer() {} static uint8_t StunSerializeBuffer[StunSerializeBufferSize];
IceServer::~IceServer() {} /* Instance methods. */
IceServer::IceServer(const std::string &username_fragment, const std::string &password) IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password)
: username_fragment_(username_fragment), password_(password) {} : listener(listener), usernameFragment(usernameFragment), password(password)
{
void IceServer::ProcessStunPacket(RTC::StunPacket *packet, sockaddr_in *remote_address) { MS_TRACE();
// Must be a Binding method. }
if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) {
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) { void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple)
ELOG_WARN("unknown method %#.3x in STUN Request => 400", {
static_cast<unsigned int>(packet->GetMethod())); MS_TRACE();
ELOG_WARN("unknown method %#.3x in STUN Request => 400",
static_cast<unsigned int>(packet->GetMethod())); // Must be a Binding method.
// Reply 400. if (packet->GetMethod() != RTC::StunPacket::Method::BINDING)
RTC::StunPacket *response = packet->CreateErrorResponse(400); {
response->Serialize(StunSerializeBuffer); if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
if (send_callback_) { {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); MS_WARN_TAG(
} ice,
delete response; "unknown method %#.3x in STUN Request => 400",
} else { static_cast<unsigned int>(packet->GetMethod()));
ELOG_WARN("ignoring STUN Indication or Response with unknown method %#.3x",
static_cast<unsigned int>(packet->GetMethod())); // Reply 400.
} RTC::StunPacket* response = packet->CreateErrorResponse(400);
return;
} response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
// Must use FINGERPRINT (optional for ICE STUN indications).
if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) { delete response;
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) { }
ELOG_WARN("STUN Binding Request without FINGERPRINT => 400"); else
// Reply 400. {
RTC::StunPacket *response = packet->CreateErrorResponse(400); MS_WARN_TAG(
response->Serialize(StunSerializeBuffer); ice,
if (send_callback_) { "ignoring STUN Indication or Response with unknown method %#.3x",
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); static_cast<unsigned int>(packet->GetMethod()));
} }
delete response;
} else { return;
ELOG_WARN("ignoring STUN Binding Response without FINGERPRINT"); }
}
return; // Must use FINGERPRINT (optional for ICE STUN indications).
} if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION)
{
switch (packet->GetClass()) { if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
case RTC::StunPacket::Class::REQUEST: { {
// USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400");
if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) ||
packet->GetUsername().empty()) { // Reply 400.
ELOG_WARN("mising required attributes in STUN Binding Request => 400"); RTC::StunPacket* response = packet->CreateErrorResponse(400);
// Reply 400. response->Serialize(StunSerializeBuffer);
RTC::StunPacket *response = packet->CreateErrorResponse(400); this->listener->OnIceServerSendStunPacket(this, response, tuple);
response->Serialize(StunSerializeBuffer);
if (send_callback_) { delete response;
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); }
} else
delete response; {
return; MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT");
} }
// Check authentication. return;
switch (packet->CheckAuthentication(this->username_fragment_, this->password_)) { }
case RTC::StunPacket::Authentication::OK: {
if (!this->old_password_.empty()) { switch (packet->GetClass())
ELOG_DEBUG("kNew ICE credentials applied"); {
this->old_username_fragment_.clear(); case RTC::StunPacket::Class::REQUEST:
this->old_password_.clear(); {
} // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required.
break; if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty())
} {
MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400");
case RTC::StunPacket::Authentication::UNAUTHORIZED: {
// We may have changed our username_fragment_ and password_, so check // Reply 400.
// the old ones. RTC::StunPacket* response = packet->CreateErrorResponse(400);
// clang-format off
if (!this->old_username_fragment_.empty() && response->Serialize(StunSerializeBuffer);
!this->old_password_.empty() && this->listener->OnIceServerSendStunPacket(this, response, tuple);
packet->CheckAuthentication(this->old_username_fragment_, this->old_password_) ==
RTC::StunPacket::Authentication::OK) { delete response;
ELOG_DEBUG("using old ICE credentials");
break; return;
} }
ELOG_WARN("wrong authentication in STUN Binding Request => 401");
// Reply 401. // Check authentication.
RTC::StunPacket *response = packet->CreateErrorResponse(401); switch (packet->CheckAuthentication(this->usernameFragment, this->password))
response->Serialize(StunSerializeBuffer); {
if (send_callback_) { case RTC::StunPacket::Authentication::OK:
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); {
} if (!this->oldPassword.empty())
delete response; {
return; MS_DEBUG_TAG(ice, "new ICE credentials applied");
}
this->oldUsernameFragment.clear();
case RTC::StunPacket::Authentication::BAD_REQUEST: { this->oldPassword.clear();
ELOG_WARN("cannot check authentication in STUN Binding Request => 400"); }
// Reply 400.
RTC::StunPacket *response = packet->CreateErrorResponse(400); break;
response->Serialize(StunSerializeBuffer); }
if (send_callback_) {
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); case RTC::StunPacket::Authentication::UNAUTHORIZED:
} {
delete response; // We may have changed our usernameFragment and password, so check
return; // the old ones.
} // clang-format off
} if (
!this->oldUsernameFragment.empty() &&
!this->oldPassword.empty() &&
packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK
)
// clang-format on
{
MS_DEBUG_TAG(ice, "using old ICE credentials");
break;
}
MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401");
// Reply 401.
RTC::StunPacket* response = packet->CreateErrorResponse(401);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
case RTC::StunPacket::Authentication::BAD_REQUEST:
{
MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
}
#if 0 #if 0
// NOTE: Should be rejected with 487, but this makes Chrome happy: // The remote peer must be ICE controlling.
// https://bugs.chromium.org/p/webrtc/issues/detail?id=7478 if (packet->GetIceControlled())
// The remote peer must be ICE controlling. {
if (packet->GetIceControlled()) { MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487");
MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487");
// Reply 487 (Role Conflict). // Reply 487 (Role Conflict).
RTC::StunPacket *response = packet->CreateErrorResponse(487); RTC::StunPacket* response = packet->CreateErrorResponse(487);
response->Serialize(StunSerializeBuffer);
if (send_callback_) { response->Serialize(StunSerializeBuffer);
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); this->listener->OnIceServerSendStunPacket(this, response, tuple);
}
delete response; delete response;
return;
} return;
}
#endif #endif
ELOG_DEBUG("processing STUN Binding Request [Priority:%d, UseCandidate:%s]", MS_DEBUG_DEV(
static_cast<uint32_t>(packet->GetPriority()), "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]",
(packet->HasUseCandidate() ? "true" : "false")); static_cast<uint32_t>(packet->GetPriority()),
// Create a success response. packet->HasUseCandidate() ? "true" : "false");
RTC::StunPacket *response = packet->CreateSuccessResponse();
// Add XOR-MAPPED-ADDRESS. // Create a success response.
// response->SetXorMappedAddress(tuple->GetRemoteAddress()); RTC::StunPacket* response = packet->CreateSuccessResponse();
response->SetXorMappedAddress((struct sockaddr *) remote_address);
// Authenticate the response. // Add XOR-MAPPED-ADDRESS.
if (this->old_password_.empty()) { response->SetXorMappedAddress(tuple);
response->Authenticate(this->password_);
} else { // Authenticate the response.
response->Authenticate(this->old_password_); if (this->oldPassword.empty())
} response->Authenticate(this->password);
else
// Send back. response->Authenticate(this->oldPassword);
response->Serialize(StunSerializeBuffer);
if (send_callback_) { // Send back.
send_callback_((char *) StunSerializeBuffer, response->GetSize(), remote_address); response->Serialize(StunSerializeBuffer);
} this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
// Handle the tuple. delete response;
HandleTuple(remote_address, packet->HasUseCandidate());
break; // Handle the tuple.
} HandleTuple(tuple, packet->HasUseCandidate());
case RTC::StunPacket::Class::INDICATION: { break;
ELOG_DEBUG("STUN Binding Indication processed"); }
break;
} case RTC::StunPacket::Class::INDICATION:
{
case RTC::StunPacket::Class::SUCCESS_RESPONSE: { MS_DEBUG_TAG(ice, "STUN Binding Indication processed");
ELOG_DEBUG("STUN Binding Success Response processed");
break; break;
} }
case RTC::StunPacket::Class::ERROR_RESPONSE: { case RTC::StunPacket::Class::SUCCESS_RESPONSE:
ELOG_DEBUG("STUN Binding Error Response processed"); {
break; MS_DEBUG_TAG(ice, "STUN Binding Success Response processed");
}
} break;
} }
void IceServer::HandleTuple(sockaddr_in *remote_address, bool has_use_candidate) {
remote_address_ = *remote_address; case RTC::StunPacket::Class::ERROR_RESPONSE:
if (has_use_candidate) { {
this->state = IceState::kCompleted; MS_DEBUG_TAG(ice, "STUN Binding Error Response processed");
}
if (ice_server_completed_callback_) { break;
ice_server_completed_callback_(); }
ice_server_completed_callback_ = nullptr; }
} }
}
bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const
const std::string &IceServer::GetUsernameFragment() const { return this->username_fragment_; } {
MS_TRACE();
const std::string &IceServer::GetPassword() const { return this->password_; }
return HasTuple(tuple) != nullptr;
inline void IceServer::SetUsernameFragment(const std::string &username_fragment) { }
this->old_username_fragment_ = this->username_fragment_;
this->username_fragment_ = username_fragment; void IceServer::RemoveTuple(RTC::TransportTuple* tuple)
} {
MS_TRACE();
inline void IceServer::SetPassword(const std::string &password) {
this->old_password_ = this->password_; RTC::TransportTuple* removedTuple{ nullptr };
this->password_ = password;
} // Find the removed tuple.
auto it = this->tuples.begin();
inline IceServer::IceState IceServer::GetState() const { return this->state; }
\ No newline at end of file for (; it != this->tuples.end(); ++it)
{
RTC::TransportTuple* storedTuple = std::addressof(*it);
if (memcmp(storedTuple, tuple, sizeof (RTC::TransportTuple)) == 0)
{
removedTuple = storedTuple;
break;
}
}
// If not found, ignore.
if (!removedTuple)
return;
// Remove from the list of tuples.
this->tuples.erase(it);
// If this is not the selected tuple, stop here.
if (removedTuple != this->selectedTuple)
return;
// Otherwise this was the selected tuple.
this->selectedTuple = nullptr;
// Mark the first tuple as selected tuple (if any).
if (this->tuples.begin() != this->tuples.end())
{
SetSelectedTuple(std::addressof(*this->tuples.begin()));
}
// Or just emit 'disconnected'.
else
{
// Update state.
this->state = IceState::DISCONNECTED;
// Notify the listener.
this->listener->OnIceServerDisconnected(this);
}
}
void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple)
{
MS_TRACE();
MS_ASSERT(
this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple");
auto* storedTuple = HasTuple(tuple);
MS_ASSERT(
storedTuple,
"cannot force the selected tuple if the given tuple was not already a valid tuple");
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
}
void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate)
{
MS_TRACE();
switch (this->state)
{
case IceState::NEW:
{
// There should be no tuples.
MS_ASSERT(
this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size());
// There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple");
if (!hasUseCandidate)
{
MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::CONNECTED;
// Notify the listener.
this->listener->OnIceServerConnected(this);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::DISCONNECTED:
{
// There should be no tuples.
MS_ASSERT(
this->tuples.empty(),
"state is 'disconnected' but there are %zu tuples",
this->tuples.size());
// There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple");
if (!hasUseCandidate)
{
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::CONNECTED;
// Notify the listener.
this->listener->OnIceServerConnected(this);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::CONNECTED:
{
// There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples");
// There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple");
if (!hasUseCandidate)
{
// If a new tuple store it.
if (!HasTuple(tuple))
AddTuple(tuple);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'");
auto* storedTuple = HasTuple(tuple);
// If a new tuple store it.
if (!storedTuple)
storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::COMPLETED:
{
// There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples");
// There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple");
if (!hasUseCandidate)
{
// If a new tuple store it.
if (!HasTuple(tuple))
AddTuple(tuple);
}
else
{
auto* storedTuple = HasTuple(tuple);
// If a new tuple store it.
if (!storedTuple)
storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
}
break;
}
}
}
inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple)
{
MS_TRACE();
// Add the new tuple at the beginning of the list.
this->tuples.push_front(*tuple);
auto* storedTuple = std::addressof(*this->tuples.begin());
// Return the address of the inserted tuple.
return storedTuple;
}
inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const
{
MS_TRACE();
// If there is no selected tuple yet then we know that the tuples list
// is empty.
if (!this->selectedTuple)
return nullptr;
// Check the current selected tuple.
if (memcmp(selectedTuple, tuple, sizeof (RTC::TransportTuple)) == 0)
return this->selectedTuple;
// Otherwise check other stored tuples.
for (const auto& it : this->tuples)
{
auto* storedTuple = const_cast<RTC::TransportTuple*>(std::addressof(it));
if (memcmp(storedTuple, tuple, sizeof (RTC::TransportTuple)) == 0)
return storedTuple;
}
return nullptr;
}
inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple)
{
MS_TRACE();
// If already the selected tuple do nothing.
if (storedTuple == this->selectedTuple)
return;
this->selectedTuple = storedTuple;
// Notify the listener.
this->listener->OnIceServerSelectedTuple(this, this->selectedTuple);
}
} // namespace RTC
#pragma once #ifndef MS_RTC_ICE_SERVER_HPP
#define MS_RTC_ICE_SERVER_HPP
#include "stun_packet.h"
#include "logger.h"
#include <list>
#include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "logger.h" namespace RTC
#include "stun_packet.h" {
using TransportTuple = struct sockaddr;
class IceServer
{
public:
enum class IceState
{
NEW = 1,
CONNECTED,
COMPLETED,
DISCONNECTED
};
public:
class Listener
{
public:
virtual ~Listener() = default;
public:
/**
* These callbacks are guaranteed to be called before ProcessStunPacket()
* returns, so the given pointers are still usable.
*/
virtual void OnIceServerSendStunPacket(
const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerSelectedTuple(
const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0;
};
public:
IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password);
public:
void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple);
const std::string& GetUsernameFragment() const
{
return this->usernameFragment;
}
const std::string& GetPassword() const
{
return this->password;
}
IceState GetState() const
{
return this->state;
}
RTC::TransportTuple* GetSelectedTuple() const
{
return this->selectedTuple;
}
void SetUsernameFragment(const std::string& usernameFragment)
{
this->oldUsernameFragment = this->usernameFragment;
this->usernameFragment = usernameFragment;
}
void SetPassword(const std::string& password)
{
this->oldPassword = this->password;
this->password = password;
}
bool IsValidTuple(const RTC::TransportTuple* tuple) const;
void RemoveTuple(RTC::TransportTuple* tuple);
// This should be just called in 'connected' or completed' state
// and the given tuple must be an already valid tuple.
void ForceSelectedTuple(const RTC::TransportTuple* tuple);
private:
void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate);
/**
* Store the given tuple and return its stored address.
*/
RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple);
/**
* If the given tuple exists return its stored address, nullptr otherwise.
*/
RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const;
/**
* Set the given tuple as the selected tuple.
* NOTE: The given tuple MUST be already stored within the list.
*/
void SetSelectedTuple(RTC::TransportTuple* storedTuple);
private:
// Passed by argument.
Listener* listener{ nullptr };
// Others.
std::string usernameFragment;
std::string password;
std::string oldUsernameFragment;
std::string oldPassword;
IceState state{ IceState::NEW };
std::list<RTC::TransportTuple> tuples;
RTC::TransportTuple* selectedTuple{ nullptr };
};
} // namespace RTC
typedef std::function<void(char *buf, size_t len, struct sockaddr_in *remote_address)> UdpSendCallback; #endif
class IceServer {
public:
enum class IceState { kNew = 1, kConnect, kCompleted, kDisconnected };
typedef std::shared_ptr<IceServer> Ptr;
IceServer();
IceServer(const std::string &username_fragment, const std::string &password);
const std::string &GetUsernameFragment() const;
const std::string &GetPassword() const;
void SetUsernameFragment(const std::string &username_fragment);
void SetPassword(const std::string &password);
IceState GetState() const;
void ProcessStunPacket(RTC::StunPacket *packet, struct sockaddr_in *remote_address);
void HandleTuple(struct sockaddr_in *remote_address, bool has_use_candidate);
~IceServer();
void SetSendCB(UdpSendCallback send_cb) { send_callback_ = send_cb; }
void SetIceServerCompletedCB(std::function<void()> cb) { ice_server_completed_callback_ = cb; };
struct sockaddr_in *GetSelectAddr() {
return &remote_address_;
}
private:
UdpSendCallback send_callback_;
std::function<void()> ice_server_completed_callback_;
std::string username_fragment_;
std::string password_;
std::string old_username_fragment_;
std::string old_password_;
IceState state{IceState::kNew};
struct sockaddr_in remote_address_;
};
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#define MS_DEBUG_2TAGS(tag1, tag2,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) #define MS_DEBUG_2TAGS(tag1, tag2,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__)
#define MS_WARN_2TAGS(tag1, tag2,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__) #define MS_WARN_2TAGS(tag1, tag2,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__)
#define MS_DEBUG_TAG(tag,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) #define MS_DEBUG_TAG(tag,fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__)
#define MS_ASSERT(con, log) assert(con) #define MS_ASSERT(con, fmt, ...) do{if(!(con)) { printf("assert failed:%s" fmt "\n", #con, ##__VA_ARGS__);} assert(con); } while(false);
#define MS_ABORT(fmt, ...) do{ printf("abort:" fmt "\n", ##__VA_ARGS__); abort(); } while(false); #define MS_ABORT(fmt, ...) do{ printf("abort:" fmt "\n", ##__VA_ARGS__); abort(); } while(false);
#define MS_WARN_TAG(tag,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__) #define MS_WARN_TAG(tag,fmt, ...) printf("warn:" fmt "\n", ##__VA_ARGS__)
#define MS_DEBUG_DEV(fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__) #define MS_DEBUG_DEV(fmt, ...) printf("debug:" fmt "\n", ##__VA_ARGS__)
\ No newline at end of file
...@@ -2,62 +2,56 @@ ...@@ -2,62 +2,56 @@
// #define MS_LOG_DEV_LEVEL 3 // #define MS_LOG_DEV_LEVEL 3
#include "rtc_dtls_transport.h" #include "rtc_dtls_transport.h"
#include "logger.h"
#include <openssl/asn1.h> #include <openssl/asn1.h>
#include <openssl/bn.h> #include <openssl/bn.h>
#include <openssl/err.h> #include <openssl/err.h>
#include <openssl/evp.h> #include <openssl/evp.h>
#include <openssl/rsa.h> #include <openssl/rsa.h>
#include <cstdio> // std::sprintf(), std::fopen()
#include <cstdio> // std::sprintf(), std::fopen() #include <cstring> // std::memcpy(), std::strcmp()
#include <cstring> // std::memcpy(), std::strcmp()
#define LOG_OPENSSL_ERROR(desc) \
#include "logger.h" do \
{ \
typedef struct { if (ERR_peek_error() == 0) \
long tv_sec; MS_ERROR("OpenSSL error [desc:'%s']", desc); \
long tv_usec; else \
} uv_timeval_t; { \
int64_t err; \
#define LOG_OPENSSL_ERROR(desc) \ while ((err = ERR_get_error()) != 0) \
do { \ { \
if (ERR_peek_error() == 0) \ MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \
MS_ERROR("OpenSSL error [desc:'%s']", desc); \ } \
else { \ ERR_clear_error(); \
int64_t err; \ } \
while ((err = ERR_get_error()) != 0) { \ } while (false)
MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \
} \
ERR_clear_error(); \
} \
} while (false)
/* Static methods for OpenSSL callbacks. */ /* Static methods for OpenSSL callbacks. */
inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) { inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/)
MS_TRACE(); {
MS_TRACE();
// Always valid since DTLS certificates are self-signed.
return 1;
}
inline static void onSslInfo(const SSL* ssl, int where, int ret) { // Always valid since DTLS certificates are self-signed.
static_cast<RTC::DtlsTransport*>(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); return 1;
} }
inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) { inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs)
if (timerUs == 0) {
return 100000; if (timerUs == 0)
else if (timerUs >= 4000000) return 100000;
return 4000000; else if (timerUs >= 4000000)
else return 4000000;
return 2 * timerUs; else
return 2 * timerUs;
} }
namespace RTC { namespace RTC
/* Static. */ {
/* Static. */
// clang-format off // clang-format off
static constexpr int DtlsMtu{ 1350 }; static constexpr int DtlsMtu{ 1350 };
static constexpr int SslReadBufferSize{ 65536 }; static constexpr int SslReadBufferSize{ 65536 };
// AES-HMAC: http://tools.ietf.org/html/rfc3711 // AES-HMAC: http://tools.ietf.org/html/rfc3711
...@@ -71,15 +65,15 @@ namespace RTC { ...@@ -71,15 +65,15 @@ namespace RTC {
static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 };
static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength };
// clang-format on // clang-format on
/* Class variables. */ /* Class variables. */
X509* DtlsTransport::certificate{nullptr}; X509* DtlsTransport::certificate{ nullptr };
EVP_PKEY* DtlsTransport::privateKey{nullptr}; EVP_PKEY* DtlsTransport::privateKey{ nullptr };
SSL_CTX* DtlsTransport::sslCtx{nullptr}; SSL_CTX* DtlsTransport::sslCtx{ nullptr };
uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize]; uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize];
// clang-format off // clang-format off
std::map<std::string, DtlsTransport::FingerprintAlgorithm> DtlsTransport::string2FingerprintAlgorithm = std::map<std::string, DtlsTransport::FingerprintAlgorithm> DtlsTransport::string2FingerprintAlgorithm =
{ {
{ "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 },
...@@ -105,1219 +99,1376 @@ uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize]; ...@@ -105,1219 +99,1376 @@ uint8_t DtlsTransport::sslReadBuffer[SslReadBufferSize];
std::vector<DtlsTransport::Fingerprint> DtlsTransport::localFingerprints; std::vector<DtlsTransport::Fingerprint> DtlsTransport::localFingerprints;
std::vector<DtlsTransport::SrtpCryptoSuiteMapEntry> DtlsTransport::srtpCryptoSuites = std::vector<DtlsTransport::SrtpCryptoSuiteMapEntry> DtlsTransport::srtpCryptoSuites =
{ {
{ RTC::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" },
{ RTC::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" },
{ RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" },
{ RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" }
}; };
// clang-format on // clang-format on
/* Class methods. */ /* Class methods. */
void DtlsTransport::ClassInit() { void DtlsTransport::ClassInit()
MS_TRACE(); {
MS_TRACE();
#if 0
// Generate a X509 certificate and private key (unless PEM files are provided).
if (Settings::configuration.dtlsCertificateFile.empty() ||
Settings::configuration.dtlsPrivateKeyFile.empty()) {
GenerateCertificateAndPrivateKey();
} else {
ReadCertificateAndPrivateKeyFromFiles();
}
#else
GenerateCertificateAndPrivateKey();
#endif
// Create a global SSL_CTX. // Generate a X509 certificate and private key (unless PEM files are provided).
CreateSslCtx(); if (true /*
Settings::configuration.dtlsCertificateFile.empty() ||
Settings::configuration.dtlsPrivateKeyFile.empty()*/)
{
GenerateCertificateAndPrivateKey();
}
else
{
ReadCertificateAndPrivateKeyFromFiles();
}
// Generate certificate fingerprints. // Create a global SSL_CTX.
GenerateFingerprints(); CreateSslCtx();
}
void DtlsTransport::ClassDestroy() { // Generate certificate fingerprints.
MS_TRACE(); GenerateFingerprints();
}
if (DtlsTransport::privateKey) EVP_PKEY_free(DtlsTransport::privateKey); void DtlsTransport::ClassDestroy()
if (DtlsTransport::certificate) X509_free(DtlsTransport::certificate); {
if (DtlsTransport::sslCtx) SSL_CTX_free(DtlsTransport::sslCtx); MS_TRACE();
}
void DtlsTransport::GenerateCertificateAndPrivateKey() { if (DtlsTransport::privateKey)
MS_TRACE(); EVP_PKEY_free(DtlsTransport::privateKey);
if (DtlsTransport::certificate)
X509_free(DtlsTransport::certificate);
if (DtlsTransport::sslCtx)
SSL_CTX_free(DtlsTransport::sslCtx);
}
int ret{0}; void DtlsTransport::GenerateCertificateAndPrivateKey()
EC_KEY* ecKey{nullptr}; {
X509_NAME* certName{nullptr}; MS_TRACE();
std::string subject = std::string("mediasoup") + std::to_string(rand() % 999999 + 100000);
// std::string("mediasoup") + std::to_string(Utils::Crypto::GetRandomUInt(100000, 999999));
// Create key with curve. int ret{ 0 };
ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); EC_KEY* ecKey{ nullptr };
X509_NAME* certName{ nullptr };
std::string subject =
std::string("mediasoup") + std::to_string(rand() % 999999 + 100000);
if (!ecKey) { // Create key with curve.
LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
goto error; if (!ecKey)
} {
LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed");
EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); goto error;
}
// NOTE: This can take some time. EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE);
ret = EC_KEY_generate_key(ecKey);
if (ret == 0) { // NOTE: This can take some time.
LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); ret = EC_KEY_generate_key(ecKey);
goto error; if (ret == 0)
} {
LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed");
// Create a private key object. goto error;
DtlsTransport::privateKey = EVP_PKEY_new(); }
if (!DtlsTransport::privateKey) { // Create a private key object.
LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); DtlsTransport::privateKey = EVP_PKEY_new();
goto error; if (!DtlsTransport::privateKey)
} {
LOG_OPENSSL_ERROR("EVP_PKEY_new() failed");
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) goto error;
ret = EVP_PKEY_assign_EC_KEY(DtlsTransport::privateKey, ecKey); }
if (ret == 0) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); ret = EVP_PKEY_assign_EC_KEY(DtlsTransport::privateKey, ecKey);
goto error; if (ret == 0)
} {
LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed");
// The EC key now belongs to the private key, so don't clean it up separately. goto error;
ecKey = nullptr; }
// Create the X509 certificate. // The EC key now belongs to the private key, so don't clean it up separately.
DtlsTransport::certificate = X509_new(); ecKey = nullptr;
if (!DtlsTransport::certificate) { // Create the X509 certificate.
LOG_OPENSSL_ERROR("X509_new() failed"); DtlsTransport::certificate = X509_new();
goto error; if (!DtlsTransport::certificate)
} {
LOG_OPENSSL_ERROR("X509_new() failed");
// Set version 3 (note that 0 means version 1). goto error;
X509_set_version(DtlsTransport::certificate, 2); }
// Set serial number (avoid default 0). // Set version 3 (note that 0 means version 1).
// ASN1_INTEGER_set(X509_get_serialNumber(DtlsTransport::certificate), X509_set_version(DtlsTransport::certificate, 2);
// static_cast<uint64_t>(Utils::Crypto::GetRandomUInt(1000000, 9999999)));
ASN1_INTEGER_set(X509_get_serialNumber(DtlsTransport::certificate),
static_cast<uint64_t>(rand() % 999999 + 100000));
// Set valid period. // Set serial number (avoid default 0).
X509_gmtime_adj(X509_get_notBefore(DtlsTransport::certificate), -315360000); // -10 years. ASN1_INTEGER_set(
X509_gmtime_adj(X509_get_notAfter(DtlsTransport::certificate), 315360000); // 10 years. X509_get_serialNumber(DtlsTransport::certificate),
static_cast<uint64_t>(rand() % 999999 + 100000));
// Set the public key for the certificate using the key. // Set valid period.
ret = X509_set_pubkey(DtlsTransport::certificate, DtlsTransport::privateKey); X509_gmtime_adj(X509_get_notBefore(DtlsTransport::certificate), -315360000); // -10 years.
X509_gmtime_adj(X509_get_notAfter(DtlsTransport::certificate), 315360000); // 10 years.
if (ret == 0) { // Set the public key for the certificate using the key.
LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); ret = X509_set_pubkey(DtlsTransport::certificate, DtlsTransport::privateKey);
goto error; if (ret == 0)
} {
LOG_OPENSSL_ERROR("X509_set_pubkey() failed");
// Set certificate fields. goto error;
certName = X509_get_subject_name(DtlsTransport::certificate); }
if (!certName) { // Set certificate fields.
LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); certName = X509_get_subject_name(DtlsTransport::certificate);
goto error; if (!certName)
} {
LOG_OPENSSL_ERROR("X509_get_subject_name() failed");
X509_NAME_add_entry_by_txt(certName, "O", MBSTRING_ASC, goto error;
reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0); }
X509_NAME_add_entry_by_txt(certName, "CN", MBSTRING_ASC,
reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
// It is self-signed so set the issuer name to be the same as the subject. X509_NAME_add_entry_by_txt(
ret = X509_set_issuer_name(DtlsTransport::certificate, certName); certName, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
X509_NAME_add_entry_by_txt(
certName, "CN", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
if (ret == 0) { // It is self-signed so set the issuer name to be the same as the subject.
LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); ret = X509_set_issuer_name(DtlsTransport::certificate, certName);
goto error; if (ret == 0)
} {
LOG_OPENSSL_ERROR("X509_set_issuer_name() failed");
// Sign the certificate with its own private key. goto error;
ret = X509_sign(DtlsTransport::certificate, DtlsTransport::privateKey, EVP_sha1()); }
if (ret == 0) { // Sign the certificate with its own private key.
LOG_OPENSSL_ERROR("X509_sign() failed"); ret = X509_sign(DtlsTransport::certificate, DtlsTransport::privateKey, EVP_sha1());
goto error; if (ret == 0)
} {
LOG_OPENSSL_ERROR("X509_sign() failed");
return; goto error;
}
error: return;
if (ecKey) EC_KEY_free(ecKey); error:
if (DtlsTransport::privateKey) if (ecKey)
EVP_PKEY_free(DtlsTransport::privateKey); // NOTE: This also frees the EC key. EC_KEY_free(ecKey);
if (DtlsTransport::certificate) X509_free(DtlsTransport::certificate); if (DtlsTransport::privateKey)
EVP_PKEY_free(DtlsTransport::privateKey); // NOTE: This also frees the EC key.
MS_THROW_ERROR("DTLS certificate and private key generation failed"); if (DtlsTransport::certificate)
} X509_free(DtlsTransport::certificate);
MS_THROW_ERROR("DTLS certificate and private key generation failed");
}
void DtlsTransport::ReadCertificateAndPrivateKeyFromFiles() { void DtlsTransport::ReadCertificateAndPrivateKeyFromFiles()
{
#if 0 #if 0
MS_TRACE(); MS_TRACE();
FILE* file{nullptr}; FILE* file{ nullptr };
file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r");
if (!file) { if (!file)
MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); {
MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno));
goto error; goto error;
} }
DtlsTransport::certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); DtlsTransport::certificate = PEM_read_X509(file, nullptr, nullptr, nullptr);
if (!DtlsTransport::certificate) { if (!DtlsTransport::certificate)
LOG_OPENSSL_ERROR("PEM_read_X509() failed"); {
LOG_OPENSSL_ERROR("PEM_read_X509() failed");
goto error; goto error;
} }
fclose(file); fclose(file);
file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r");
if (!file) { if (!file)
MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); {
MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno));
goto error; goto error;
} }
DtlsTransport::privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); DtlsTransport::privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr);
if (!DtlsTransport::privateKey) { if (!DtlsTransport::privateKey)
LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); {
LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed");
goto error; goto error;
} }
fclose(file); fclose(file);
return; return;
error: error:
MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); MS_THROW_ERROR("error reading DTLS certificate and private key PEM files");
#endif #endif
} }
void DtlsTransport::CreateSslCtx() {
MS_TRACE();
std::string dtlsSrtpCryptoSuites;
int ret;
/* Set the global DTLS context. */
// Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0).
DtlsTransport::sslCtx = SSL_CTX_new(DTLS_method());
if (!DtlsTransport::sslCtx) {
LOG_OPENSSL_ERROR("SSL_CTX_new() failed");
goto error;
}
ret = SSL_CTX_use_certificate(DtlsTransport::sslCtx, DtlsTransport::certificate);
if (ret == 0) {
LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed");
goto error;
}
ret = SSL_CTX_use_PrivateKey(DtlsTransport::sslCtx, DtlsTransport::privateKey);
if (ret == 0) { void DtlsTransport::CreateSslCtx()
LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); {
MS_TRACE();
goto error;
}
ret = SSL_CTX_check_private_key(DtlsTransport::sslCtx);
if (ret == 0) {
LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed");
goto error;
}
// Set options.
SSL_CTX_set_options(DtlsTransport::sslCtx, SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET |
SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_QUERY_MTU);
// Don't use sessions cache.
SSL_CTX_set_session_cache_mode(DtlsTransport::sslCtx, SSL_SESS_CACHE_OFF);
// Read always as much into the buffer as possible.
// NOTE: This is the default for DTLS, but a bug in non latest OpenSSL
// versions makes this call required.
SSL_CTX_set_read_ahead(DtlsTransport::sslCtx, 1);
SSL_CTX_set_verify_depth(DtlsTransport::sslCtx, 4);
// Require certificate from peer.
SSL_CTX_set_verify(DtlsTransport::sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
onSslCertificateVerify);
// Set SSL info callback.
SSL_CTX_set_info_callback(DtlsTransport::sslCtx, onSslInfo);
// Set ciphers.
ret = SSL_CTX_set_cipher_list(DtlsTransport::sslCtx,
"DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK");
if (ret == 0) {
LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed");
goto error;
}
// Enable ECDH ciphers.
// DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters
// NOTE: https://code.google.com/p/chromium/issues/detail?id=406458
// NOTE: https://bugs.ruby-lang.org/issues/12324
// For OpenSSL >= 1.0.2.
SSL_CTX_set_ecdh_auto(DtlsTransport::sslCtx, 1);
// Set the "use_srtp" DTLS extension.
for (auto it = DtlsTransport::srtpCryptoSuites.begin();
it != DtlsTransport::srtpCryptoSuites.end(); ++it) {
if (it != DtlsTransport::srtpCryptoSuites.begin()) dtlsSrtpCryptoSuites += ":";
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it);
dtlsSrtpCryptoSuites += cryptoSuiteEntry->name;
}
MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s",
dtlsSrtpCryptoSuites.c_str());
// NOTE: This function returns 0 on success.
ret = SSL_CTX_set_tlsext_use_srtp(DtlsTransport::sslCtx, dtlsSrtpCryptoSuites.c_str());
if (ret != 0) {
MS_ERROR("SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'",
dtlsSrtpCryptoSuites.c_str());
LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed");
goto error;
}
return;
error:
if (DtlsTransport::sslCtx) {
SSL_CTX_free(DtlsTransport::sslCtx);
DtlsTransport::sslCtx = nullptr;
}
MS_THROW_ERROR("SSL context creation failed");
}
void DtlsTransport::GenerateFingerprints() {
MS_TRACE();
for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) {
const std::string& algorithmString = kv.first;
FingerprintAlgorithm algorithm = kv.second;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{0};
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
switch (algorithm) {
case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224();
break;
case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256();
break;
case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384();
break;
case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512();
break;
default:
MS_THROW_ERROR("unknown algorithm");
}
ret = X509_digest(DtlsTransport::certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0) {
MS_ERROR("X509_digest() failed");
MS_THROW_ERROR("Fingerprints generation failed");
}
// Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{0}; i < size; ++i) {
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint);
// Store it in the vector.
DtlsTransport::Fingerprint fingerprint;
fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString);
fingerprint.value = hexFingerprint;
DtlsTransport::localFingerprints.push_back(fingerprint);
}
}
/* Instance methods. */
DtlsTransport::DtlsTransport(Listener* listener) : listener(listener) {
MS_TRACE();
/* Set SSL. */
this->ssl = SSL_new(DtlsTransport::sslCtx);
if (!this->ssl) {
LOG_OPENSSL_ERROR("SSL_new() failed");
goto error;
}
// Set this as custom data.
SSL_set_ex_data(this->ssl, 0, static_cast<void*>(this));
this->sslBioFromNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioFromNetwork) {
LOG_OPENSSL_ERROR("BIO_new() failed");
SSL_free(this->ssl);
goto error;
}
this->sslBioToNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioToNetwork) {
LOG_OPENSSL_ERROR("BIO_new() failed");
BIO_free(this->sslBioFromNetwork);
SSL_free(this->ssl);
goto error;
}
SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork);
// Set the MTU so that we don't send packets that are too large with no fragmentation.
SSL_set_mtu(this->ssl, DtlsMtu);
DTLS_set_link_mtu(this->ssl, DtlsMtu);
// Set callback handler for setting DTLS timer interval.
DTLS_set_timer_cb(this->ssl, onSslDtlsTimer);
// Set the DTLS timer. std::string dtlsSrtpCryptoSuites;
// this->timer = new Timer(this); int ret;
return; /* Set the global DTLS context. */
error: // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0).
DtlsTransport::sslCtx = SSL_CTX_new(DTLS_method());
// NOTE: At this point SSL_set_bio() was not called so we must free BIOs as if (!DtlsTransport::sslCtx)
// well. {
if (this->sslBioFromNetwork) BIO_free(this->sslBioFromNetwork); LOG_OPENSSL_ERROR("SSL_CTX_new() failed");
if (this->sslBioToNetwork) BIO_free(this->sslBioToNetwork); goto error;
}
if (this->ssl) SSL_free(this->ssl); ret = SSL_CTX_use_certificate(DtlsTransport::sslCtx, DtlsTransport::certificate);
// NOTE: If this is not catched by the caller the program will abort, but if (ret == 0)
// this should never happen. {
MS_THROW_ERROR("DtlsTransport instance creation failed"); LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed");
}
DtlsTransport::~DtlsTransport() { goto error;
MS_TRACE(); }
if (IsRunning()) { ret = SSL_CTX_use_PrivateKey(DtlsTransport::sslCtx, DtlsTransport::privateKey);
// Send close alert to the peer.
SSL_shutdown(this->ssl);
SendPendingOutgoingDtlsData();
}
if (this->ssl) { if (ret == 0)
SSL_free(this->ssl); {
LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed");
this->ssl = nullptr; goto error;
this->sslBioFromNetwork = nullptr; }
this->sslBioToNetwork = nullptr;
}
// Close the DTLS timer. ret = SSL_CTX_check_private_key(DtlsTransport::sslCtx);
// delete this->timer;
}
void DtlsTransport::Dump() const { if (ret == 0)
MS_TRACE(); {
LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed");
std::string state{"new"};
std::string role{"none "};
switch (this->state) {
case DtlsState::CONNECTING:
state = "connecting";
break;
case DtlsState::CONNECTED:
state = "connected";
break;
case DtlsState::FAILED:
state = "failed";
break;
case DtlsState::CLOSED:
state = "closed";
break;
default:;
}
switch (this->localRole) {
case Role::AUTO:
role = "auto";
break;
case Role::SERVER:
role = "server";
break;
case Role::CLIENT:
role = "client";
break;
default:;
}
MS_DUMP("<DtlsTransport>");
MS_DUMP(" state : %s", state.c_str());
MS_DUMP(" role : %s", role.c_str());
MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no");
MS_DUMP("</DtlsTransport>");
}
void DtlsTransport::Run(Role localRole) { goto error;
MS_TRACE(); }
MS_ASSERT(localRole == Role::CLIENT || localRole == Role::SERVER, // Set options.
"local DTLS role must be 'client' or 'server'"); SSL_CTX_set_options(
DtlsTransport::sslCtx,
SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE |
SSL_OP_NO_QUERY_MTU);
Role previousLocalRole = this->localRole; // Don't use sessions cache.
SSL_CTX_set_session_cache_mode(DtlsTransport::sslCtx, SSL_SESS_CACHE_OFF);
if (localRole == previousLocalRole) { // Read always as much into the buffer as possible.
MS_ERROR("same local DTLS role provided, doing nothing"); // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL
// versions makes this call required.
SSL_CTX_set_read_ahead(DtlsTransport::sslCtx, 1);
return; SSL_CTX_set_verify_depth(DtlsTransport::sslCtx, 4);
}
// If the previous local DTLS role was 'client' or 'server' do reset. // Require certificate from peer.
if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) { SSL_CTX_set_verify(
MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); DtlsTransport::sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify);
Reset(); // Set SSL info callback.
} SSL_CTX_set_info_callback(DtlsTransport::sslCtx, [](const SSL* ssl, int where, int ret){
static_cast<RTC::DtlsTransport*>(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret);
});
// Set ciphers.
ret = SSL_CTX_set_cipher_list(
DtlsTransport::sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK");
// Update local role. if (ret == 0)
this->localRole = localRole; {
LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed");
// Set state and notify the listener. goto error;
this->state = DtlsState::CONNECTING; }
this->listener->OnDtlsTransportConnecting(this);
switch (this->localRole) { // Enable ECDH ciphers.
case Role::CLIENT: { // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters
MS_DEBUG_TAG(dtls, "running [role:client]"); // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458
// NOTE: https://bugs.ruby-lang.org/issues/12324
SSL_set_connect_state(this->ssl); // For OpenSSL >= 1.0.2.
SSL_do_handshake(this->ssl); SSL_CTX_set_ecdh_auto(DtlsTransport::sslCtx, 1);
SendPendingOutgoingDtlsData();
SetTimeout();
break; // Set the "use_srtp" DTLS extension.
} for (auto it = DtlsTransport::srtpCryptoSuites.begin();
it != DtlsTransport::srtpCryptoSuites.end();
++it)
{
if (it != DtlsTransport::srtpCryptoSuites.begin())
dtlsSrtpCryptoSuites += ":";
case Role::SERVER: { SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it);
MS_DEBUG_TAG(dtls, "running [role:server]"); dtlsSrtpCryptoSuites += cryptoSuiteEntry->name;
}
SSL_set_accept_state(this->ssl); MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str());
SSL_do_handshake(this->ssl);
break; // NOTE: This function returns 0 on success.
} ret = SSL_CTX_set_tlsext_use_srtp(DtlsTransport::sslCtx, dtlsSrtpCryptoSuites.c_str());
default: { if (ret != 0)
MS_ABORT("invalid local DTLS role"); {
} MS_ERROR(
} "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str());
} LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed");
bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) { goto error;
MS_TRACE(); }
MS_ASSERT(fingerprint.algorithm != FingerprintAlgorithm::NONE, return;
"no fingerprint algorithm provided");
this->remoteFingerprint = fingerprint; error:
// The remote fingerpring may have been set after DTLS handshake was done, if (DtlsTransport::sslCtx)
// so we may need to process it now. {
if (this->handshakeDone && this->state != DtlsState::CONNECTED) { SSL_CTX_free(DtlsTransport::sslCtx);
MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); DtlsTransport::sslCtx = nullptr;
}
return ProcessHandshake(); MS_THROW_ERROR("SSL context creation failed");
} }
return true; void DtlsTransport::GenerateFingerprints()
} {
MS_TRACE();
for (auto& kv : DtlsTransport::string2FingerprintAlgorithm)
{
const std::string& algorithmString = kv.first;
FingerprintAlgorithm algorithm = kv.second;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
switch (algorithm)
{
case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224();
break;
case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256();
break;
case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384();
break;
case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512();
break;
default:
MS_THROW_ERROR("unknown algorithm");
}
ret = X509_digest(DtlsTransport::certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0)
{
MS_ERROR("X509_digest() failed");
MS_THROW_ERROR("Fingerprints generation failed");
}
// Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i)
{
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint);
// Store it in the vector.
DtlsTransport::Fingerprint fingerprint;
fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString);
fingerprint.value = hexFingerprint;
DtlsTransport::localFingerprints.push_back(fingerprint);
}
}
/* Instance methods. */
DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener)
{
MS_TRACE();
void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) { /* Set SSL. */
MS_TRACE();
int written; this->ssl = SSL_new(DtlsTransport::sslCtx);
int read;
if (!IsRunning()) { if (!this->ssl)
MS_ERROR("cannot process data while not running"); {
LOG_OPENSSL_ERROR("SSL_new() failed");
return; goto error;
} }
// Write the received DTLS data into the sslBioFromNetwork. // Set this as custom data.
written = SSL_set_ex_data(this->ssl, 0, static_cast<void*>(this));
BIO_write(this->sslBioFromNetwork, static_cast<const void*>(data), static_cast<int>(len));
if (written != static_cast<int>(len)) { this->sslBioFromNetwork = BIO_new(BIO_s_mem());
MS_WARN_TAG(dtls, "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)",
static_cast<size_t>(written), len);
}
// Must call SSL_read() to process received DTLS data. if (!this->sslBioFromNetwork)
read = SSL_read(this->ssl, static_cast<void*>(DtlsTransport::sslReadBuffer), SslReadBufferSize); {
LOG_OPENSSL_ERROR("BIO_new() failed");
// Send data if it's ready. SSL_free(this->ssl);
SendPendingOutgoingDtlsData();
// Check SSL status and return if it is bad/closed. goto error;
if (!CheckStatus(read)) return; }
// Set/update the DTLS timeout. this->sslBioToNetwork = BIO_new(BIO_s_mem());
if (!SetTimeout()) return;
// Application data received. Notify to the listener. if (!this->sslBioToNetwork)
if (read > 0) { {
// It is allowed to receive DTLS data even before validating remote fingerprint. LOG_OPENSSL_ERROR("BIO_new() failed");
if (!this->handshakeDone) {
MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done");
return; BIO_free(this->sslBioFromNetwork);
} SSL_free(this->ssl);
// Notify the listener. goto error;
this->listener->OnDtlsTransportApplicationDataReceived( }
this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast<size_t>(read));
}
}
void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) { SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork);
MS_TRACE();
// We cannot send data to the peer if its remote fingerprint is not validated. // Set the MTU so that we don't send packets that are too large with no fragmentation.
if (this->state != DtlsState::CONNECTED) { SSL_set_mtu(this->ssl, DtlsMtu);
MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); DTLS_set_link_mtu(this->ssl, DtlsMtu);
return; // Set callback handler for setting DTLS timer interval.
} DTLS_set_timer_cb(this->ssl, onSslDtlsTimer);
if (len == 0) { return;
MS_WARN_TAG(dtls, "ignoring 0 length data");
return; error:
}
int written; // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as
// well.
if (this->sslBioFromNetwork)
BIO_free(this->sslBioFromNetwork);
written = SSL_write(this->ssl, static_cast<const void*>(data), static_cast<int>(len)); if (this->sslBioToNetwork)
BIO_free(this->sslBioToNetwork);
if (written < 0) { if (this->ssl)
LOG_OPENSSL_ERROR("SSL_write() failed"); SSL_free(this->ssl);
if (!CheckStatus(written)) return; // NOTE: If this is not catched by the caller the program will abort, but
} else if (written != static_cast<int>(len)) { // this should never happen.
MS_WARN_TAG(dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", MS_THROW_ERROR("DtlsTransport instance creation failed");
written, len); }
}
// Send data. DtlsTransport::~DtlsTransport()
SendPendingOutgoingDtlsData(); {
} MS_TRACE();
void DtlsTransport::Reset() { if (IsRunning())
MS_TRACE(); {
// Send close alert to the peer.
SSL_shutdown(this->ssl);
SendPendingOutgoingDtlsData();
}
int ret; if (this->ssl)
{
SSL_free(this->ssl);
if (!IsRunning()) return; this->ssl = nullptr;
this->sslBioFromNetwork = nullptr;
this->sslBioToNetwork = nullptr;
}
MS_WARN_TAG(dtls, "resetting DTLS transport"); // Close the DTLS timer.
this->timer = nullptr;
}
// Stop the DTLS timer. void DtlsTransport::Dump() const
// this->timer->Stop(); {
MS_TRACE();
std::string state{ "new" };
std::string role{ "none " };
switch (this->state)
{
case DtlsState::CONNECTING:
state = "connecting";
break;
case DtlsState::CONNECTED:
state = "connected";
break;
case DtlsState::FAILED:
state = "failed";
break;
case DtlsState::CLOSED:
state = "closed";
break;
default:;
}
switch (this->localRole)
{
case Role::AUTO:
role = "auto";
break;
case Role::SERVER:
role = "server";
break;
case Role::CLIENT:
role = "client";
break;
default:;
}
MS_DUMP("<DtlsTransport>");
MS_DUMP(" state : %s", state.c_str());
MS_DUMP(" role : %s", role.c_str());
MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no");
MS_DUMP("</DtlsTransport>");
}
void DtlsTransport::Run(Role localRole)
{
MS_TRACE();
// We need to reset the SSL instance so we need to "shutdown" it, but we MS_ASSERT(
// don't want to send a Close Alert to the peer, so just don't call localRole == Role::CLIENT || localRole == Role::SERVER,
// SendPendingOutgoingDTLSData(). "local DTLS role must be 'client' or 'server'");
SSL_shutdown(this->ssl);
this->localRole = Role::NONE; Role previousLocalRole = this->localRole;
this->state = DtlsState::NEW;
this->handshakeDone = false;
this->handshakeDoneNow = false;
// Reset SSL status. if (localRole == previousLocalRole)
// NOTE: For this to properly work, SSL_shutdown() must be called before. {
// NOTE: This may fail if not enough DTLS handshake data has been received, MS_ERROR("same local DTLS role provided, doing nothing");
// but we don't care so just clear the error queue.
ret = SSL_clear(this->ssl);
if (ret == 0) ERR_clear_error(); return;
} }
inline bool DtlsTransport::CheckStatus(int returnCode) { // If the previous local DTLS role was 'client' or 'server' do reset.
MS_TRACE(); if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER)
{
MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change");
int err; Reset();
bool wasHandshakeDone = this->handshakeDone; }
err = SSL_get_error(this->ssl, returnCode); // Update local role.
this->localRole = localRole;
switch (err) { // Set state and notify the listener.
case SSL_ERROR_NONE: this->state = DtlsState::CONNECTING;
break; this->listener->OnDtlsTransportConnecting(this);
case SSL_ERROR_SSL: switch (this->localRole)
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); {
break; case Role::CLIENT:
{
MS_DEBUG_TAG(dtls, "running [role:client]");
case SSL_ERROR_WANT_READ: SSL_set_connect_state(this->ssl);
break; SSL_do_handshake(this->ssl);
SendPendingOutgoingDtlsData();
SetTimeout();
case SSL_ERROR_WANT_WRITE: break;
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); }
break;
case SSL_ERROR_WANT_X509_LOOKUP: case Role::SERVER:
MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); {
break; MS_DEBUG_TAG(dtls, "running [role:server]");
case SSL_ERROR_SYSCALL: SSL_set_accept_state(this->ssl);
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); SSL_do_handshake(this->ssl);
break;
case SSL_ERROR_ZERO_RETURN: break;
break; }
case SSL_ERROR_WANT_CONNECT: default:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); {
break; MS_ABORT("invalid local DTLS role");
}
}
}
case SSL_ERROR_WANT_ACCEPT: bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint)
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); {
break; MS_TRACE();
default: MS_ASSERT(
MS_WARN_TAG(dtls, "SSL status: unknown error"); fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided");
}
// Check if the handshake (or re-handshake) has been done right now. this->remoteFingerprint = fingerprint;
if (this->handshakeDoneNow) {
this->handshakeDoneNow = false;
this->handshakeDone = true;
// Stop the timer. // The remote fingerpring may have been set after DTLS handshake was done,
// this->timer->Stop(); // so we may need to process it now.
if (this->handshakeDone && this->state != DtlsState::CONNECTED)
{
MS_DEBUG_TAG(dtls, "handshake already done, processing it right now");
// Process the handshake just once (ignore if DTLS renegotiation). return ProcessHandshake();
// if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) }
// return ProcessHandshake();
if (!wasHandshakeDone) {
return ProcessHandshake();
}
return true; return true;
} }
// Check if the peer sent close alert or a fatal error happened.
else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL ||
err == SSL_ERROR_SYSCALL) {
if (this->state == DtlsState::CONNECTED) {
MS_DEBUG_TAG(dtls, "disconnected");
Reset(); void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len)
{
MS_TRACE();
int written;
int read;
if (!IsRunning())
{
MS_ERROR("cannot process data while not running");
return;
}
// Write the received DTLS data into the sslBioFromNetwork.
written =
BIO_write(this->sslBioFromNetwork, static_cast<const void*>(data), static_cast<int>(len));
if (written != static_cast<int>(len))
{
MS_WARN_TAG(
dtls,
"OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)",
static_cast<size_t>(written),
len);
}
// Must call SSL_read() to process received DTLS data.
read = SSL_read(this->ssl, static_cast<void*>(DtlsTransport::sslReadBuffer), SslReadBufferSize);
// Send data if it's ready.
SendPendingOutgoingDtlsData();
// Check SSL status and return if it is bad/closed.
if (!CheckStatus(read))
return;
// Set/update the DTLS timeout.
if (!SetTimeout())
return;
// Application data received. Notify to the listener.
if (read > 0)
{
// It is allowed to receive DTLS data even before validating remote fingerprint.
if (!this->handshakeDone)
{
MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done");
return;
}
// Notify the listener.
this->listener->OnDtlsTransportApplicationDataReceived(
this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast<size_t>(read));
}
}
void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len)
{
MS_TRACE();
// Set state and notify the listener. // We cannot send data to the peer if its remote fingerprint is not validated.
this->state = DtlsState::CLOSED; if (this->state != DtlsState::CONNECTED)
this->listener->OnDtlsTransportClosed(this); {
} else { MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected");
MS_WARN_TAG(dtls, "connection failed");
Reset(); return;
}
// Set state and notify the listener. if (len == 0)
this->state = DtlsState::FAILED; {
this->listener->OnDtlsTransportFailed(this); MS_WARN_TAG(dtls, "ignoring 0 length data");
}
return false; return;
} else { }
return true;
}
}
inline void DtlsTransport::SendPendingOutgoingDtlsData() { int written;
MS_TRACE();
if (BIO_eof(this->sslBioToNetwork)) return; written = SSL_write(this->ssl, static_cast<const void*>(data), static_cast<int>(len));
int64_t read; if (written < 0)
char* data{nullptr}; {
LOG_OPENSSL_ERROR("SSL_write() failed");
read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT if (!CheckStatus(written))
return;
}
else if (written != static_cast<int>(len))
{
MS_WARN_TAG(
dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len);
}
if (read <= 0) return; // Send data.
SendPendingOutgoingDtlsData();
}
MS_DEBUG_DEV("%ld bytes of DTLS data ready to sent to the peer", read); void DtlsTransport::Reset()
{
MS_TRACE();
// Notify the listener. int ret;
this->listener->OnDtlsTransportSendData(this, reinterpret_cast<uint8_t*>(data),
static_cast<size_t>(read));
// Clear the BIO buffer. if (!IsRunning())
// NOTE: the (void) avoids the -Wunused-value warning. return;
(void)BIO_reset(this->sslBioToNetwork);
}
inline bool DtlsTransport::SetTimeout() { MS_WARN_TAG(dtls, "resetting DTLS transport");
MS_TRACE();
MS_ASSERT(this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, // Stop the DTLS timer.
"invalid DTLS state"); this->timer = nullptr;
int64_t ret; // We need to reset the SSL instance so we need to "shutdown" it, but we
uv_timeval_t dtlsTimeout{0, 0}; // don't want to send a Close Alert to the peer, so just don't call
uint64_t timeoutMs; // SendPendingOutgoingDTLSData().
SSL_shutdown(this->ssl);
// NOTE: If ret == 0 then ignore the value in dtlsTimeout. this->localRole = Role::NONE;
// NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. this->state = DtlsState::NEW;
ret = DTLSv1_get_timeout(this->ssl, static_cast<void*>(&dtlsTimeout)); // NOLINT this->handshakeDone = false;
this->handshakeDoneNow = false;
if (ret == 0) return true; // Reset SSL status.
// NOTE: For this to properly work, SSL_shutdown() must be called before.
// NOTE: This may fail if not enough DTLS handshake data has been received,
// but we don't care so just clear the error queue.
ret = SSL_clear(this->ssl);
timeoutMs = (dtlsTimeout.tv_sec * static_cast<uint64_t>(1000)) + (dtlsTimeout.tv_usec / 1000); if (ret == 0)
ERR_clear_error();
}
if (timeoutMs == 0) { inline bool DtlsTransport::CheckStatus(int returnCode)
return true; {
} else if (timeoutMs < 30000) { MS_TRACE();
MS_DEBUG_DEV("DTLS timer set in %lu ms", timeoutMs);
int err;
bool wasHandshakeDone = this->handshakeDone;
err = SSL_get_error(this->ssl, returnCode);
switch (err)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_SSL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL");
break;
case SSL_ERROR_WANT_READ:
break;
case SSL_ERROR_WANT_WRITE:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE");
break;
case SSL_ERROR_WANT_X509_LOOKUP:
MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP");
break;
case SSL_ERROR_SYSCALL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL");
break;
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_CONNECT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT");
break;
case SSL_ERROR_WANT_ACCEPT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT");
break;
default:
MS_WARN_TAG(dtls, "SSL status: unknown error");
}
// Check if the handshake (or re-handshake) has been done right now.
if (this->handshakeDoneNow)
{
this->handshakeDoneNow = false;
this->handshakeDone = true;
// Stop the timer.
this->timer = nullptr;
// Process the handshake just once (ignore if DTLS renegotiation).
if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE)
return ProcessHandshake();
return true;
}
// Check if the peer sent close alert or a fatal error happened.
else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL)
{
if (this->state == DtlsState::CONNECTED)
{
MS_DEBUG_TAG(dtls, "disconnected");
Reset();
// Set state and notify the listener.
this->state = DtlsState::CLOSED;
this->listener->OnDtlsTransportClosed(this);
}
else
{
MS_WARN_TAG(dtls, "connection failed");
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
}
return false;
}
else
{
return true;
}
}
inline void DtlsTransport::SendPendingOutgoingDtlsData()
{
MS_TRACE();
// this->timer->Start(timeoutMs); if (BIO_eof(this->sslBioToNetwork))
return;
return true; int64_t read;
} char* data{ nullptr };
// NOTE: Don't start the timer again if the timeout is greater than 30 seconds.
else {
MS_WARN_TAG(dtls, "DTLS timeout too high (%lu ms), resetting DLTS", timeoutMs);
Reset(); read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT
// Set state and notify the listener. if (read <= 0)
this->state = DtlsState::FAILED; return;
this->listener->OnDtlsTransportFailed(this);
return false; MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read);
}
}
inline bool DtlsTransport::ProcessHandshake() { // Notify the listener.
MS_TRACE(); this->listener->OnDtlsTransportSendData(
this, reinterpret_cast<uint8_t*>(data), static_cast<size_t>(read));
MS_ASSERT(this->handshakeDone, "handshake not done yet"); // Clear the BIO buffer.
// MS_ASSERT(this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, // NOTE: the (void) avoids the -Wunused-value warning.
// "remote fingerprint not set"); (void)BIO_reset(this->sslBioToNetwork);
}
// Validate the remote fingerprint. inline bool DtlsTransport::SetTimeout()
// if (!CheckRemoteFingerprint()) { {
// Reset(); MS_TRACE();
MS_ASSERT(
this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED,
"invalid DTLS state");
int64_t ret;
struct timeval dtlsTimeout{ 0, 0 };
uint64_t timeoutMs;
// NOTE: If ret == 0 then ignore the value in dtlsTimeout.
// NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev.
ret = DTLSv1_get_timeout(this->ssl, static_cast<void*>(&dtlsTimeout)); // NOLINT
if (ret == 0)
return true;
timeoutMs = (dtlsTimeout.tv_sec * static_cast<uint64_t>(1000)) + (dtlsTimeout.tv_usec / 1000);
if (timeoutMs == 0)
{
return true;
}
else if (timeoutMs < 30000)
{
MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs);
weak_ptr<DtlsTransport> weak_self = shared_from_this();
this->timer = std::make_shared<Timer>(timeoutMs / 1000.0f, [weak_self](){
auto strong_self = weak_self.lock();
if(strong_self){
strong_self->OnTimer();
}
return true;
}, this->poller);
return true;
}
// NOTE: Don't start the timer again if the timeout is greater than 30 seconds.
else
{
MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs);
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
return false;
}
}
inline bool DtlsTransport::ProcessHandshake()
{
MS_TRACE();
// // Set state and notify the listener. MS_ASSERT(this->handshakeDone, "handshake not done yet");
// this->state = DtlsState::FAILED; MS_ASSERT(
// this->listener->OnDtlsTransportFailed(this); this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
// return false; // Validate the remote fingerprint.
// } if (!CheckRemoteFingerprint())
{
Reset();
// Get the negotiated SRTP crypto suite. // Set state and notify the listener.
RTC::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
if (srtpCryptoSuite != RTC::CryptoSuite::NONE) { return false;
// Extract the SRTP keys (will notify the listener with them). }
ExtractSrtpKeys(srtpCryptoSuite);
return true; // Get the negotiated SRTP crypto suite.
} RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite();
// NOTE: We assume that "use_srtp" DTLS extension is required even if if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE)
// there is no audio/video. {
MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); // Extract the SRTP keys (will notify the listener with them).
ExtractSrtpKeys(srtpCryptoSuite);
Reset(); return true;
}
// Set state and notify the listener. // NOTE: We assume that "use_srtp" DTLS extension is required even if
this->state = DtlsState::FAILED; // there is no audio/video.
this->listener->OnDtlsTransportFailed(this); MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated");
return false; Reset();
}
inline bool DtlsTransport::CheckRemoteFingerprint() { // Set state and notify the listener.
MS_TRACE(); this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
MS_ASSERT(this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, return false;
"remote fingerprint not set"); }
X509* certificate; inline bool DtlsTransport::CheckRemoteFingerprint()
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; {
unsigned int size{0}; MS_TRACE();
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
certificate = SSL_get_peer_certificate(this->ssl); MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
if (!certificate) { X509* certificate;
MS_WARN_TAG(dtls, "no certificate was provided by the peer"); uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
return false; certificate = SSL_get_peer_certificate(this->ssl);
}
switch (this->remoteFingerprint.algorithm) { if (!certificate)
case FingerprintAlgorithm::SHA1: {
hashFunction = EVP_sha1(); MS_WARN_TAG(dtls, "no certificate was provided by the peer");
break;
case FingerprintAlgorithm::SHA224: return false;
hashFunction = EVP_sha224(); }
break;
case FingerprintAlgorithm::SHA256: switch (this->remoteFingerprint.algorithm)
hashFunction = EVP_sha256(); {
break; case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
case FingerprintAlgorithm::SHA384: case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha384(); hashFunction = EVP_sha224();
break; break;
case FingerprintAlgorithm::SHA512: case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha512(); hashFunction = EVP_sha256();
break; break;
default: case FingerprintAlgorithm::SHA384:
MS_ABORT("unknown algorithm"); hashFunction = EVP_sha384();
} break;
// Compare the remote fingerprint with the value given via signaling. case FingerprintAlgorithm::SHA512:
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); hashFunction = EVP_sha512();
break;
if (ret == 0) { default:
MS_ERROR("X509_digest() failed"); MS_ABORT("unknown algorithm");
}
X509_free(certificate); // Compare the remote fingerprint with the value given via signaling.
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
return false; if (ret == 0)
} {
MS_ERROR("X509_digest() failed");
// Convert to hexadecimal format in uppercase with colons. X509_free(certificate);
for (unsigned int i{0}; i < size; ++i) {
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
if (this->remoteFingerprint.value != hexFingerprint) { return false;
MS_WARN_TAG(dtls, }
"fingerprint in the remote certificate (%s) does not match the announced one (%s)",
hexFingerprint, this->remoteFingerprint.value.c_str());
X509_free(certificate); // Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i)
{
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
return false; if (this->remoteFingerprint.value != hexFingerprint)
} {
MS_WARN_TAG(
dtls,
"fingerprint in the remote certificate (%s) does not match the announced one (%s)",
hexFingerprint,
this->remoteFingerprint.value.c_str());
MS_DEBUG_TAG(dtls, "valid remote fingerprint"); //todo 先屏蔽检查客户端签名
#if 0
X509_free(certificate);
// Get the remote certificate in PEM format. return false;
#endif
}
BIO* bio = BIO_new(BIO_s_mem()); MS_DEBUG_TAG(dtls, "valid remote fingerprint");
// Ensure the underlying BUF_MEM structure is also freed. // Get the remote certificate in PEM format.
// NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since
// BIO_set_close() always returns 1.
(void)BIO_set_close(bio, BIO_CLOSE);
ret = PEM_write_bio_X509(bio, certificate); BIO* bio = BIO_new(BIO_s_mem());
if (ret != 1) { // Ensure the underlying BUF_MEM structure is also freed.
LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since
// BIO_set_close() always returns 1.
(void)BIO_set_close(bio, BIO_CLOSE);
X509_free(certificate); ret = PEM_write_bio_X509(bio, certificate);
BIO_free(bio);
return false; if (ret != 1)
} {
LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed");
BUF_MEM* mem; X509_free(certificate);
BIO_free(bio);
BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] return false;
}
if (!mem || !mem->data || mem->length == 0u) { BUF_MEM* mem;
LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed");
X509_free(certificate); BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast]
BIO_free(bio);
return false; if (!mem || !mem->data || mem->length == 0u)
} {
LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed");
this->remoteCert = std::string(mem->data, mem->length); X509_free(certificate);
BIO_free(bio);
X509_free(certificate); return false;
BIO_free(bio); }
return true; this->remoteCert = std::string(mem->data, mem->length);
}
inline void DtlsTransport::ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite) { X509_free(certificate);
MS_TRACE(); BIO_free(bio);
size_t srtpKeyLength{0};
size_t srtpSaltLength{0};
size_t srtpMasterLength{0};
switch (srtpCryptoSuite) {
case RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_80:
case RTC::CryptoSuite::AES_CM_128_HMAC_SHA1_32: {
srtpKeyLength = SrtpMasterKeyLength;
srtpSaltLength = SrtpMasterSaltLength;
srtpMasterLength = SrtpMasterLength;
break;
}
case RTC::CryptoSuite::AEAD_AES_256_GCM: {
srtpKeyLength = SrtpAesGcm256MasterKeyLength;
srtpSaltLength = SrtpAesGcm256MasterSaltLength;
srtpMasterLength = SrtpAesGcm256MasterLength;
break;
}
case RTC::CryptoSuite::AEAD_AES_128_GCM: {
srtpKeyLength = SrtpAesGcm128MasterKeyLength;
srtpSaltLength = SrtpAesGcm128MasterSaltLength;
srtpMasterLength = SrtpAesGcm128MasterLength;
break;
}
default: {
MS_ABORT("unknown SRTP crypto suite");
}
}
auto* srtpMaterial = new uint8_t[srtpMasterLength * 2];
uint8_t* srtpLocalKey{nullptr};
uint8_t* srtpLocalSalt{nullptr};
uint8_t* srtpRemoteKey{nullptr};
uint8_t* srtpRemoteSalt{nullptr};
auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength];
auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength];
int ret;
ret = SSL_export_keying_material(this->ssl, srtpMaterial, srtpMasterLength * 2,
"EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0);
MS_ASSERT(ret != 0, "SSL_export_keying_material() failed");
switch (this->localRole) {
case Role::SERVER: {
srtpRemoteKey = srtpMaterial;
srtpLocalKey = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteSalt + srtpSaltLength;
break;
}
case Role::CLIENT: {
srtpLocalKey = srtpMaterial;
srtpRemoteKey = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalSalt + srtpSaltLength;
break;
}
default: {
MS_ABORT("no DTLS role set");
}
}
// Create the SRTP local master key.
std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength);
std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength);
// Create the SRTP remote master key.
std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength);
std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength);
// Set state and notify the listener.
this->state = DtlsState::CONNECTED;
this->listener->OnDtlsTransportConnected(this, srtpCryptoSuite, srtpLocalMasterKey,
srtpMasterLength, srtpRemoteMasterKey, srtpMasterLength,
this->remoteCert);
delete[] srtpMaterial;
delete[] srtpLocalMasterKey;
delete[] srtpRemoteMasterKey;
}
inline RTC::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() { return true;
MS_TRACE(); }
RTC::CryptoSuite negotiatedSrtpCryptoSuite = RTC::CryptoSuite::NONE; inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite)
{
MS_TRACE();
size_t srtpKeyLength{ 0 };
size_t srtpSaltLength{ 0 };
size_t srtpMasterLength{ 0 };
switch (srtpCryptoSuite)
{
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80:
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32:
{
srtpKeyLength = SrtpMasterKeyLength;
srtpSaltLength = SrtpMasterSaltLength;
srtpMasterLength = SrtpMasterLength;
break;
}
case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM:
{
srtpKeyLength = SrtpAesGcm256MasterKeyLength;
srtpSaltLength = SrtpAesGcm256MasterSaltLength;
srtpMasterLength = SrtpAesGcm256MasterLength;
break;
}
case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM:
{
srtpKeyLength = SrtpAesGcm128MasterKeyLength;
srtpSaltLength = SrtpAesGcm128MasterSaltLength;
srtpMasterLength = SrtpAesGcm128MasterLength;
break;
}
default:
{
MS_ABORT("unknown SRTP crypto suite");
}
}
auto* srtpMaterial = new uint8_t[srtpMasterLength * 2];
uint8_t* srtpLocalKey{ nullptr };
uint8_t* srtpLocalSalt{ nullptr };
uint8_t* srtpRemoteKey{ nullptr };
uint8_t* srtpRemoteSalt{ nullptr };
auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength];
auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength];
int ret;
ret = SSL_export_keying_material(
this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0);
MS_ASSERT(ret != 0, "SSL_export_keying_material() failed");
switch (this->localRole)
{
case Role::SERVER:
{
srtpRemoteKey = srtpMaterial;
srtpLocalKey = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteSalt + srtpSaltLength;
break;
}
case Role::CLIENT:
{
srtpLocalKey = srtpMaterial;
srtpRemoteKey = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalSalt + srtpSaltLength;
break;
}
default:
{
MS_ABORT("no DTLS role set");
}
}
// Create the SRTP local master key.
std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength);
std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength);
// Create the SRTP remote master key.
std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength);
std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength);
// Set state and notify the listener.
this->state = DtlsState::CONNECTED;
this->listener->OnDtlsTransportConnected(
this,
srtpCryptoSuite,
srtpLocalMasterKey,
srtpMasterLength,
srtpRemoteMasterKey,
srtpMasterLength,
this->remoteCert);
delete[] srtpMaterial;
delete[] srtpLocalMasterKey;
delete[] srtpRemoteMasterKey;
}
inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite()
{
MS_TRACE();
// Ensure that the SRTP crypto suite has been negotiated. RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE;
// NOTE: This is a OpenSSL type.
SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl);
if (!sslSrtpCryptoSuite) return negotiatedSrtpCryptoSuite; // Ensure that the SRTP crypto suite has been negotiated.
// NOTE: This is a OpenSSL type.
SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl);
// Get the negotiated SRTP crypto suite. if (!sslSrtpCryptoSuite)
for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) { return negotiatedSrtpCryptoSuite;
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite);
if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) { // Get the negotiated SRTP crypto suite.
MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites)
{
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite);
negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0)
} {
} MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name);
MS_ASSERT(negotiatedSrtpCryptoSuite != RTC::CryptoSuite::NONE, negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite;
"chosen SRTP crypto suite is not an available one"); }
}
return negotiatedSrtpCryptoSuite; MS_ASSERT(
} negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE,
"chosen SRTP crypto suite is not an available one");
inline void DtlsTransport::OnSslInfo(int where, int ret) { return negotiatedSrtpCryptoSuite;
MS_TRACE(); }
int w = where & -SSL_ST_MASK;
const char* role;
if ((w & SSL_ST_CONNECT) != 0)
role = "client";
else if ((w & SSL_ST_ACCEPT) != 0)
role = "server";
else
role = "undefined";
if ((where & SSL_CB_LOOP) != 0) {
MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl));
} else if ((where & SSL_CB_ALERT) != 0) {
const char* alertType;
switch (*SSL_alert_type_string(ret)) {
case 'W':
alertType = "warning";
break;
case 'F':
alertType = "fatal";
break;
default:
alertType = "undefined";
}
if ((where & SSL_CB_READ) != 0) {
MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
} else if ((where & SSL_CB_WRITE) != 0) {
MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
} else {
MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
} else if ((where & SSL_CB_EXIT) != 0) {
if (ret == 0)
MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl));
else if (ret < 0)
MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl));
} else if ((where & SSL_CB_HANDSHAKE_START) != 0) {
MS_DEBUG_TAG(dtls, "DTLS handshake start");
} else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) {
MS_DEBUG_TAG(dtls, "DTLS handshake done");
this->handshakeDoneNow = true;
}
// NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon
// receipt of a close alert does not work (the flag is set after this callback).
}
inline void DtlsTransport::OnTimer() { inline void DtlsTransport::OnSslInfo(int where, int ret)
MS_TRACE(); {
MS_TRACE();
int w = where & -SSL_ST_MASK;
const char* role;
if ((w & SSL_ST_CONNECT) != 0)
role = "client";
else if ((w & SSL_ST_ACCEPT) != 0)
role = "server";
else
role = "undefined";
if ((where & SSL_CB_LOOP) != 0)
{
MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl));
}
else if ((where & SSL_CB_ALERT) != 0)
{
const char* alertType;
switch (*SSL_alert_type_string(ret))
{
case 'W':
alertType = "warning";
break;
case 'F':
alertType = "fatal";
break;
default:
alertType = "undefined";
}
if ((where & SSL_CB_READ) != 0)
{
MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
else if ((where & SSL_CB_WRITE) != 0)
{
MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
else
{
MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
}
else if ((where & SSL_CB_EXIT) != 0)
{
if (ret == 0)
MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl));
else if (ret < 0)
MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl));
}
else if ((where & SSL_CB_HANDSHAKE_START) != 0)
{
MS_DEBUG_TAG(dtls, "DTLS handshake start");
}
else if ((where & SSL_CB_HANDSHAKE_DONE) != 0)
{
MS_DEBUG_TAG(dtls, "DTLS handshake done");
this->handshakeDoneNow = true;
}
// NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon
// receipt of a close alert does not work (the flag is set after this callback).
}
inline void DtlsTransport::OnTimer()
{
MS_TRACE();
// Workaround for https://github.com/openssl/openssl/issues/7998. // Workaround for https://github.com/openssl/openssl/issues/7998.
if (this->handshakeDone) { if (this->handshakeDone)
MS_DEBUG_DEV("handshake is done so return"); {
MS_DEBUG_DEV("handshake is done so return");
return; return;
} }
DTLSv1_handle_timeout(this->ssl); DTLSv1_handle_timeout(this->ssl);
// If required, send DTLS data. // If required, send DTLS data.
SendPendingOutgoingDtlsData(); SendPendingOutgoingDtlsData();
// Set the DTLS timer again. // Set the DTLS timer again.
SetTimeout(); SetTimeout();
} }
} // namespace RTC } // namespace RTC
#ifndef MS_RTC_DTLS_TRANSPORT_HPP #ifndef MS_RTC_DTLS_TRANSPORT_HPP
#define MS_RTC_DTLS_TRANSPORT_HPP #define MS_RTC_DTLS_TRANSPORT_HPP
#include "srtp_session.h"
#include <openssl/bio.h> #include <openssl/bio.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/x509.h> #include <openssl/x509.h>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "Poller/Timer.h"
namespace RTC { #include "Poller/EventPoller.h"
enum class CryptoSuite { using namespace toolkit;
NONE = 0,
AES_CM_128_HMAC_SHA1_80 = 1, namespace RTC
AES_CM_128_HMAC_SHA1_32, {
AEAD_AES_256_GCM, class DtlsTransport : public std::enable_shared_from_this<DtlsTransport>
AEAD_AES_128_GCM {
}; public:
class DtlsTransport { enum class DtlsState
public: {
enum class DtlsState { NEW = 1, CONNECTING, CONNECTED, FAILED, CLOSED }; NEW = 1,
CONNECTING,
public: CONNECTED,
enum class Role { NONE = 0, AUTO = 1, CLIENT, SERVER }; FAILED,
CLOSED
public: };
enum class FingerprintAlgorithm { NONE = 0, SHA1 = 1, SHA224, SHA256, SHA384, SHA512 };
public:
public: enum class Role
struct Fingerprint { {
FingerprintAlgorithm algorithm{FingerprintAlgorithm::NONE}; NONE = 0,
std::string value; AUTO = 1,
}; CLIENT,
SERVER
private: };
struct SrtpCryptoSuiteMapEntry {
RTC::CryptoSuite cryptoSuite; public:
const char* name; enum class FingerprintAlgorithm
}; {
NONE = 0,
public: SHA1 = 1,
class Listener { SHA224,
public: SHA256,
// DTLS is in the process of negotiating a secure connection. Incoming SHA384,
// media can flow through. SHA512
// NOTE: The caller MUST NOT call any method during this callback. };
virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0;
// DTLS has completed negotiation of a secure connection (including DTLS-SRTP public:
// and remote fingerprint verification). Outgoing media can now flow through. struct Fingerprint
// NOTE: The caller MUST NOT call any method during this callback. {
virtual void OnDtlsTransportConnected(const RTC::DtlsTransport* dtlsTransport, FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE };
RTC::CryptoSuite srtpCryptoSuite, uint8_t* srtpLocalKey, std::string value;
size_t srtpLocalKeyLen, uint8_t* srtpRemoteKey, };
size_t srtpRemoteKeyLen, std::string& remoteCert) = 0;
// The DTLS connection has been closed as the result of an error (such as a private:
// DTLS alert or a failure to validate the remote fingerprint). struct SrtpCryptoSuiteMapEntry
virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; {
// The DTLS connection has been closed due to receipt of a close_notify alert. RTC::SrtpSession::CryptoSuite cryptoSuite;
virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; const char* name;
// Need to send DTLS data to the peer. };
virtual void OnDtlsTransportSendData(const RTC::DtlsTransport* dtlsTransport,
const uint8_t* data, size_t len) = 0; public:
// DTLS application data received. class Listener
virtual void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport* dtlsTransport, {
const uint8_t* data, size_t len) = 0; public:
}; // DTLS is in the process of negotiating a secure connection. Incoming
// media can flow through.
public: // NOTE: The caller MUST NOT call any method during this callback.
static void ClassInit(); virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0;
static void ClassDestroy(); // DTLS has completed negotiation of a secure connection (including DTLS-SRTP
static Role StringToRole(const std::string& role) { // and remote fingerprint verification). Outgoing media can now flow through.
auto it = DtlsTransport::string2Role.find(role); // NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnected(
if (it != DtlsTransport::string2Role.end()) const RTC::DtlsTransport* dtlsTransport,
return it->second; RTC::SrtpSession::CryptoSuite srtpCryptoSuite,
else uint8_t* srtpLocalKey,
return DtlsTransport::Role::NONE; size_t srtpLocalKeyLen,
} uint8_t* srtpRemoteKey,
static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) { size_t srtpRemoteKeyLen,
auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); std::string& remoteCert) = 0;
// The DTLS connection has been closed as the result of an error (such as a
if (it != DtlsTransport::string2FingerprintAlgorithm.end()) // DTLS alert or a failure to validate the remote fingerprint).
return it->second; virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0;
else // The DTLS connection has been closed due to receipt of a close_notify alert.
return DtlsTransport::FingerprintAlgorithm::NONE; virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0;
} // Need to send DTLS data to the peer.
static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) { virtual void OnDtlsTransportSendData(
auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
// DTLS application data received.
return it->second; virtual void OnDtlsTransportApplicationDataReceived(
} const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
static bool IsDtls(const uint8_t* data, size_t len) { };
// clang-format off
public:
static void ClassInit();
static void ClassDestroy();
static Role StringToRole(const std::string& role)
{
auto it = DtlsTransport::string2Role.find(role);
if (it != DtlsTransport::string2Role.end())
return it->second;
else
return DtlsTransport::Role::NONE;
}
static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint)
{
auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint);
if (it != DtlsTransport::string2FingerprintAlgorithm.end())
return it->second;
else
return DtlsTransport::FingerprintAlgorithm::NONE;
}
static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint)
{
auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint);
return it->second;
}
static bool IsDtls(const uint8_t* data, size_t len)
{
// clang-format off
return ( return (
// Minimum DTLS record length is 13 bytes. // Minimum DTLS record length is 13 bytes.
(len >= 13) && (len >= 13) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] > 19 && data[0] < 64) (data[0] > 19 && data[0] < 64)
); );
// clang-format on // clang-format on
} }
private: private:
static void GenerateCertificateAndPrivateKey(); static void GenerateCertificateAndPrivateKey();
static void ReadCertificateAndPrivateKeyFromFiles(); static void ReadCertificateAndPrivateKeyFromFiles();
static void CreateSslCtx(); static void CreateSslCtx();
static void GenerateFingerprints(); static void GenerateFingerprints();
private: private:
static X509* certificate; static X509* certificate;
static EVP_PKEY* privateKey; static EVP_PKEY* privateKey;
static SSL_CTX* sslCtx; static SSL_CTX* sslCtx;
static uint8_t sslReadBuffer[]; static uint8_t sslReadBuffer[];
static std::map<std::string, Role> string2Role; static std::map<std::string, Role> string2Role;
static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm; static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm;
static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String; static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String;
static std::vector<Fingerprint> localFingerprints; static std::vector<Fingerprint> localFingerprints;
static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites; static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites;
public: public:
explicit DtlsTransport(Listener* listener); DtlsTransport(EventPoller::Ptr poller, Listener* listener);
~DtlsTransport(); ~DtlsTransport();
public: public:
void Dump() const; void Dump() const;
void Run(Role localRole); void Run(Role localRole);
std::vector<Fingerprint>& GetLocalFingerprints() const { std::vector<Fingerprint>& GetLocalFingerprints() const
return DtlsTransport::localFingerprints; {
} return DtlsTransport::localFingerprints;
bool SetRemoteFingerprint(Fingerprint fingerprint); }
void ProcessDtlsData(const uint8_t* data, size_t len); bool SetRemoteFingerprint(Fingerprint fingerprint);
DtlsState GetState() const { return this->state; } void ProcessDtlsData(const uint8_t* data, size_t len);
Role GetLocalRole() const { return this->localRole; } DtlsState GetState() const
void SendApplicationData(const uint8_t* data, size_t len); {
return this->state;
private: }
bool IsRunning() const { Role GetLocalRole() const
switch (this->state) { {
case DtlsState::NEW: return this->localRole;
return false; }
case DtlsState::CONNECTING: void SendApplicationData(const uint8_t* data, size_t len);
case DtlsState::CONNECTED:
return true; private:
case DtlsState::FAILED: bool IsRunning() const
case DtlsState::CLOSED: {
return false; switch (this->state)
} {
case DtlsState::NEW:
// Make GCC 4.9 happy. return false;
return false; case DtlsState::CONNECTING:
} case DtlsState::CONNECTED:
void Reset(); return true;
bool CheckStatus(int returnCode); case DtlsState::FAILED:
void SendPendingOutgoingDtlsData(); case DtlsState::CLOSED:
bool SetTimeout(); return false;
bool ProcessHandshake(); }
bool CheckRemoteFingerprint();
void ExtractSrtpKeys(RTC::CryptoSuite srtpCryptoSuite); // Make GCC 4.9 happy.
RTC::CryptoSuite GetNegotiatedSrtpCryptoSuite(); return false;
}
/* Callbacks fired by OpenSSL events. */ void Reset();
public: bool CheckStatus(int returnCode);
void OnSslInfo(int where, int ret); void SendPendingOutgoingDtlsData();
bool SetTimeout();
/* Pure virtual methods inherited from Timer::Listener. */ bool ProcessHandshake();
public: bool CheckRemoteFingerprint();
void OnTimer(); void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite);
RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite();
private:
// Passed by argument. private:
Listener* listener{nullptr}; void OnSslInfo(int where, int ret);
// Allocated by this. void OnTimer();
SSL* ssl{nullptr};
BIO* sslBioFromNetwork{nullptr}; // The BIO from which ssl reads. private:
BIO* sslBioToNetwork{nullptr}; // The BIO in which ssl writes. EventPoller::Ptr poller;
// Others. // Passed by argument.
DtlsState state{DtlsState::NEW}; Listener* listener{ nullptr };
Role localRole{Role::NONE}; // Allocated by this.
Fingerprint remoteFingerprint; SSL* ssl{ nullptr };
bool handshakeDone{false}; BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads.
bool handshakeDoneNow{false}; BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes.
std::string remoteCert; Timer::Ptr timer;
}; // Others.
} // namespace RTC DtlsState state{ DtlsState::NEW };
Role localRole{ Role::NONE };
Fingerprint remoteFingerprint;
bool handshakeDone{ false };
bool handshakeDoneNow{ false };
std::string remoteCert;
};
} // namespace RTC
#endif #endif
...@@ -2,268 +2,287 @@ ...@@ -2,268 +2,287 @@
// #define MS_LOG_DEV_LEVEL 3 // #define MS_LOG_DEV_LEVEL 3
#include "srtp_session.h" #include "srtp_session.h"
#include <cstring> // std::memset(), std::memcpy()
#include <cstring> // std::memset(), std::memcpy()
#include <iostream>
#include "logger.h" #include "logger.h"
namespace RTC { namespace RTC
/* Static. */ {
/* Static. */
static constexpr size_t EncryptBufferSize{65536};
static uint8_t EncryptBuffer[EncryptBufferSize]; static constexpr size_t EncryptBufferSize{ 65536 };
static uint8_t EncryptBuffer[EncryptBufferSize];
/* Class methods. */
std::vector<const char *> DepLibSRTP::errors = {
std::vector<const char *> DepLibSRTP::errors = { // From 0 (srtp_err_status_ok) to 24 (srtp_err_status_pfkey_err).
// From 0 (srtp_err_status_ok) to 24 (srtp_err_status_pfkey_err). "success (srtp_err_status_ok)",
"success (srtp_err_status_ok)", "unspecified failure (srtp_err_status_fail)",
"unspecified failure (srtp_err_status_fail)", "unsupported parameter (srtp_err_status_bad_param)",
"unsupported parameter (srtp_err_status_bad_param)", "couldn't allocate memory (srtp_err_status_alloc_fail)",
"couldn't allocate memory (srtp_err_status_alloc_fail)", "couldn't deallocate memory (srtp_err_status_dealloc_fail)",
"couldn't deallocate memory (srtp_err_status_dealloc_fail)", "couldn't initialize (srtp_err_status_init_fail)",
"couldn't initialize (srtp_err_status_init_fail)", "can’t process as much data as requested (srtp_err_status_terminus)",
"can’t process as much data as requested (srtp_err_status_terminus)", "authentication failure (srtp_err_status_auth_fail)",
"authentication failure (srtp_err_status_auth_fail)", "cipher failure (srtp_err_status_cipher_fail)",
"cipher failure (srtp_err_status_cipher_fail)", "replay check failed (bad index) (srtp_err_status_replay_fail)",
"replay check failed (bad index) (srtp_err_status_replay_fail)", "replay check failed (index too old) (srtp_err_status_replay_old)",
"replay check failed (index too old) (srtp_err_status_replay_old)", "algorithm failed test routine (srtp_err_status_algo_fail)",
"algorithm failed test routine (srtp_err_status_algo_fail)", "unsupported operation (srtp_err_status_no_such_op)",
"unsupported operation (srtp_err_status_no_such_op)", "no appropriate context found (srtp_err_status_no_ctx)",
"no appropriate context found (srtp_err_status_no_ctx)", "unable to perform desired validation (srtp_err_status_cant_check)",
"unable to perform desired validation (srtp_err_status_cant_check)", "can’t use key any more (srtp_err_status_key_expired)",
"can’t use key any more (srtp_err_status_key_expired)", "error in use of socket (srtp_err_status_socket_err)",
"error in use of socket (srtp_err_status_socket_err)", "error in use POSIX signals (srtp_err_status_signal_err)",
"error in use POSIX signals (srtp_err_status_signal_err)", "nonce check failed (srtp_err_status_nonce_bad)",
"nonce check failed (srtp_err_status_nonce_bad)", "couldn’t read data (srtp_err_status_read_fail)",
"couldn’t read data (srtp_err_status_read_fail)", "couldn’t write data (srtp_err_status_write_fail)",
"couldn’t write data (srtp_err_status_write_fail)", "error parsing data (srtp_err_status_parse_err)",
"error parsing data (srtp_err_status_parse_err)", "error encoding data (srtp_err_status_encode_err)",
"error encoding data (srtp_err_status_encode_err)", "error while using semaphores (srtp_err_status_semaphore_err)",
"error while using semaphores (srtp_err_status_semaphore_err)", "error while using pfkey (srtp_err_status_pfkey_err)"};
"error while using pfkey (srtp_err_status_pfkey_err)"};
// clang-format on // clang-format on
/* Static methods. */ /* Static methods. */
void DepLibSRTP::ClassInit() { void DepLibSRTP::ClassInit() {
MS_TRACE(); MS_TRACE();
MS_DEBUG_TAG(info, "libsrtp version: \"%s\"", srtp_get_version_string());
srtp_err_status_t err = srtp_init();
if (DepLibSRTP::IsError(err)) MS_DEBUG_TAG(info, "libsrtp version: \"%s\"", srtp_get_version_string());
MS_THROW_ERROR("srtp_init() failed: %s", DepLibSRTP::GetErrorString(err));
}
void DepLibSRTP::ClassDestroy() { srtp_err_status_t err = srtp_init();
MS_TRACE();
srtp_shutdown(); if (DepLibSRTP::IsError(err))
} MS_THROW_ERROR("srtp_init() failed: %s", DepLibSRTP::GetErrorString(err));
void SrtpSession::ClassInit() {
// Set libsrtp event handler.
srtp_err_status_t err =
srtp_install_event_handler(static_cast<srtp_event_handler_func_t *>(OnSrtpEvent));
if (DepLibSRTP::IsError(err)) {
MS_THROW_ERROR("srtp_install_event_handler() failed: %s", DepLibSRTP::GetErrorString(err));
std::cout << "srtp_install_event_handler() failed :" << DepLibSRTP::GetErrorString(err);
} }
}
void SrtpSession::OnSrtpEvent(srtp_event_data_t *data) {
MS_TRACE();
switch (data->event) { void DepLibSRTP::ClassDestroy() {
case event_ssrc_collision: MS_TRACE();
MS_WARN_TAG(srtp, "SSRC collision occurred");
break;
case event_key_soft_limit: srtp_shutdown();
MS_WARN_TAG(srtp, "stream reached the soft key usage limit and will expire soon");
break;
case event_key_hard_limit:
MS_WARN_TAG(srtp, "stream reached the hard key usage limit and has expired");
break;
case event_packet_index_limit:
MS_WARN_TAG(srtp, "stream reached the hard packet limit (2^48 packets)");
break;
} }
}
/* Instance methods. */ /* Class methods. */
void SrtpSession::ClassInit()
{
// Set libsrtp event handler.
srtp_err_status_t err =
srtp_install_event_handler(static_cast<srtp_event_handler_func_t*>(OnSrtpEvent));
if (DepLibSRTP::IsError(err))
{
MS_THROW_ERROR("srtp_install_event_handler() failed: %s", DepLibSRTP::GetErrorString(err));
}
}
void SrtpSession::OnSrtpEvent(srtp_event_data_t* data)
{
MS_TRACE();
switch (data->event)
{
case event_ssrc_collision:
MS_WARN_TAG(srtp, "SSRC collision occurred");
break;
case event_key_soft_limit:
MS_WARN_TAG(srtp, "stream reached the soft key usage limit and will expire soon");
break;
case event_key_hard_limit:
MS_WARN_TAG(srtp, "stream reached the hard key usage limit and has expired");
break;
case event_packet_index_limit:
MS_WARN_TAG(srtp, "stream reached the hard packet limit (2^48 packets)");
break;
}
}
/* Instance methods. */
SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t* key, size_t keyLen)
{
MS_TRACE();
srtp_policy_t policy; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Set all policy fields to 0.
std::memset(&policy, 0, sizeof(srtp_policy_t));
switch (cryptoSuite)
{
case CryptoSuite::AES_CM_128_HMAC_SHA1_80:
{
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp);
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp);
break;
}
case CryptoSuite::AES_CM_128_HMAC_SHA1_32:
{
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp);
// NOTE: Must be 80 for RTCP.
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp);
break;
}
case CryptoSuite::AEAD_AES_256_GCM:
{
srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp);
srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp);
SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen) { break;
MS_TRACE(); }
srtp_policy_t policy;// NOLINT(cppcoreguidelines-pro-type-member-init) case CryptoSuite::AEAD_AES_128_GCM:
{
srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp);
srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp);
// Set all policy fields to 0. break;
std::memset(&policy, 0, sizeof(srtp_policy_t)); }
switch (cryptoSuite) { default:
case CryptoSuite::AES_CM_128_HMAC_SHA1_80: { {
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtp); MS_ABORT("unknown SRTP crypto suite");
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); }
}
break; MS_ASSERT(
} (int)keyLen == policy.rtp.cipher_key_len,
"given keyLen does not match policy.rtp.cipher_keyLen");
case CryptoSuite::AES_CM_128_HMAC_SHA1_32: { switch (type)
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_32(&policy.rtp); {
// NOTE: Must be 80 for RTCP. case Type::INBOUND:
srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&policy.rtcp); policy.ssrc.type = ssrc_any_inbound;
break;
break; case Type::OUTBOUND:
} policy.ssrc.type = ssrc_any_outbound;
break;
}
case CryptoSuite::AEAD_AES_256_GCM: { policy.ssrc.value = 0;
srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtp); policy.key = key;
srtp_crypto_policy_set_aes_gcm_256_16_auth(&policy.rtcp); // Required for sending RTP retransmission without RTX.
policy.allow_repeat_tx = 1;
policy.window_size = 1024;
policy.next = nullptr;
break; // Set the SRTP session.
} srtp_err_status_t err = srtp_create(&this->session, &policy);
case CryptoSuite::AEAD_AES_128_GCM: { if (DepLibSRTP::IsError(err))
srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtp); MS_THROW_ERROR("srtp_create() failed: %s", DepLibSRTP::GetErrorString(err));
srtp_crypto_policy_set_aes_gcm_128_16_auth(&policy.rtcp); }
break; SrtpSession::~SrtpSession()
} {
MS_TRACE();
default: { if (this->session != nullptr)
MS_ABORT("unknown SRTP crypto suite"); {
} srtp_err_status_t err = srtp_dealloc(this->session);
}
MS_ASSERT((int) keyLen == policy.rtp.cipher_key_len, if (DepLibSRTP::IsError(err))
"given keyLen does not match policy.rtp.cipher_keyLen"); MS_ABORT("srtp_dealloc() failed: %s", DepLibSRTP::GetErrorString(err));
}
}
switch (type) { bool SrtpSession::EncryptRtp(const uint8_t** data, size_t* len)
case Type::INBOUND: {
policy.ssrc.type = ssrc_any_inbound; MS_TRACE();
break;
case Type::OUTBOUND: // Ensure that the resulting SRTP packet fits into the encrypt buffer.
policy.ssrc.type = ssrc_any_outbound; if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize)
break; {
} MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len);
policy.ssrc.value = 0; return false;
policy.key = key; }
// Required for sending RTP retransmission without RTX.
policy.allow_repeat_tx = 1;
policy.window_size = 1024;
policy.next = nullptr;
// Set the SRTP session.
srtp_err_status_t err = srtp_create(&this->session, &policy);
if (DepLibSRTP::IsError(err)) {
is_init = false;
MS_THROW_ERROR("srtp_create() failed: %s", DepLibSRTP::GetErrorString(err));
} else {
is_init = true;
}
}
SrtpSession::~SrtpSession() { std::memcpy(EncryptBuffer, *data, *len);
MS_TRACE();
if (this->session != nullptr) { srtp_err_status_t err =
srtp_err_status_t err = srtp_dealloc(this->session); srtp_protect(this->session, static_cast<void*>(EncryptBuffer), reinterpret_cast<int*>(len));
if (DepLibSRTP::IsError(err)) if (DepLibSRTP::IsError(err))
MS_ABORT("srtp_dealloc() failed: %s", DepLibSRTP::GetErrorString(err)); {
} MS_WARN_TAG(srtp, "srtp_protect() failed: %s", DepLibSRTP::GetErrorString(err));
}
bool SrtpSession::EncryptRtp(const uint8_t **data, size_t *len) {
MS_TRACE();
if (!is_init) {
return false;
}
// Ensure that the resulting SRTP packet fits into the encrypt buffer.
if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) {
MS_WARN_TAG(srtp, "cannot encrypt RTP packet, size too big (%zu bytes)", *len);
return false;
}
std::memcpy(EncryptBuffer, *data, *len);
srtp_err_status_t err =
srtp_protect(this->session, static_cast<void *>(EncryptBuffer), reinterpret_cast<int *>(len));
if (DepLibSRTP::IsError(err)) { return false;
MS_WARN_TAG(srtp, "srtp_protect() failed: %s", DepLibSRTP::GetErrorString(err)); }
return false; // Update the given data pointer.
} *data = (const uint8_t*)EncryptBuffer;
// Update the given data pointer.
*data = (const uint8_t *) EncryptBuffer;
return true; return true;
} }
bool SrtpSession::DecryptSrtp(uint8_t *data, size_t *len) { bool SrtpSession::DecryptSrtp(uint8_t* data, size_t* len)
MS_TRACE(); {
MS_TRACE();
srtp_err_status_t err = srtp_err_status_t err =
srtp_unprotect(this->session, static_cast<void *>(data), reinterpret_cast<int *>(len)); srtp_unprotect(this->session, static_cast<void*>(data), reinterpret_cast<int*>(len));
if (DepLibSRTP::IsError(err)) { if (DepLibSRTP::IsError(err))
MS_DEBUG_TAG(srtp, "srtp_unprotect() failed: %s", DepLibSRTP::GetErrorString(err)); {
MS_DEBUG_TAG(srtp, "srtp_unprotect() failed: %s", DepLibSRTP::GetErrorString(err));
return false; return false;
} }
return true; return true;
} }
bool SrtpSession::EncryptRtcp(const uint8_t **data, size_t *len) { bool SrtpSession::EncryptRtcp(const uint8_t** data, size_t* len)
MS_TRACE(); {
MS_TRACE();
// Ensure that the resulting SRTCP packet fits into the encrypt buffer. // Ensure that the resulting SRTCP packet fits into the encrypt buffer.
if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize) { if (*len + SRTP_MAX_TRAILER_LEN > EncryptBufferSize)
MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len); {
MS_WARN_TAG(srtp, "cannot encrypt RTCP packet, size too big (%zu bytes)", *len);
return false; return false;
} }
std::memcpy(EncryptBuffer, *data, *len); std::memcpy(EncryptBuffer, *data, *len);
srtp_err_status_t err = srtp_protect_rtcp(this->session, static_cast<void *>(EncryptBuffer), srtp_err_status_t err = srtp_protect_rtcp(
reinterpret_cast<int *>(len)); this->session, static_cast<void*>(EncryptBuffer), reinterpret_cast<int*>(len));
if (DepLibSRTP::IsError(err)) { if (DepLibSRTP::IsError(err))
MS_WARN_TAG(srtp, "srtp_protect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); {
MS_WARN_TAG(srtp, "srtp_protect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err));
return false; return false;
} }
// Update the given data pointer. // Update the given data pointer.
*data = (const uint8_t *) EncryptBuffer; *data = (const uint8_t*)EncryptBuffer;
return true; return true;
} }
bool SrtpSession::DecryptSrtcp(uint8_t *data, size_t *len) { bool SrtpSession::DecryptSrtcp(uint8_t* data, size_t* len)
MS_TRACE(); {
MS_TRACE();
srtp_err_status_t err = srtp_err_status_t err =
srtp_unprotect_rtcp(this->session, static_cast<void *>(data), reinterpret_cast<int *>(len)); srtp_unprotect_rtcp(this->session, static_cast<void*>(data), reinterpret_cast<int*>(len));
if (DepLibSRTP::IsError(err)) { if (DepLibSRTP::IsError(err))
MS_DEBUG_TAG(srtp, "srtp_unprotect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err)); {
MS_DEBUG_TAG(srtp, "srtp_unprotect_rtcp() failed: %s", DepLibSRTP::GetErrorString(err));
return false; return false;
} }
return true; return true;
} }
}// namespace RTC } // namespace RTC
#ifndef MS_RTC_SRTP_SESSION_HPP #ifndef MS_RTC_SRTP_SESSION_HPP
#define MS_RTC_SRTP_SESSION_HPP #define MS_RTC_SRTP_SESSION_HPP
#include "rtc_dtls_transport.h"
#include "utils.h" #include "utils.h"
#include <srtp2/srtp.h> #include <srtp2/srtp.h>
#include <vector> #include <vector>
namespace RTC { namespace RTC
{
class DepLibSRTP { class DepLibSRTP {
public: public:
static void ClassInit(); static void ClassInit();
static void ClassDestroy(); static void ClassDestroy();
static bool IsError(srtp_err_status_t code) { return (code != srtp_err_status_ok); } static bool IsError(srtp_err_status_t code) { return (code != srtp_err_status_ok); }
static const char *GetErrorString(srtp_err_status_t code) { static const char *GetErrorString(srtp_err_status_t code) {
// This throws out_of_range if the given index is not in the vector. // This throws out_of_range if the given index is not in the vector.
return DepLibSRTP::errors.at(code); return DepLibSRTP::errors.at(code);
} }
private: private:
static std::vector<const char *> errors; static std::vector<const char *> errors;
}; };
class SrtpSession { class SrtpSession
public: {
public: public:
enum class Type { INBOUND = 1, OUTBOUND }; enum class CryptoSuite
{
public: NONE = 0,
static void ClassInit(); AES_CM_128_HMAC_SHA1_80 = 1,
AES_CM_128_HMAC_SHA1_32,
private: AEAD_AES_256_GCM,
static void OnSrtpEvent(srtp_event_data_t *data); AEAD_AES_128_GCM
};
public:
SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen); public:
~SrtpSession(); enum class Type
{
public: INBOUND = 1,
bool EncryptRtp(const uint8_t **data, size_t *len); OUTBOUND
bool DecryptSrtp(uint8_t *data, size_t *len); };
bool EncryptRtcp(const uint8_t **data, size_t *len);
bool DecryptSrtcp(uint8_t *data, size_t *len); public:
void RemoveStream(uint32_t ssrc) { srtp_remove_stream(this->session, uint32_t{htonl(ssrc)}); } static void ClassInit();
private: private:
bool is_init = false; static void OnSrtpEvent(srtp_event_data_t* data);
// Allocated by this.
srtp_t session{nullptr}; public:
}; SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t* key, size_t keyLen);
}// namespace RTC ~SrtpSession();
public:
bool EncryptRtp(const uint8_t** data, size_t* len);
bool DecryptSrtp(uint8_t* data, size_t* len);
bool EncryptRtcp(const uint8_t** data, size_t* len);
bool DecryptSrtcp(uint8_t* data, size_t* len);
void RemoveStream(uint32_t ssrc)
{
srtp_remove_stream(this->session, uint32_t{ htonl(ssrc) });
}
private:
// Allocated by this.
srtp_t session{ nullptr };
};
} // namespace RTC
#endif #endif
...@@ -6,10 +6,79 @@ ...@@ -6,10 +6,79 @@
#include <cstdio> // std::snprintf() #include <cstdio> // std::snprintf()
#include <cstring> // std::memcmp(), std::memcpy() #include <cstring> // std::memcmp(), std::memcpy()
#include "utils.h"
namespace RTC { namespace RTC {
static const uint32_t crc32Table[] =
{
0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3,
0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91,
0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7,
0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5,
0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b,
0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59,
0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f,
0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d,
0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433,
0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01,
0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65,
0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb,
0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9,
0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f,
0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad,
0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683,
0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1,
0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7,
0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5,
0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b,
0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f,
0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d,
0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713,
0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21,
0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777,
0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45,
0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db,
0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9,
0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf,
0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d
};
inline uint32_t GetCRC32(const uint8_t *data, size_t size) {
uint32_t crc{0xFFFFFFFF};
const uint8_t *p = data;
while (size--) {
crc = crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8);
}
return crc ^ ~0U;
}
static std::string openssl_HMACsha1(const void *key, size_t key_len, const void *data, size_t data_len){
std::string str;
str.resize(20);
unsigned int out_len;
#if defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L)
//openssl 1.1.0新增api,老版本api作废
HMAC_CTX *ctx = HMAC_CTX_new();
HMAC_CTX_reset(ctx);
HMAC_Init_ex(ctx, key, (int)key_len, EVP_sha1(), NULL);
HMAC_Update(ctx, (unsigned char*)data, data_len);
HMAC_Final(ctx, (unsigned char *)str.data(), &out_len);
HMAC_CTX_reset(ctx);
HMAC_CTX_free(ctx);
#else
HMAC_CTX ctx;
HMAC_CTX_init(&ctx);
HMAC_Init_ex(&ctx, key, key_len, EVP_sha1(), NULL);
HMAC_Update(&ctx, (unsigned char*)data, data_len);
HMAC_Final(&ctx, (unsigned char *)str.data(), &out_len);
HMAC_CTX_cleanup(&ctx);
#endif //defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L)
return str;
}
/* Class variables. */ /* Class variables. */
const uint8_t StunPacket::kMagicCookie[] = {0x21, 0x12, 0xA4, 0x42}; const uint8_t StunPacket::kMagicCookie[] = {0x21, 0x12, 0xA4, 0x42};
...@@ -258,7 +327,7 @@ StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) { ...@@ -258,7 +327,7 @@ StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) {
if (hasFingerprint) { if (hasFingerprint) {
// Compute the CRC32 of the received packet up to (but excluding) the // Compute the CRC32 of the received packet up to (but excluding) the
// FINGERPRINT attribute and XOR it with 0x5354554e. // FINGERPRINT attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = Utils::Crypto::GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e;
// Compare with the FINGERPRINT value in the packet. // Compare with the FINGERPRINT value in the packet.
if (fingerprint != computedFingerprint) { if (fingerprint != computedFingerprint) {
...@@ -290,79 +359,6 @@ StunPacket::~StunPacket() { ...@@ -290,79 +359,6 @@ StunPacket::~StunPacket() {
// MS_TRACE(); // MS_TRACE();
} }
void StunPacket::Dump() const {
// MS_TRACE();
// MS_DUMP("<StunPacket>");
std::string klass;
switch (this->klass) {
case Class::REQUEST:
klass = "Request";
break;
case Class::INDICATION:
klass = "Indication";
break;
case Class::SUCCESS_RESPONSE:
klass = "SuccessResponse";
break;
case Class::ERROR_RESPONSE:
klass = "ErrorResponse";
break;
}
if (this->method == Method::BINDING) {
// MS_DUMP(" Binding %s", klass.c_str());
} else {
// This prints the unknown method number. Example: TURN Allocate => 0x003.
// MS_DUMP(" %s with unknown method %#.3x", klass.c_str(),
// static_cast<uint16_t>(this->method));
}
// MS_DUMP(" size: %zu bytes", this->size);
static char transactionId[25];
for (int i{0}; i < 12; ++i) {
// NOTE: n must be 3 because snprintf adds a \0 after printed chars.
std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]);
}
// MS_DUMP(" transactionId: %s", transactionId);
if (this->errorCode != 0u)
// MS_DUMP(" errorCode: %" PRIu16, this->errorCode);
if (!this->username.empty())
// MS_DUMP(" username: %s", this->username.c_str());
if (this->priority != 0u)
// MS_DUMP(" priority: %" PRIu32, this->priority);
if (this->iceControlling != 0u)
// MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling);
if (this->iceControlled != 0u)
// MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled);
if (this->hasUseCandidate)
// MS_DUMP(" useCandidate");
if (this->xorMappedAddress != nullptr) {
int family;
uint16_t port;
std::string ip;
Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port);
// MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port);
}
if (this->messageIntegrity != nullptr) {
static char messageIntegrity[41];
for (int i{0}; i < 20; ++i) {
std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]);
}
// MS_DUMP(" messageIntegrity: %s", messageIntegrity);
}
if (this->hasFingerprint) {
}
// MS_DUMP(" has fingerprint");
// MS_DUMP("</StunPacket>");
}
StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& localUsername, StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& localUsername,
const std::string& localPassword) { const std::string& localPassword) {
// MS_TRACE(); // MS_TRACE();
...@@ -402,13 +398,13 @@ StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& lo ...@@ -402,13 +398,13 @@ StunPacket::Authentication StunPacket::CheckAuthentication(const std::string& lo
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8)); Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules.
const uint8_t* computedMessageIntegrity = Utils::Crypto::GetHmacShA1( auto computedMessageIntegrity = openssl_HMACsha1(
localPassword, this->data, (this->messageIntegrity - 4) - this->data); localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data);
Authentication result; Authentication result;
// Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet.
if (std::memcmp(this->messageIntegrity, computedMessageIntegrity, 20) == 0) if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0)
result = Authentication::OK; result = Authentication::OK;
else else
result = Authentication::UNAUTHORIZED; result = Authentication::UNAUTHORIZED;
...@@ -670,12 +666,11 @@ void StunPacket::Serialize(uint8_t* buffer) { ...@@ -670,12 +666,11 @@ void StunPacket::Serialize(uint8_t* buffer) {
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8)); Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules.
const uint8_t* computedMessageIntegrity = auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos);
Utils::Crypto::GetHmacShA1(this->password, buffer, pos);
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::MESSAGE_INTEGRITY)); Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::MESSAGE_INTEGRITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 20); Utils::Byte::Set2Bytes(buffer, pos + 2, 20);
std::memcpy(buffer + pos + 4, computedMessageIntegrity, 20); std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size());
// Update the pointer. // Update the pointer.
this->messageIntegrity = buffer + pos + 4; this->messageIntegrity = buffer + pos + 4;
...@@ -692,7 +687,7 @@ void StunPacket::Serialize(uint8_t* buffer) { ...@@ -692,7 +687,7 @@ void StunPacket::Serialize(uint8_t* buffer) {
if (addFingerprint) { if (addFingerprint) {
// Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT
// attribute and XOR it with 0x5354554e. // attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = Utils::Crypto::GetCRC32(buffer, pos) ^ 0x5354554e; uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e;
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT)); Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4); Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
......
#define MS_CLASS "Utils::Crypto"
// #define MS_LOG_DEV
#include "utils.h"
#include "openssl/sha.h"
namespace Utils {
/* Static variables. */
uint32_t Crypto::seed;
HMAC_CTX *Crypto::hmacSha1Ctx{nullptr};
uint8_t Crypto::hmacSha1Buffer[20];// SHA-1 result is 20 bytes long.
// clang-format off
const uint32_t Crypto::crc32Table[] =
{
0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3,
0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91,
0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7,
0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5,
0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b,
0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59,
0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f,
0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d,
0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433,
0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01,
0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65,
0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb,
0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9,
0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f,
0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad,
0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683,
0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1,
0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7,
0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5,
0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b,
0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f,
0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d,
0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713,
0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21,
0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777,
0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45,
0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db,
0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9,
0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf,
0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d
};
// clang-format on
/* Static methods. */
void Crypto::ClassInit() {
// MS_TRACE();
// Init the vrypto seed with a random number taken from the address
// of the seed variable itself (which is random).
Crypto::seed = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(std::addressof(Crypto::seed)));
// Create an OpenSSL HMAC_CTX context for HMAC SHA1 calculation.
// Crypto::hmacSha1Ctx = HMAC_CTX_new();
if (Crypto::hmacSha1Ctx == nullptr) {
Crypto::hmacSha1Ctx = HMAC_CTX_new();
}
}
void Crypto::ClassDestroy() {
// MS_TRACE();
if (Crypto::hmacSha1Ctx != nullptr) {
HMAC_CTX_free(Crypto::hmacSha1Ctx);
}
}
const uint8_t *Crypto::GetHmacShA1(const std::string &key, const uint8_t *data, size_t len) {
// MS_TRACE();
size_t ret;
ret = HMAC_Init_ex(Crypto::hmacSha1Ctx, key.c_str(), key.length(), EVP_sha1(), nullptr);
// MS_ASSERT(ret == 1, "OpenSSL HMAC_Init_ex() failed with key '%s'", key.c_str());
ret = HMAC_Update(Crypto::hmacSha1Ctx, data, static_cast<int>(len));
/*
MS_ASSERT(
ret == 1,
"OpenSSL HMAC_Update() failed with key '%s' and data length %zu bytes",
key.c_str(),
len);
*/
uint32_t resultLen;
ret = HMAC_Final(Crypto::hmacSha1Ctx, (uint8_t *) Crypto::hmacSha1Buffer, &resultLen);
/*
MS_ASSERT(
ret == 1, "OpenSSL HMAC_Final() failed with key '%s' and data length %zu bytes", key.c_str(),
len); MS_ASSERT(resultLen == 20, "OpenSSL HMAC_Final() resultLen is %u instead of 20", resultLen);
*/
return Crypto::hmacSha1Buffer;
}
}// namespace Utils
namespace Utils {
static std::string inet_ntoa(struct in_addr in) {
char buf[20];
unsigned char *p = (unsigned char *) &(in);
snprintf(buf, sizeof(buf), "%u.%u.%u.%u", p[0], p[1], p[2], p[3]);
return buf;
}
void IP::GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip, uint16_t &port) {
char ipBuffer[INET6_ADDRSTRLEN + 1];
switch (addr->sa_family) {
case AF_INET: {
ip = Utils::inet_ntoa(reinterpret_cast<const struct sockaddr_in *>(addr)->sin_addr);
port = static_cast<uint16_t>(ntohs(reinterpret_cast<const struct sockaddr_in *>(addr)->sin_port));
break;
}
case AF_INET6: {
port = static_cast<uint16_t>(ntohs(reinterpret_cast<const struct sockaddr_in6 *>(addr)->sin6_port));
break;
}
default: {
// MS_ABORT("unknown network family: %d", static_cast<int>(addr->sa_family));
}
}
family = addr->sa_family;
ip.assign(ipBuffer);
}
}// namespace Utils
\ No newline at end of file
...@@ -30,76 +30,6 @@ ...@@ -30,76 +30,6 @@
#include <string> #include <string>
namespace Utils { namespace Utils {
class IP {
public:
static int GetFamily(const char *ip, size_t ipLen);
static int GetFamily(const std::string &ip);
static void GetAddressInfo(const struct sockaddr *addr, int &family, std::string &ip,
uint16_t &port);
static bool CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2);
static struct sockaddr_storage CopyAddress(const struct sockaddr *addr);
static void NormalizeIp(std::string &ip);
};
/* Inline static methods. */
inline int IP::GetFamily(const std::string &ip) { return GetFamily(ip.c_str(), ip.size()); }
inline bool IP::CompareAddresses(const struct sockaddr *addr1, const struct sockaddr *addr2) {
// Compare family.
if (addr1->sa_family != addr2->sa_family ||
(addr1->sa_family != AF_INET && addr1->sa_family != AF_INET6)) {
return false;
}
// Compare port.
if (reinterpret_cast<const struct sockaddr_in *>(addr1)->sin_port !=
reinterpret_cast<const struct sockaddr_in *>(addr2)->sin_port) {
return false;
}
// Compare IP.
switch (addr1->sa_family) {
case AF_INET: {
return (reinterpret_cast<const struct sockaddr_in *>(addr1)->sin_addr.s_addr ==
reinterpret_cast<const struct sockaddr_in *>(addr2)->sin_addr.s_addr);
}
case AF_INET6: {
return (std::memcmp(
std::addressof(reinterpret_cast<const struct sockaddr_in6 *>(addr1)->sin6_addr),
std::addressof(reinterpret_cast<const struct sockaddr_in6 *>(addr2)->sin6_addr),
16) == 0
? true
: false);
}
default: {
return false;
}
}
}
inline struct sockaddr_storage IP::CopyAddress(const struct sockaddr *addr) {
struct sockaddr_storage copiedAddr;
switch (addr->sa_family) {
case AF_INET:
std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in));
break;
case AF_INET6:
std::memcpy(std::addressof(copiedAddr), addr, sizeof(struct sockaddr_in6));
break;
}
return copiedAddr;
}
class File {
public:
static void CheckFile(const char *file);
};
class Byte { class Byte {
public: public:
...@@ -181,138 +111,6 @@ inline uint16_t Byte::PadTo4Bytes(uint16_t size) { ...@@ -181,138 +111,6 @@ inline uint16_t Byte::PadTo4Bytes(uint16_t size) {
return size; return size;
} }
inline uint32_t Byte::PadTo4Bytes(uint32_t size) {
// If size is not multiple of 32 bits then pad it.
if (size & 0x03)
return (size & 0xFFFFFFFC) + 4;
else
return size;
}
class Bits {
public:
static size_t CountSetBits(const uint16_t mask);
};
/* Inline static methods. */
class Crypto {
public:
static void ClassInit();
static void ClassDestroy();
static uint32_t GetRandomUInt(uint32_t min, uint32_t max);
static const std::string GetRandomString(size_t len);
static uint32_t GetCRC32(const uint8_t *data, size_t size);
static const uint8_t *GetHmacShA1(const std::string &key, const uint8_t *data, size_t len);
private:
static uint32_t seed;
static HMAC_CTX *hmacSha1Ctx;
static uint8_t hmacSha1Buffer[];
static const uint32_t crc32Table[256];
};
/* Inline static methods. */
inline uint32_t Crypto::GetRandomUInt(uint32_t min, uint32_t max) {
// NOTE: This is the original, but produces very small values.
// Crypto::seed = (214013 * Crypto::seed) + 2531011;
// return (((Crypto::seed>>16)&0x7FFF) % (max - min + 1)) + min;
// This seems to produce better results.
Crypto::seed = uint32_t{((214013 * Crypto::seed) + 2531011)};
return (((Crypto::seed >> 4) & 0x7FFF7FFF) % (max - min + 1)) + min;
}
inline const std::string Crypto::GetRandomString(size_t len) {
static char buffer[64];
static const char chars[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b',
'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'};
if (len > 64) len = 64;
for (size_t i{0}; i < len; ++i) {
buffer[i] = chars[GetRandomUInt(0, sizeof(chars) - 1)];
}
return std::string(buffer, len);
}
inline uint32_t Crypto::GetCRC32(const uint8_t *data, size_t size) {
uint32_t crc{0xFFFFFFFF};
const uint8_t *p = data;
while (size--) {
crc = Crypto::crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8);
}
return crc ^ ~0U;
}
class String {
public:
static void ToLowerCase(std::string &str);
};
inline void String::ToLowerCase(std::string &str) {
std::transform(str.begin(), str.end(), str.begin(), ::tolower);
}
class Time {
// Seconds from Jan 1, 1900 to Jan 1, 1970.
static constexpr uint32_t UnixNtpOffset{0x83AA7E80};
// NTP fractional unit.
static constexpr uint64_t NtpFractionalUnit{1LL << 32};
public:
struct Ntp {
uint32_t seconds;
uint32_t fractions;
};
static Time::Ntp TimeMs2Ntp(uint64_t ms);
static uint64_t Ntp2TimeMs(Time::Ntp ntp);
static bool IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp);
static uint32_t LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2);
};
inline Time::Ntp Time::TimeMs2Ntp(uint64_t ms) {
Time::Ntp ntp;// NOLINT(cppcoreguidelines-pro-type-member-init)
ntp.seconds = uint32_t(ms / 1000);
ntp.fractions =
static_cast<uint32_t>((static_cast<double>(ms % 1000) / 1000) * NtpFractionalUnit);
return ntp;
}
inline uint64_t Time::Ntp2TimeMs(Time::Ntp ntp) {
// clang-format off
return (
static_cast<uint64_t>(ntp.seconds) * 1000 +
static_cast<uint64_t>(std::round((static_cast<double>(ntp.fractions) * 1000) / NtpFractionalUnit))
);
// clang-format on
}
inline bool Time::IsNewerTimestamp(uint32_t timestamp, uint32_t prevTimestamp) {
// Distinguish between elements that are exactly 0x80000000 apart.
// If t1>t2 and |t1-t2| = 0x80000000: IsNewer(t1,t2)=true,
// IsNewer(t2,t1)=false
// rather than having IsNewer(t1,t2) = IsNewer(t2,t1) = false.
if (static_cast<uint32_t>(timestamp - prevTimestamp) == 0x80000000)
return timestamp > prevTimestamp;
return timestamp != prevTimestamp &&
static_cast<uint32_t>(timestamp - prevTimestamp) < 0x80000000;
}
inline uint32_t Time::LatestTimestamp(uint32_t timestamp1, uint32_t timestamp2) {
return IsNewerTimestamp(timestamp1, timestamp2) ? timestamp1 : timestamp2;
}
}// namespace Utils }// namespace Utils
#endif #endif
...@@ -4,31 +4,81 @@ ...@@ -4,31 +4,81 @@
WebRtcTransport::WebRtcTransport() { WebRtcTransport::WebRtcTransport() {
static onceToken token([](){ static onceToken token([](){
Utils::Crypto::ClassInit();
RTC::DtlsTransport::ClassInit(); RTC::DtlsTransport::ClassInit();
RTC::DepLibSRTP::ClassInit(); RTC::DepLibSRTP::ClassInit();
RTC::SrtpSession::ClassInit(); RTC::SrtpSession::ClassInit();
}); });
ice_server_ = std::make_shared<IceServer>(Utils::Crypto::GetRandomString(4), Utils::Crypto::GetRandomString(24)); dtls_transport_ = std::make_shared<RTC::DtlsTransport>(EventPollerPool::Instance().getFirstPoller(), this);
ice_server_->SetIceServerCompletedCB([this]() { ice_server_ = std::make_shared<RTC::IceServer>(this, makeRandStr(4), makeRandStr(24));
this->OnIceServerCompleted();
});
ice_server_->SetSendCB([this](char *buf, size_t len, struct sockaddr_in *remote_address) {
this->WritePacket(buf, len, remote_address);
});
// todo dtls服务器或客户端模式
dtls_transport_ = std::make_shared<DtlsTransport>(true);
dtls_transport_->SetHandshakeCompletedCB([this](std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) {
this->OnDtlsCompleted(client_key, server_key, srtp_crypto_suite);
});
dtls_transport_->SetOutPutCB([this](char *buf, size_t len) { this->WritePacket(buf, len); });
} }
WebRtcTransport::~WebRtcTransport() {} WebRtcTransport::~WebRtcTransport() {}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void WebRtcTransport::OnIceServerSendStunPacket(const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) {
onWrite((char *)packet->GetData(), packet->GetSize(), (struct sockaddr_in *)tuple);
}
void WebRtcTransport::OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) {
InfoL;
}
void WebRtcTransport::OnIceServerConnected(const RTC::IceServer *iceServer) {
InfoL;
dtls_transport_->Run(RTC::DtlsTransport::Role::SERVER);
}
void WebRtcTransport::OnIceServerCompleted(const RTC::IceServer *iceServer) {
InfoL;
}
void WebRtcTransport::OnIceServerDisconnected(const RTC::IceServer *iceServer) {
InfoL;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void WebRtcTransport::OnDtlsTransportConnected(
const RTC::DtlsTransport *dtlsTransport,
RTC::SrtpSession::CryptoSuite srtpCryptoSuite,
uint8_t *srtpLocalKey,
size_t srtpLocalKeyLen,
uint8_t *srtpRemoteKey,
size_t srtpRemoteKeyLen,
std::string &remoteCert) {
InfoL;
srtp_session_ = std::make_shared<RTC::SrtpSession>(RTC::SrtpSession::Type::OUTBOUND, srtpCryptoSuite, srtpLocalKey, srtpLocalKeyLen);
onDtlsConnected();
}
void WebRtcTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) {
onWrite((char *)data, len);
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
void WebRtcTransport::onWrite(const char *buf, size_t len){
auto tuple = ice_server_->GetSelectedTuple();
assert(tuple);
onWrite(buf, len, (struct sockaddr_in *)tuple);
}
std::string WebRtcTransport::GetLocalSdp() { std::string WebRtcTransport::GetLocalSdp() {
RTC::DtlsTransport::Fingerprint remote_fingerprint;
remote_fingerprint.algorithm = RTC::DtlsTransport::GetFingerprintAlgorithm("sha-256");
remote_fingerprint.value = "";
dtls_transport_->SetRemoteFingerprint(remote_fingerprint);
string finger_print_sha256;
auto finger_prints = dtls_transport_->GetLocalFingerprints();
for (size_t i = 0; i < finger_prints.size(); i++) {
if (finger_prints[i].algorithm == RTC::DtlsTransport::FingerprintAlgorithm::SHA256) {
finger_print_sha256 = finger_prints[i].value;
}
}
char sdp[1024 * 10] = {0}; char sdp[1024 * 10] = {0};
auto ssrc = getSSRC(); auto ssrc = getSSRC();
auto ip = getIP(); auto ip = getIP();
...@@ -60,22 +110,10 @@ std::string WebRtcTransport::GetLocalSdp() { ...@@ -60,22 +110,10 @@ std::string WebRtcTransport::GetLocalSdp() {
"a=candidate:%s 1 udp %u %s %u typ %s\r\n", "a=candidate:%s 1 udp %u %s %u typ %s\r\n",
ip.c_str(), port, pt, ip.c_str(), ip.c_str(), port, pt, ip.c_str(),
ice_server_->GetUsernameFragment().c_str(),ice_server_->GetPassword().c_str(), ice_server_->GetUsernameFragment().c_str(),ice_server_->GetPassword().c_str(),
dtls_transport_->GetMyFingerprint().c_str(), pt, ssrc, ssrc, ssrc, ssrc, "4", ssrc, ip.c_str(), port, "host"); finger_print_sha256.c_str(), pt, ssrc, ssrc, ssrc, ssrc, "4", ssrc, ip.c_str(), port, "host");
return sdp; return sdp;
} }
void WebRtcTransport::OnIceServerCompleted() {
InfoL;
dtls_transport_->Start();
onIceConnected();
}
void WebRtcTransport::OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite) {
InfoL << client_key << " " << server_key << " " << (int)srtp_crypto_suite;
srtp_session_ = std::make_shared<RTC::SrtpSession>(RTC::SrtpSession::Type::OUTBOUND, srtp_crypto_suite, (uint8_t *) client_key.c_str(), client_key.size());
onDtlsCompleted();
}
bool is_dtls(char *buf) { bool is_dtls(char *buf) {
return ((*buf > 19) && (*buf < 64)); return ((*buf > 19) && (*buf < 64));
} }
...@@ -90,25 +128,23 @@ bool is_rtcp(char *buf) { ...@@ -90,25 +128,23 @@ bool is_rtcp(char *buf) {
return ((header->pt >= 64) && (header->pt < 96)); return ((header->pt >= 64) && (header->pt < 96));
} }
void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address) { void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, RTC::TransportTuple *tuple) {
if (RTC::StunPacket::IsStun((const uint8_t *) buf, len)) { if (RTC::StunPacket::IsStun((const uint8_t *) buf, len)) {
InfoL << "stun:" << hexdump(buf, len);
RTC::StunPacket *packet = RTC::StunPacket::Parse((const uint8_t *) buf, len); RTC::StunPacket *packet = RTC::StunPacket::Parse((const uint8_t *) buf, len);
if (packet == nullptr) { if (packet == nullptr) {
WarnL << "parse stun error" << std::endl; WarnL << "parse stun error" << std::endl;
return; return;
} }
ice_server_->ProcessStunPacket(packet, remote_address); ice_server_->ProcessStunPacket(packet, tuple);
return; return;
} }
if (DtlsTransport::IsDtlsPacket(buf, len)) { if (is_dtls(buf)) {
InfoL << "dtls:" << hexdump(buf, len); dtls_transport_->ProcessDtlsData((uint8_t *)buf, len);
dtls_transport_->InputData(buf, len);
return; return;
} }
if (is_rtp(buf)) { if (is_rtp(buf)) {
RtpHeader *header = (RtpHeader *) buf; RtpHeader *header = (RtpHeader *) buf;
InfoL << "rtp:" << header->dumpString(len); // InfoL << "rtp:" << header->dumpString(len);
return; return;
} }
if (is_rtcp(buf)) { if (is_rtcp(buf)) {
...@@ -118,10 +154,6 @@ void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_i ...@@ -118,10 +154,6 @@ void WebRtcTransport::OnInputDataPacket(char *buf, size_t len, struct sockaddr_i
} }
} }
void WebRtcTransport::WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address) {
onWrite(buf, len, remote_address ? remote_address : (ice_server_ ? ice_server_->GetSelectAddr() : nullptr));
}
void WebRtcTransport::WritRtpPacket(char *buf, size_t len) { void WebRtcTransport::WritRtpPacket(char *buf, size_t len) {
const uint8_t *p = (uint8_t *) buf; const uint8_t *p = (uint8_t *) buf;
bool ret = false; bool ret = false;
...@@ -129,7 +161,7 @@ void WebRtcTransport::WritRtpPacket(char *buf, size_t len) { ...@@ -129,7 +161,7 @@ void WebRtcTransport::WritRtpPacket(char *buf, size_t len) {
ret = srtp_session_->EncryptRtp(&p, &len); ret = srtp_session_->EncryptRtp(&p, &len);
} }
if (ret) { if (ret) {
onWrite((char *) p, len, ice_server_->GetSelectAddr()); onWrite((char *) p, len);
} }
} }
...@@ -139,8 +171,8 @@ WebRtcTransportImp::WebRtcTransportImp(const EventPoller::Ptr &poller) { ...@@ -139,8 +171,8 @@ WebRtcTransportImp::WebRtcTransportImp(const EventPoller::Ptr &poller) {
_socket = Socket::createSocket(poller, false); _socket = Socket::createSocket(poller, false);
//随机端口,绑定全部网卡 //随机端口,绑定全部网卡
_socket->bindUdpSock(0); _socket->bindUdpSock(0);
_socket->setOnRead([this](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len){ _socket->setOnRead([this](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) mutable {
OnInputDataPacket(buf->data(), buf->size(), (struct sockaddr_in*)addr); OnInputDataPacket(buf->data(), buf->size(), addr);
}); });
} }
...@@ -149,7 +181,7 @@ void WebRtcTransportImp::attach(const RtspMediaSource::Ptr &src) { ...@@ -149,7 +181,7 @@ void WebRtcTransportImp::attach(const RtspMediaSource::Ptr &src) {
_src = src; _src = src;
} }
void WebRtcTransportImp::onDtlsCompleted() { void WebRtcTransportImp::onDtlsConnected() {
_reader = _src->getRing()->attach(_socket->getPoller(), true); _reader = _src->getRing()->attach(_socket->getPoller(), true);
weak_ptr<WebRtcTransportImp> weak_self = shared_from_this(); weak_ptr<WebRtcTransportImp> weak_self = shared_from_this();
_reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt){ _reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt){
...@@ -167,14 +199,9 @@ void WebRtcTransportImp::onDtlsCompleted() { ...@@ -167,14 +199,9 @@ void WebRtcTransportImp::onDtlsCompleted() {
}); });
} }
void WebRtcTransportImp::onIceConnected(){
}
void WebRtcTransportImp::onWrite(const char *buf, size_t len, struct sockaddr_in *dst) { void WebRtcTransportImp::onWrite(const char *buf, size_t len, struct sockaddr_in *dst) {
auto ptr = BufferRaw::create(); auto ptr = BufferRaw::create();
ptr->assign(buf, len); ptr->assign(buf, len);
// InfoL << len << " " << SockUtil::inet_ntoa(dst->sin_addr) << " " << ntohs(dst->sin_port);
_socket->send(ptr, (struct sockaddr *)(dst), sizeof(struct sockaddr)); _socket->send(ptr, (struct sockaddr *)(dst), sizeof(struct sockaddr));
} }
...@@ -201,15 +228,5 @@ std::string WebRtcTransportImp::getIP() const { ...@@ -201,15 +228,5 @@ std::string WebRtcTransportImp::getIP() const {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
INSTANCE_IMP(WebRtcManager)
WebRtcManager::WebRtcManager() {
}
WebRtcManager::~WebRtcManager() {
}
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "dtls_transport.h" #include "rtc_dtls_transport.h"
#include "ice_server.h" #include "ice_server.h"
#include "srtp_session.h" #include "srtp_session.h"
#include "stun_packet.h" #include "stun_packet.h"
class WebRtcTransport { class WebRtcTransport : public RTC::DtlsTransport::Listener, public RTC::IceServer::Listener {
public: public:
using Ptr = std::shared_ptr<WebRtcTransport>; using Ptr = std::shared_ptr<WebRtcTransport>;
WebRtcTransport(); WebRtcTransport();
...@@ -22,7 +22,7 @@ public: ...@@ -22,7 +22,7 @@ public:
/// \param buf /// \param buf
/// \param len /// \param len
/// \param remote_address /// \param remote_address
void OnInputDataPacket(char *buf, size_t len, struct sockaddr_in *remote_address); void OnInputDataPacket(char *buf, size_t len, RTC::TransportTuple *tuple);
/// 发送rtp /// 发送rtp
/// \param buf /// \param buf
...@@ -30,6 +30,31 @@ public: ...@@ -30,6 +30,31 @@ public:
void WritRtpPacket(char *buf, size_t len); void WritRtpPacket(char *buf, size_t len);
protected: protected:
// dtls相关的回调
void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override {};
void OnDtlsTransportConnected(
const RTC::DtlsTransport *dtlsTransport,
RTC::SrtpSession::CryptoSuite srtpCryptoSuite,
uint8_t *srtpLocalKey,
size_t srtpLocalKeyLen,
uint8_t *srtpRemoteKey,
size_t srtpRemoteKeyLen,
std::string &remoteCert) override;
void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override {};
void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override {};
void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override;
void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override {};
protected:
//ice相关的回调
void OnIceServerSendStunPacket(const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) override;
void OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) override;
void OnIceServerConnected(const RTC::IceServer *iceServer) override;
void OnIceServerCompleted(const RTC::IceServer *iceServer) override;
void OnIceServerDisconnected(const RTC::IceServer *iceServer) override;
protected:
/// 输出udp数据 /// 输出udp数据
/// \param buf /// \param buf
/// \param len /// \param len
...@@ -39,17 +64,14 @@ protected: ...@@ -39,17 +64,14 @@ protected:
virtual uint16_t getPort() const = 0; virtual uint16_t getPort() const = 0;
virtual std::string getIP() const = 0; virtual std::string getIP() const = 0;
virtual int getPayloadType() const = 0; virtual int getPayloadType() const = 0;
virtual void onIceConnected() = 0; virtual void onDtlsConnected() = 0;
virtual void onDtlsCompleted() = 0;
private: private:
void OnIceServerCompleted(); void onWrite(const char *buf, size_t len);
void OnDtlsCompleted(std::string client_key, std::string server_key, RTC::CryptoSuite srtp_crypto_suite);
void WritePacket(char *buf, size_t len, struct sockaddr_in *remote_address = nullptr);
private: private:
IceServer::Ptr ice_server_; std::shared_ptr<RTC::IceServer> ice_server_;
DtlsTransport::Ptr dtls_transport_; std::shared_ptr<RTC::DtlsTransport> dtls_transport_;
std::shared_ptr<RTC::SrtpSession> srtp_session_; std::shared_ptr<RTC::SrtpSession> srtp_session_;
}; };
...@@ -74,8 +96,7 @@ protected: ...@@ -74,8 +96,7 @@ protected:
uint32_t getSSRC() const override; uint32_t getSSRC() const override;
uint16_t getPort() const override; uint16_t getPort() const override;
std::string getIP() const override; std::string getIP() const override;
void onIceConnected() override; void onDtlsConnected() override;
void onDtlsCompleted() override;
private: private:
Socket::Ptr _socket; Socket::Ptr _socket;
...@@ -83,16 +104,6 @@ private: ...@@ -83,16 +104,6 @@ private:
RtspMediaSource::RingType::RingReader::Ptr _reader; RtspMediaSource::RingType::RingReader::Ptr _reader;
}; };
class WebRtcManager : public std::enable_shared_from_this<WebRtcManager> {
public:
~WebRtcManager();
static WebRtcManager& Instance();
private:
WebRtcManager();
};
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
<video id="vid2" autoplay></video> <video id="vid2" autoplay></video>
<br> <br>
<p>ip_address</p> <p>ip_address</p>
<input id="input1" type="text" name="ip_address" value="http://172.26.10.29:20080/webrtc?app=live&stream=test"> <input id="input1" type="text" name="ip_address" value="https://rp.zlmediakit.com:20443/webrtc?app=live&stream=test">
<br> <br>
<button id="btn1">Call</button> <button id="btn1">Call</button>
<button id="btn3">Hang Up</button> <button id="btn3">Hang Up</button>
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论