Commit 56d6eb0f by ziyue

批量替换tab为4个空格

parent a5c3db4e
...@@ -269,9 +269,9 @@ API_EXPORT int API_CALL mk_media_input_aac(mk_media ctx, const void *data, int l ...@@ -269,9 +269,9 @@ API_EXPORT int API_CALL mk_media_input_aac(mk_media ctx, const void *data, int l
} }
API_EXPORT int API_CALL mk_media_input_pcm(mk_media ctx, void *data , int len, uint64_t pts){ API_EXPORT int API_CALL mk_media_input_pcm(mk_media ctx, void *data , int len, uint64_t pts){
assert(ctx && data && len > 0); assert(ctx && data && len > 0);
MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx; MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx;
return (*obj)->getChannel()->inputPCM((char*)data, len, pts); return (*obj)->getChannel()->inputPCM((char*)data, len, pts);
} }
API_EXPORT int API_CALL mk_media_input_audio(mk_media ctx, const void* data, int len, uint64_t dts){ API_EXPORT int API_CALL mk_media_input_audio(mk_media ctx, const void* data, int len, uint64_t dts){
......
...@@ -33,1453 +33,1453 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ...@@ -33,1453 +33,1453 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
using namespace std; using namespace std;
#define LOG_OPENSSL_ERROR(desc) \ #define LOG_OPENSSL_ERROR(desc) \
do \ do \
{ \ { \
if (ERR_peek_error() == 0) \ if (ERR_peek_error() == 0) \
MS_ERROR("OpenSSL error [desc:'%s']", desc); \ MS_ERROR("OpenSSL error [desc:'%s']", desc); \
else \ else \
{ \ { \
int64_t err; \ int64_t err; \
while ((err = ERR_get_error()) != 0) \ while ((err = ERR_get_error()) != 0) \
{ \ { \
MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \
} \ } \
ERR_clear_error(); \ ERR_clear_error(); \
} \ } \
} while (false) } while (false)
/* Static methods for OpenSSL callbacks. */ /* Static methods for OpenSSL callbacks. */
inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/)
{ {
MS_TRACE(); MS_TRACE();
// Always valid since DTLS certificates are self-signed. // Always valid since DTLS certificates are self-signed.
return 1; return 1;
} }
inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs)
{ {
if (timerUs == 0) if (timerUs == 0)
return 100000; return 100000;
else if (timerUs >= 4000000) else if (timerUs >= 4000000)
return 4000000; return 4000000;
else else
return 2 * timerUs; return 2 * timerUs;
} }
namespace RTC namespace RTC
{ {
/* Static. */ /* Static. */
// clang-format off // clang-format off
static constexpr int DtlsMtu{ 1350 }; static constexpr int DtlsMtu{ 1350 };
// AES-HMAC: http://tools.ietf.org/html/rfc3711 // AES-HMAC: http://tools.ietf.org/html/rfc3711
static constexpr size_t SrtpMasterKeyLength{ 16 }; static constexpr size_t SrtpMasterKeyLength{ 16 };
static constexpr size_t SrtpMasterSaltLength{ 14 }; static constexpr size_t SrtpMasterSaltLength{ 14 };
static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength };
// AES-GCM: http://tools.ietf.org/html/rfc7714 // AES-GCM: http://tools.ietf.org/html/rfc7714
static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 };
static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength };
static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 };
static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength };
// clang-format on // clang-format on
/* Class variables. */ /* Class variables. */
// clang-format off // clang-format off
std::map<std::string, DtlsTransport::FingerprintAlgorithm> DtlsTransport::string2FingerprintAlgorithm = std::map<std::string, DtlsTransport::FingerprintAlgorithm> DtlsTransport::string2FingerprintAlgorithm =
{ {
{ "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 },
{ "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 },
{ "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 },
{ "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 },
{ "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 }
}; };
std::map<DtlsTransport::FingerprintAlgorithm, std::string> DtlsTransport::fingerprintAlgorithm2String = std::map<DtlsTransport::FingerprintAlgorithm, std::string> DtlsTransport::fingerprintAlgorithm2String =
{ {
{ DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" },
{ DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" },
{ DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" },
{ DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" },
{ DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" }
}; };
std::map<std::string, DtlsTransport::Role> DtlsTransport::string2Role = std::map<std::string, DtlsTransport::Role> DtlsTransport::string2Role =
{ {
{ "auto", DtlsTransport::Role::AUTO }, { "auto", DtlsTransport::Role::AUTO },
{ "client", DtlsTransport::Role::CLIENT }, { "client", DtlsTransport::Role::CLIENT },
{ "server", DtlsTransport::Role::SERVER } { "server", DtlsTransport::Role::SERVER }
}; };
std::vector<DtlsTransport::SrtpCryptoSuiteMapEntry> DtlsTransport::srtpCryptoSuites = std::vector<DtlsTransport::SrtpCryptoSuiteMapEntry> DtlsTransport::srtpCryptoSuites =
{ {
{ RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" },
{ RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" },
{ RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" },
{ RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" }
}; };
// clang-format on // clang-format on
INSTANCE_IMP(DtlsTransport::DtlsEnvironment); INSTANCE_IMP(DtlsTransport::DtlsEnvironment);
/* Class methods. */ /* Class methods. */
DtlsTransport::DtlsEnvironment::DtlsEnvironment() DtlsTransport::DtlsEnvironment::DtlsEnvironment()
{ {
MS_TRACE(); MS_TRACE();
// Generate a X509 certificate and private key (unless PEM files are provided). // Generate a X509 certificate and private key (unless PEM files are provided).
if (true /* if (true /*
Settings::configuration.dtlsCertificateFile.empty() || Settings::configuration.dtlsCertificateFile.empty() ||
Settings::configuration.dtlsPrivateKeyFile.empty()*/) Settings::configuration.dtlsPrivateKeyFile.empty()*/)
{ {
GenerateCertificateAndPrivateKey(); GenerateCertificateAndPrivateKey();
} }
else else
{ {
ReadCertificateAndPrivateKeyFromFiles(); ReadCertificateAndPrivateKeyFromFiles();
} }
// Create a global SSL_CTX. // Create a global SSL_CTX.
CreateSslCtx(); CreateSslCtx();
// Generate certificate fingerprints. // Generate certificate fingerprints.
GenerateFingerprints(); GenerateFingerprints();
} }
DtlsTransport::DtlsEnvironment::~DtlsEnvironment() DtlsTransport::DtlsEnvironment::~DtlsEnvironment()
{ {
MS_TRACE(); MS_TRACE();
if (privateKey) if (privateKey)
EVP_PKEY_free(privateKey); EVP_PKEY_free(privateKey);
if (certificate) if (certificate)
X509_free(certificate); X509_free(certificate);
if (sslCtx) if (sslCtx)
SSL_CTX_free(sslCtx); SSL_CTX_free(sslCtx);
} }
void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey()
{ {
MS_TRACE(); MS_TRACE();
int ret{ 0 }; int ret{ 0 };
EC_KEY* ecKey{ nullptr }; EC_KEY* ecKey{ nullptr };
X509_NAME* certName{ nullptr }; X509_NAME* certName{ nullptr };
std::string subject = std::string subject =
std::string("mediasoup") + to_string(rand() % 999999 + 100000); std::string("mediasoup") + to_string(rand() % 999999 + 100000);
// Create key with curve. // Create key with curve.
ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
if (!ecKey) if (!ecKey)
{ {
LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed");
goto error; goto error;
} }
EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE);
// NOTE: This can take some time. // NOTE: This can take some time.
ret = EC_KEY_generate_key(ecKey); ret = EC_KEY_generate_key(ecKey);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed");
goto error; goto error;
} }
// Create a private key object. // Create a private key object.
privateKey = EVP_PKEY_new(); privateKey = EVP_PKEY_new();
if (!privateKey) if (!privateKey)
{ {
LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); LOG_OPENSSL_ERROR("EVP_PKEY_new() failed");
goto error; goto error;
} }
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed");
goto error; goto error;
} }
// The EC key now belongs to the private key, so don't clean it up separately. // The EC key now belongs to the private key, so don't clean it up separately.
ecKey = nullptr; ecKey = nullptr;
// Create the X509 certificate. // Create the X509 certificate.
certificate = X509_new(); certificate = X509_new();
if (!certificate) if (!certificate)
{ {
LOG_OPENSSL_ERROR("X509_new() failed"); LOG_OPENSSL_ERROR("X509_new() failed");
goto error; goto error;
} }
// Set version 3 (note that 0 means version 1). // Set version 3 (note that 0 means version 1).
X509_set_version(certificate, 2); X509_set_version(certificate, 2);
// Set serial number (avoid default 0). // Set serial number (avoid default 0).
ASN1_INTEGER_set( ASN1_INTEGER_set(
X509_get_serialNumber(certificate), X509_get_serialNumber(certificate),
static_cast<uint64_t>(rand() % 999999 + 100000)); static_cast<uint64_t>(rand() % 999999 + 100000));
// Set valid period. // Set valid period.
X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years.
X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years.
// Set the public key for the certificate using the key. // Set the public key for the certificate using the key.
ret = X509_set_pubkey(certificate, privateKey); ret = X509_set_pubkey(certificate, privateKey);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); LOG_OPENSSL_ERROR("X509_set_pubkey() failed");
goto error; goto error;
} }
// Set certificate fields. // Set certificate fields.
certName = X509_get_subject_name(certificate); certName = X509_get_subject_name(certificate);
if (!certName) if (!certName)
{ {
LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); LOG_OPENSSL_ERROR("X509_get_subject_name() failed");
goto error; goto error;
} }
X509_NAME_add_entry_by_txt( X509_NAME_add_entry_by_txt(
certName, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0); certName, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
X509_NAME_add_entry_by_txt( X509_NAME_add_entry_by_txt(
certName, "CN", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0); certName, "CN", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
// It is self-signed so set the issuer name to be the same as the subject. // It is self-signed so set the issuer name to be the same as the subject.
ret = X509_set_issuer_name(certificate, certName); ret = X509_set_issuer_name(certificate, certName);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); LOG_OPENSSL_ERROR("X509_set_issuer_name() failed");
goto error; goto error;
} }
// Sign the certificate with its own private key. // Sign the certificate with its own private key.
ret = X509_sign(certificate, privateKey, EVP_sha1()); ret = X509_sign(certificate, privateKey, EVP_sha1());
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("X509_sign() failed"); LOG_OPENSSL_ERROR("X509_sign() failed");
goto error; goto error;
} }
return; return;
error: error:
if (ecKey) if (ecKey)
EC_KEY_free(ecKey); EC_KEY_free(ecKey);
if (privateKey) if (privateKey)
EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key.
if (certificate) if (certificate)
X509_free(certificate); X509_free(certificate);
MS_THROW_ERROR("DTLS certificate and private key generation failed"); MS_THROW_ERROR("DTLS certificate and private key generation failed");
} }
void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles() void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles()
{ {
#if 0 #if 0
MS_TRACE(); MS_TRACE();
FILE* file{ nullptr }; FILE* file{ nullptr };
file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r"); file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r");
if (!file) if (!file)
{ {
MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno)); MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno));
goto error; goto error;
} }
certificate = PEM_read_X509(file, nullptr, nullptr, nullptr); certificate = PEM_read_X509(file, nullptr, nullptr, nullptr);
if (!certificate) if (!certificate)
{ {
LOG_OPENSSL_ERROR("PEM_read_X509() failed"); LOG_OPENSSL_ERROR("PEM_read_X509() failed");
goto error; goto error;
} }
fclose(file); fclose(file);
file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r"); file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r");
if (!file) if (!file)
{ {
MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno)); MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno));
goto error; goto error;
} }
privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr); privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr);
if (!privateKey) if (!privateKey)
{ {
LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed"); LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed");
goto error; goto error;
} }
fclose(file); fclose(file);
return; return;
error: error:
MS_THROW_ERROR("error reading DTLS certificate and private key PEM files"); MS_THROW_ERROR("error reading DTLS certificate and private key PEM files");
#endif #endif
} }
void DtlsTransport::DtlsEnvironment::CreateSslCtx() void DtlsTransport::DtlsEnvironment::CreateSslCtx()
{ {
MS_TRACE(); MS_TRACE();
std::string dtlsSrtpCryptoSuites; std::string dtlsSrtpCryptoSuites;
int ret; int ret;
/* Set the global DTLS context. */ /* Set the global DTLS context. */
// Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0).
sslCtx = SSL_CTX_new(DTLS_method()); sslCtx = SSL_CTX_new(DTLS_method());
if (!sslCtx) if (!sslCtx)
{ {
LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); LOG_OPENSSL_ERROR("SSL_CTX_new() failed");
goto error; goto error;
} }
ret = SSL_CTX_use_certificate(sslCtx, certificate); ret = SSL_CTX_use_certificate(sslCtx, certificate);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed");
goto error; goto error;
} }
ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed");
goto error; goto error;
} }
ret = SSL_CTX_check_private_key(sslCtx); ret = SSL_CTX_check_private_key(sslCtx);
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed");
goto error; goto error;
} }
// Set options. // Set options.
SSL_CTX_set_options( SSL_CTX_set_options(
sslCtx, sslCtx,
SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE |
SSL_OP_NO_QUERY_MTU); SSL_OP_NO_QUERY_MTU);
// Don't use sessions cache. // Don't use sessions cache.
SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF);
// Read always as much into the buffer as possible. // Read always as much into the buffer as possible.
// NOTE: This is the default for DTLS, but a bug in non latest OpenSSL // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL
// versions makes this call required. // versions makes this call required.
SSL_CTX_set_read_ahead(sslCtx, 1); SSL_CTX_set_read_ahead(sslCtx, 1);
SSL_CTX_set_verify_depth(sslCtx, 4); SSL_CTX_set_verify_depth(sslCtx, 4);
// Require certificate from peer. // Require certificate from peer.
SSL_CTX_set_verify( SSL_CTX_set_verify(
sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify);
// Set SSL info callback. // Set SSL info callback.
SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){
static_cast<RTC::DtlsTransport*>(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); static_cast<RTC::DtlsTransport*>(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret);
}); });
// Set ciphers. // Set ciphers.
ret = SSL_CTX_set_cipher_list( ret = SSL_CTX_set_cipher_list(
sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK");
if (ret == 0) if (ret == 0)
{ {
LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed");
goto error; goto error;
} }
// Enable ECDH ciphers. // Enable ECDH ciphers.
// DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters
// NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458
// NOTE: https://bugs.ruby-lang.org/issues/12324 // NOTE: https://bugs.ruby-lang.org/issues/12324
// For OpenSSL >= 1.0.2. // For OpenSSL >= 1.0.2.
SSL_CTX_set_ecdh_auto(sslCtx, 1); SSL_CTX_set_ecdh_auto(sslCtx, 1);
// Set the "use_srtp" DTLS extension. // Set the "use_srtp" DTLS extension.
for (auto it = DtlsTransport::srtpCryptoSuites.begin(); for (auto it = DtlsTransport::srtpCryptoSuites.begin();
it != DtlsTransport::srtpCryptoSuites.end(); it != DtlsTransport::srtpCryptoSuites.end();
++it) ++it)
{ {
if (it != DtlsTransport::srtpCryptoSuites.begin()) if (it != DtlsTransport::srtpCryptoSuites.begin())
dtlsSrtpCryptoSuites += ":"; dtlsSrtpCryptoSuites += ":";
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it);
dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; dtlsSrtpCryptoSuites += cryptoSuiteEntry->name;
} }
MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str());
// NOTE: This function returns 0 on success. // NOTE: This function returns 0 on success.
ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str());
if (ret != 0) if (ret != 0)
{ {
MS_ERROR( MS_ERROR(
"SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str());
LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed");
goto error; goto error;
} }
return; return;
error: error:
if (sslCtx) if (sslCtx)
{ {
SSL_CTX_free(sslCtx); SSL_CTX_free(sslCtx);
sslCtx = nullptr; sslCtx = nullptr;
} }
MS_THROW_ERROR("SSL context creation failed"); MS_THROW_ERROR("SSL context creation failed");
} }
void DtlsTransport::DtlsEnvironment::GenerateFingerprints() void DtlsTransport::DtlsEnvironment::GenerateFingerprints()
{ {
MS_TRACE(); MS_TRACE();
for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) for (auto& kv : DtlsTransport::string2FingerprintAlgorithm)
{ {
const std::string& algorithmString = kv.first; const std::string& algorithmString = kv.first;
FingerprintAlgorithm algorithm = kv.second; FingerprintAlgorithm algorithm = kv.second;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 }; unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction; const EVP_MD* hashFunction;
int ret; int ret;
switch (algorithm) switch (algorithm)
{ {
case FingerprintAlgorithm::SHA1: case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1(); hashFunction = EVP_sha1();
break; break;
case FingerprintAlgorithm::SHA224: case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224(); hashFunction = EVP_sha224();
break; break;
case FingerprintAlgorithm::SHA256: case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256(); hashFunction = EVP_sha256();
break; break;
case FingerprintAlgorithm::SHA384: case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384(); hashFunction = EVP_sha384();
break; break;
case FingerprintAlgorithm::SHA512: case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512(); hashFunction = EVP_sha512();
break; break;
default: default:
MS_THROW_ERROR("unknown algorithm"); MS_THROW_ERROR("unknown algorithm");
} }
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0) if (ret == 0)
{ {
MS_ERROR("X509_digest() failed"); MS_ERROR("X509_digest() failed");
MS_THROW_ERROR("Fingerprints generation failed"); MS_THROW_ERROR("Fingerprints generation failed");
} }
// Convert to hexadecimal format in uppercase with colons. // Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i) for (unsigned int i{ 0 }; i < size; ++i)
{ {
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
} }
hexFingerprint[(size * 3) - 1] = '\0'; hexFingerprint[(size * 3) - 1] = '\0';
MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint);
// Store it in the vector. // Store it in the vector.
DtlsTransport::Fingerprint fingerprint; DtlsTransport::Fingerprint fingerprint;
fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString);
fingerprint.value = hexFingerprint; fingerprint.value = hexFingerprint;
localFingerprints.push_back(fingerprint); localFingerprints.push_back(fingerprint);
} }
} }
/* Instance methods. */ /* Instance methods. */
DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener)
{ {
MS_TRACE(); MS_TRACE();
env = DtlsEnvironment::Instance().shared_from_this(); env = DtlsEnvironment::Instance().shared_from_this();
/* Set SSL. */ /* Set SSL. */
this->ssl = SSL_new(env->sslCtx); this->ssl = SSL_new(env->sslCtx);
if (!this->ssl) if (!this->ssl)
{ {
LOG_OPENSSL_ERROR("SSL_new() failed"); LOG_OPENSSL_ERROR("SSL_new() failed");
goto error; goto error;
} }
// Set this as custom data. // Set this as custom data.
SSL_set_ex_data(this->ssl, 0, static_cast<void*>(this)); SSL_set_ex_data(this->ssl, 0, static_cast<void*>(this));
this->sslBioFromNetwork = BIO_new(BIO_s_mem()); this->sslBioFromNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioFromNetwork) if (!this->sslBioFromNetwork)
{ {
LOG_OPENSSL_ERROR("BIO_new() failed"); LOG_OPENSSL_ERROR("BIO_new() failed");
SSL_free(this->ssl); SSL_free(this->ssl);
goto error; goto error;
} }
this->sslBioToNetwork = BIO_new(BIO_s_mem()); this->sslBioToNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioToNetwork) if (!this->sslBioToNetwork)
{ {
LOG_OPENSSL_ERROR("BIO_new() failed"); LOG_OPENSSL_ERROR("BIO_new() failed");
BIO_free(this->sslBioFromNetwork); BIO_free(this->sslBioFromNetwork);
SSL_free(this->ssl); SSL_free(this->ssl);
goto error; goto error;
} }
SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork);
// Set the MTU so that we don't send packets that are too large with no fragmentation. // Set the MTU so that we don't send packets that are too large with no fragmentation.
SSL_set_mtu(this->ssl, DtlsMtu); SSL_set_mtu(this->ssl, DtlsMtu);
DTLS_set_link_mtu(this->ssl, DtlsMtu); DTLS_set_link_mtu(this->ssl, DtlsMtu);
// Set callback handler for setting DTLS timer interval. // Set callback handler for setting DTLS timer interval.
DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); DTLS_set_timer_cb(this->ssl, onSslDtlsTimer);
return; return;
error: error:
// NOTE: At this point SSL_set_bio() was not called so we must free BIOs as // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as
// well. // well.
if (this->sslBioFromNetwork) if (this->sslBioFromNetwork)
BIO_free(this->sslBioFromNetwork); BIO_free(this->sslBioFromNetwork);
if (this->sslBioToNetwork) if (this->sslBioToNetwork)
BIO_free(this->sslBioToNetwork); BIO_free(this->sslBioToNetwork);
if (this->ssl) if (this->ssl)
SSL_free(this->ssl); SSL_free(this->ssl);
// NOTE: If this is not catched by the caller the program will abort, but // NOTE: If this is not catched by the caller the program will abort, but
// this should never happen. // this should never happen.
MS_THROW_ERROR("DtlsTransport instance creation failed"); MS_THROW_ERROR("DtlsTransport instance creation failed");
} }
DtlsTransport::~DtlsTransport() DtlsTransport::~DtlsTransport()
{ {
MS_TRACE(); MS_TRACE();
if (IsRunning()) if (IsRunning())
{ {
// Send close alert to the peer. // Send close alert to the peer.
SSL_shutdown(this->ssl); SSL_shutdown(this->ssl);
SendPendingOutgoingDtlsData(); SendPendingOutgoingDtlsData();
} }
if (this->ssl) if (this->ssl)
{ {
SSL_free(this->ssl); SSL_free(this->ssl);
this->ssl = nullptr; this->ssl = nullptr;
this->sslBioFromNetwork = nullptr; this->sslBioFromNetwork = nullptr;
this->sslBioToNetwork = nullptr; this->sslBioToNetwork = nullptr;
} }
// Close the DTLS timer. // Close the DTLS timer.
this->timer = nullptr; this->timer = nullptr;
} }
void DtlsTransport::Dump() const void DtlsTransport::Dump() const
{ {
MS_TRACE(); MS_TRACE();
std::string state{ "new" }; std::string state{ "new" };
std::string role{ "none " }; std::string role{ "none " };
switch (this->state) switch (this->state)
{ {
case DtlsState::CONNECTING: case DtlsState::CONNECTING:
state = "connecting"; state = "connecting";
break; break;
case DtlsState::CONNECTED: case DtlsState::CONNECTED:
state = "connected"; state = "connected";
break; break;
case DtlsState::FAILED: case DtlsState::FAILED:
state = "failed"; state = "failed";
break; break;
case DtlsState::CLOSED: case DtlsState::CLOSED:
state = "closed"; state = "closed";
break; break;
default:; default:;
} }
switch (this->localRole) switch (this->localRole)
{ {
case Role::AUTO: case Role::AUTO:
role = "auto"; role = "auto";
break; break;
case Role::SERVER: case Role::SERVER:
role = "server"; role = "server";
break; break;
case Role::CLIENT: case Role::CLIENT:
role = "client"; role = "client";
break; break;
default:; default:;
} }
MS_DUMP("<DtlsTransport>"); MS_DUMP("<DtlsTransport>");
MS_DUMP(" state : %s", state.c_str()); MS_DUMP(" state : %s", state.c_str());
MS_DUMP(" role : %s", role.c_str()); MS_DUMP(" role : %s", role.c_str());
MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no");
MS_DUMP("</DtlsTransport>"); MS_DUMP("</DtlsTransport>");
} }
void DtlsTransport::Run(Role localRole) void DtlsTransport::Run(Role localRole)
{ {
MS_TRACE(); MS_TRACE();
MS_ASSERT( MS_ASSERT(
localRole == Role::CLIENT || localRole == Role::SERVER, localRole == Role::CLIENT || localRole == Role::SERVER,
"local DTLS role must be 'client' or 'server'"); "local DTLS role must be 'client' or 'server'");
Role previousLocalRole = this->localRole; Role previousLocalRole = this->localRole;
if (localRole == previousLocalRole) if (localRole == previousLocalRole)
{ {
MS_ERROR("same local DTLS role provided, doing nothing"); MS_ERROR("same local DTLS role provided, doing nothing");
return; return;
} }
// If the previous local DTLS role was 'client' or 'server' do reset. // If the previous local DTLS role was 'client' or 'server' do reset.
if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER)
{ {
MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change");
Reset(); Reset();
} }
// Update local role. // Update local role.
this->localRole = localRole; this->localRole = localRole;
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::CONNECTING; this->state = DtlsState::CONNECTING;
this->listener->OnDtlsTransportConnecting(this); this->listener->OnDtlsTransportConnecting(this);
switch (this->localRole) switch (this->localRole)
{ {
case Role::CLIENT: case Role::CLIENT:
{ {
MS_DEBUG_TAG(dtls, "running [role:client]"); MS_DEBUG_TAG(dtls, "running [role:client]");
SSL_set_connect_state(this->ssl); SSL_set_connect_state(this->ssl);
SSL_do_handshake(this->ssl); SSL_do_handshake(this->ssl);
SendPendingOutgoingDtlsData(); SendPendingOutgoingDtlsData();
SetTimeout(); SetTimeout();
break; break;
} }
case Role::SERVER: case Role::SERVER:
{ {
MS_DEBUG_TAG(dtls, "running [role:server]"); MS_DEBUG_TAG(dtls, "running [role:server]");
SSL_set_accept_state(this->ssl); SSL_set_accept_state(this->ssl);
SSL_do_handshake(this->ssl); SSL_do_handshake(this->ssl);
break; break;
} }
default: default:
{ {
MS_ABORT("invalid local DTLS role"); MS_ABORT("invalid local DTLS role");
} }
} }
} }
bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint)
{ {
MS_TRACE(); MS_TRACE();
MS_ASSERT( MS_ASSERT(
fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided");
this->remoteFingerprint = fingerprint; this->remoteFingerprint = fingerprint;
// The remote fingerpring may have been set after DTLS handshake was done, // The remote fingerpring may have been set after DTLS handshake was done,
// so we may need to process it now. // so we may need to process it now.
if (this->handshakeDone && this->state != DtlsState::CONNECTED) if (this->handshakeDone && this->state != DtlsState::CONNECTED)
{ {
MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); MS_DEBUG_TAG(dtls, "handshake already done, processing it right now");
return ProcessHandshake(); return ProcessHandshake();
} }
return true; return true;
} }
void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len)
{ {
MS_TRACE(); MS_TRACE();
int written; int written;
int read; int read;
if (!IsRunning()) if (!IsRunning())
{ {
MS_ERROR("cannot process data while not running"); MS_ERROR("cannot process data while not running");
return; return;
} }
// Write the received DTLS data into the sslBioFromNetwork. // Write the received DTLS data into the sslBioFromNetwork.
written = written =
BIO_write(this->sslBioFromNetwork, static_cast<const void*>(data), static_cast<int>(len)); BIO_write(this->sslBioFromNetwork, static_cast<const void*>(data), static_cast<int>(len));
if (written != static_cast<int>(len)) if (written != static_cast<int>(len))
{ {
MS_WARN_TAG( MS_WARN_TAG(
dtls, dtls,
"OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)",
static_cast<size_t>(written), static_cast<size_t>(written),
len); len);
} }
// Must call SSL_read() to process received DTLS data. // Must call SSL_read() to process received DTLS data.
read = SSL_read(this->ssl, static_cast<void*>(DtlsTransport::sslReadBuffer), SslReadBufferSize); read = SSL_read(this->ssl, static_cast<void*>(DtlsTransport::sslReadBuffer), SslReadBufferSize);
// Send data if it's ready. // Send data if it's ready.
SendPendingOutgoingDtlsData(); SendPendingOutgoingDtlsData();
// Check SSL status and return if it is bad/closed. // Check SSL status and return if it is bad/closed.
if (!CheckStatus(read)) if (!CheckStatus(read))
return; return;
// Set/update the DTLS timeout. // Set/update the DTLS timeout.
if (!SetTimeout()) if (!SetTimeout())
return; return;
// Application data received. Notify to the listener. // Application data received. Notify to the listener.
if (read > 0) if (read > 0)
{ {
// It is allowed to receive DTLS data even before validating remote fingerprint. // It is allowed to receive DTLS data even before validating remote fingerprint.
if (!this->handshakeDone) if (!this->handshakeDone)
{ {
MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done");
return; return;
} }
// Notify the listener. // Notify the listener.
this->listener->OnDtlsTransportApplicationDataReceived( this->listener->OnDtlsTransportApplicationDataReceived(
this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast<size_t>(read)); this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast<size_t>(read));
} }
} }
void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len)
{ {
MS_TRACE(); MS_TRACE();
// We cannot send data to the peer if its remote fingerprint is not validated. // We cannot send data to the peer if its remote fingerprint is not validated.
if (this->state != DtlsState::CONNECTED) if (this->state != DtlsState::CONNECTED)
{ {
MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected");
return; return;
} }
if (len == 0) if (len == 0)
{ {
MS_WARN_TAG(dtls, "ignoring 0 length data"); MS_WARN_TAG(dtls, "ignoring 0 length data");
return; return;
} }
int written; int written;
written = SSL_write(this->ssl, static_cast<const void*>(data), static_cast<int>(len)); written = SSL_write(this->ssl, static_cast<const void*>(data), static_cast<int>(len));
if (written < 0) if (written < 0)
{ {
LOG_OPENSSL_ERROR("SSL_write() failed"); LOG_OPENSSL_ERROR("SSL_write() failed");
if (!CheckStatus(written)) if (!CheckStatus(written))
return; return;
} }
else if (written != static_cast<int>(len)) else if (written != static_cast<int>(len))
{ {
MS_WARN_TAG( MS_WARN_TAG(
dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len);
} }
// Send data. // Send data.
SendPendingOutgoingDtlsData(); SendPendingOutgoingDtlsData();
} }
void DtlsTransport::Reset() void DtlsTransport::Reset()
{ {
MS_TRACE(); MS_TRACE();
int ret; int ret;
if (!IsRunning()) if (!IsRunning())
return; return;
MS_WARN_TAG(dtls, "resetting DTLS transport"); MS_WARN_TAG(dtls, "resetting DTLS transport");
// Stop the DTLS timer. // Stop the DTLS timer.
this->timer = nullptr; this->timer = nullptr;
// We need to reset the SSL instance so we need to "shutdown" it, but we // We need to reset the SSL instance so we need to "shutdown" it, but we
// don't want to send a Close Alert to the peer, so just don't call // don't want to send a Close Alert to the peer, so just don't call
// SendPendingOutgoingDTLSData(). // SendPendingOutgoingDTLSData().
SSL_shutdown(this->ssl); SSL_shutdown(this->ssl);
this->localRole = Role::NONE; this->localRole = Role::NONE;
this->state = DtlsState::NEW; this->state = DtlsState::NEW;
this->handshakeDone = false; this->handshakeDone = false;
this->handshakeDoneNow = false; this->handshakeDoneNow = false;
// Reset SSL status. // Reset SSL status.
// NOTE: For this to properly work, SSL_shutdown() must be called before. // NOTE: For this to properly work, SSL_shutdown() must be called before.
// NOTE: This may fail if not enough DTLS handshake data has been received, // NOTE: This may fail if not enough DTLS handshake data has been received,
// but we don't care so just clear the error queue. // but we don't care so just clear the error queue.
ret = SSL_clear(this->ssl); ret = SSL_clear(this->ssl);
if (ret == 0) if (ret == 0)
ERR_clear_error(); ERR_clear_error();
} }
inline bool DtlsTransport::CheckStatus(int returnCode) inline bool DtlsTransport::CheckStatus(int returnCode)
{ {
MS_TRACE(); MS_TRACE();
int err; int err;
bool wasHandshakeDone = this->handshakeDone; bool wasHandshakeDone = this->handshakeDone;
err = SSL_get_error(this->ssl, returnCode); err = SSL_get_error(this->ssl, returnCode);
switch (err) switch (err)
{ {
case SSL_ERROR_NONE: case SSL_ERROR_NONE:
break; break;
case SSL_ERROR_SSL: case SSL_ERROR_SSL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL");
break; break;
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
break; break;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE");
break; break;
case SSL_ERROR_WANT_X509_LOOKUP: case SSL_ERROR_WANT_X509_LOOKUP:
MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP");
break; break;
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL");
break; break;
case SSL_ERROR_ZERO_RETURN: case SSL_ERROR_ZERO_RETURN:
break; break;
case SSL_ERROR_WANT_CONNECT: case SSL_ERROR_WANT_CONNECT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT");
break; break;
case SSL_ERROR_WANT_ACCEPT: case SSL_ERROR_WANT_ACCEPT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT");
break; break;
default: default:
MS_WARN_TAG(dtls, "SSL status: unknown error"); MS_WARN_TAG(dtls, "SSL status: unknown error");
} }
// Check if the handshake (or re-handshake) has been done right now. // Check if the handshake (or re-handshake) has been done right now.
if (this->handshakeDoneNow) if (this->handshakeDoneNow)
{ {
this->handshakeDoneNow = false; this->handshakeDoneNow = false;
this->handshakeDone = true; this->handshakeDone = true;
// Stop the timer. // Stop the timer.
this->timer = nullptr; this->timer = nullptr;
// Process the handshake just once (ignore if DTLS renegotiation). // Process the handshake just once (ignore if DTLS renegotiation).
if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE)
return ProcessHandshake(); return ProcessHandshake();
return true; return true;
} }
// Check if the peer sent close alert or a fatal error happened. // Check if the peer sent close alert or a fatal error happened.
else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL)
{ {
if (this->state == DtlsState::CONNECTED) if (this->state == DtlsState::CONNECTED)
{ {
MS_DEBUG_TAG(dtls, "disconnected"); MS_DEBUG_TAG(dtls, "disconnected");
Reset(); Reset();
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::CLOSED; this->state = DtlsState::CLOSED;
this->listener->OnDtlsTransportClosed(this); this->listener->OnDtlsTransportClosed(this);
} }
else else
{ {
MS_WARN_TAG(dtls, "connection failed"); MS_WARN_TAG(dtls, "connection failed");
Reset(); Reset();
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::FAILED; this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this); this->listener->OnDtlsTransportFailed(this);
} }
return false; return false;
} }
else else
{ {
return true; return true;
} }
} }
inline void DtlsTransport::SendPendingOutgoingDtlsData() inline void DtlsTransport::SendPendingOutgoingDtlsData()
{ {
MS_TRACE(); MS_TRACE();
if (BIO_eof(this->sslBioToNetwork)) if (BIO_eof(this->sslBioToNetwork))
return; return;
int64_t read; int64_t read;
char* data{ nullptr }; char* data{ nullptr };
read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT
if (read <= 0) if (read <= 0)
return; return;
MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read);
// Notify the listener. // Notify the listener.
this->listener->OnDtlsTransportSendData( this->listener->OnDtlsTransportSendData(
this, reinterpret_cast<uint8_t*>(data), static_cast<size_t>(read)); this, reinterpret_cast<uint8_t*>(data), static_cast<size_t>(read));
// Clear the BIO buffer. // Clear the BIO buffer.
// NOTE: the (void) avoids the -Wunused-value warning. // NOTE: the (void) avoids the -Wunused-value warning.
(void)BIO_reset(this->sslBioToNetwork); (void)BIO_reset(this->sslBioToNetwork);
} }
inline bool DtlsTransport::SetTimeout() inline bool DtlsTransport::SetTimeout()
{ {
MS_TRACE(); MS_TRACE();
MS_ASSERT( MS_ASSERT(
this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED,
"invalid DTLS state"); "invalid DTLS state");
int64_t ret; int64_t ret;
struct timeval dtlsTimeout{ 0, 0 }; struct timeval dtlsTimeout{ 0, 0 };
uint64_t timeoutMs; uint64_t timeoutMs;
// NOTE: If ret == 0 then ignore the value in dtlsTimeout. // NOTE: If ret == 0 then ignore the value in dtlsTimeout.
// NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev.
ret = DTLSv1_get_timeout(this->ssl, static_cast<void*>(&dtlsTimeout)); // NOLINT ret = DTLSv1_get_timeout(this->ssl, static_cast<void*>(&dtlsTimeout)); // NOLINT
if (ret == 0) if (ret == 0)
return true; return true;
timeoutMs = (dtlsTimeout.tv_sec * static_cast<uint64_t>(1000)) + (dtlsTimeout.tv_usec / 1000); timeoutMs = (dtlsTimeout.tv_sec * static_cast<uint64_t>(1000)) + (dtlsTimeout.tv_usec / 1000);
if (timeoutMs == 0) if (timeoutMs == 0)
{ {
return true; return true;
} }
else if (timeoutMs < 30000) else if (timeoutMs < 30000)
{ {
MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs);
weak_ptr<DtlsTransport> weak_self = shared_from_this(); weak_ptr<DtlsTransport> weak_self = shared_from_this();
this->timer = std::make_shared<Timer>(timeoutMs / 1000.0f, [weak_self](){ this->timer = std::make_shared<Timer>(timeoutMs / 1000.0f, [weak_self](){
auto strong_self = weak_self.lock(); auto strong_self = weak_self.lock();
if(strong_self){ if(strong_self){
strong_self->OnTimer(); strong_self->OnTimer();
} }
return true; return true;
}, this->poller); }, this->poller);
return true; return true;
} }
// NOTE: Don't start the timer again if the timeout is greater than 30 seconds. // NOTE: Don't start the timer again if the timeout is greater than 30 seconds.
else else
{ {
MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs);
Reset(); Reset();
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::FAILED; this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this); this->listener->OnDtlsTransportFailed(this);
return false; return false;
} }
} }
inline bool DtlsTransport::ProcessHandshake() inline bool DtlsTransport::ProcessHandshake()
{ {
MS_TRACE(); MS_TRACE();
MS_ASSERT(this->handshakeDone, "handshake not done yet"); MS_ASSERT(this->handshakeDone, "handshake not done yet");
MS_ASSERT( MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
// Validate the remote fingerprint. // Validate the remote fingerprint.
if (!CheckRemoteFingerprint()) if (!CheckRemoteFingerprint())
{ {
Reset(); Reset();
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::FAILED; this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this); this->listener->OnDtlsTransportFailed(this);
return false; return false;
} }
// Get the negotiated SRTP crypto suite. // Get the negotiated SRTP crypto suite.
RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite();
if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE)
{ {
// Extract the SRTP keys (will notify the listener with them). // Extract the SRTP keys (will notify the listener with them).
ExtractSrtpKeys(srtpCryptoSuite); ExtractSrtpKeys(srtpCryptoSuite);
return true; return true;
} }
// NOTE: We assume that "use_srtp" DTLS extension is required even if // NOTE: We assume that "use_srtp" DTLS extension is required even if
// there is no audio/video. // there is no audio/video.
MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated");
Reset(); Reset();
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::FAILED; this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this); this->listener->OnDtlsTransportFailed(this);
return false; return false;
} }
inline bool DtlsTransport::CheckRemoteFingerprint() inline bool DtlsTransport::CheckRemoteFingerprint()
{ {
MS_TRACE(); MS_TRACE();
MS_ASSERT( MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
X509* certificate; X509* certificate;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 }; unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction; const EVP_MD* hashFunction;
int ret; int ret;
certificate = SSL_get_peer_certificate(this->ssl); certificate = SSL_get_peer_certificate(this->ssl);
if (!certificate) if (!certificate)
{ {
MS_WARN_TAG(dtls, "no certificate was provided by the peer"); MS_WARN_TAG(dtls, "no certificate was provided by the peer");
return false; return false;
} }
switch (this->remoteFingerprint.algorithm) switch (this->remoteFingerprint.algorithm)
{ {
case FingerprintAlgorithm::SHA1: case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1(); hashFunction = EVP_sha1();
break; break;
case FingerprintAlgorithm::SHA224: case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224(); hashFunction = EVP_sha224();
break; break;
case FingerprintAlgorithm::SHA256: case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256(); hashFunction = EVP_sha256();
break; break;
case FingerprintAlgorithm::SHA384: case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384(); hashFunction = EVP_sha384();
break; break;
case FingerprintAlgorithm::SHA512: case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512(); hashFunction = EVP_sha512();
break; break;
default: default:
MS_ABORT("unknown algorithm"); MS_ABORT("unknown algorithm");
} }
// Compare the remote fingerprint with the value given via signaling. // Compare the remote fingerprint with the value given via signaling.
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0) if (ret == 0)
{ {
MS_ERROR("X509_digest() failed"); MS_ERROR("X509_digest() failed");
X509_free(certificate); X509_free(certificate);
return false; return false;
} }
// Convert to hexadecimal format in uppercase with colons. // Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i) for (unsigned int i{ 0 }; i < size; ++i)
{ {
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
} }
hexFingerprint[(size * 3) - 1] = '\0'; hexFingerprint[(size * 3) - 1] = '\0';
if (this->remoteFingerprint.value != hexFingerprint) if (this->remoteFingerprint.value != hexFingerprint)
{ {
MS_WARN_TAG( MS_WARN_TAG(
dtls, dtls,
"fingerprint in the remote certificate (%s) does not match the announced one (%s)", "fingerprint in the remote certificate (%s) does not match the announced one (%s)",
hexFingerprint, hexFingerprint,
this->remoteFingerprint.value.c_str()); this->remoteFingerprint.value.c_str());
X509_free(certificate); X509_free(certificate);
return false; return false;
} }
MS_DEBUG_TAG(dtls, "valid remote fingerprint"); MS_DEBUG_TAG(dtls, "valid remote fingerprint");
// Get the remote certificate in PEM format. // Get the remote certificate in PEM format.
BIO* bio = BIO_new(BIO_s_mem()); BIO* bio = BIO_new(BIO_s_mem());
// Ensure the underlying BUF_MEM structure is also freed. // Ensure the underlying BUF_MEM structure is also freed.
// NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since
// BIO_set_close() always returns 1. // BIO_set_close() always returns 1.
(void)BIO_set_close(bio, BIO_CLOSE); (void)BIO_set_close(bio, BIO_CLOSE);
ret = PEM_write_bio_X509(bio, certificate); ret = PEM_write_bio_X509(bio, certificate);
if (ret != 1) if (ret != 1)
{ {
LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed");
X509_free(certificate); X509_free(certificate);
BIO_free(bio); BIO_free(bio);
return false; return false;
} }
BUF_MEM* mem; BUF_MEM* mem;
BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast]
if (!mem || !mem->data || mem->length == 0u) if (!mem || !mem->data || mem->length == 0u)
{ {
LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed");
X509_free(certificate); X509_free(certificate);
BIO_free(bio); BIO_free(bio);
return false; return false;
} }
this->remoteCert = std::string(mem->data, mem->length); this->remoteCert = std::string(mem->data, mem->length);
X509_free(certificate); X509_free(certificate);
BIO_free(bio); BIO_free(bio);
return true; return true;
} }
inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite)
{ {
MS_TRACE(); MS_TRACE();
size_t srtpKeyLength{ 0 }; size_t srtpKeyLength{ 0 };
size_t srtpSaltLength{ 0 }; size_t srtpSaltLength{ 0 };
size_t srtpMasterLength{ 0 }; size_t srtpMasterLength{ 0 };
switch (srtpCryptoSuite) switch (srtpCryptoSuite)
{ {
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80:
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32:
{ {
srtpKeyLength = SrtpMasterKeyLength; srtpKeyLength = SrtpMasterKeyLength;
srtpSaltLength = SrtpMasterSaltLength; srtpSaltLength = SrtpMasterSaltLength;
srtpMasterLength = SrtpMasterLength; srtpMasterLength = SrtpMasterLength;
break; break;
} }
case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM:
{ {
srtpKeyLength = SrtpAesGcm256MasterKeyLength; srtpKeyLength = SrtpAesGcm256MasterKeyLength;
srtpSaltLength = SrtpAesGcm256MasterSaltLength; srtpSaltLength = SrtpAesGcm256MasterSaltLength;
srtpMasterLength = SrtpAesGcm256MasterLength; srtpMasterLength = SrtpAesGcm256MasterLength;
break; break;
} }
case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM:
{ {
srtpKeyLength = SrtpAesGcm128MasterKeyLength; srtpKeyLength = SrtpAesGcm128MasterKeyLength;
srtpSaltLength = SrtpAesGcm128MasterSaltLength; srtpSaltLength = SrtpAesGcm128MasterSaltLength;
srtpMasterLength = SrtpAesGcm128MasterLength; srtpMasterLength = SrtpAesGcm128MasterLength;
break; break;
} }
default: default:
{ {
MS_ABORT("unknown SRTP crypto suite"); MS_ABORT("unknown SRTP crypto suite");
} }
} }
auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; auto* srtpMaterial = new uint8_t[srtpMasterLength * 2];
uint8_t* srtpLocalKey{ nullptr }; uint8_t* srtpLocalKey{ nullptr };
uint8_t* srtpLocalSalt{ nullptr }; uint8_t* srtpLocalSalt{ nullptr };
uint8_t* srtpRemoteKey{ nullptr }; uint8_t* srtpRemoteKey{ nullptr };
uint8_t* srtpRemoteSalt{ nullptr }; uint8_t* srtpRemoteSalt{ nullptr };
auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength];
auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength];
int ret; int ret;
ret = SSL_export_keying_material( ret = SSL_export_keying_material(
this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0);
MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); MS_ASSERT(ret != 0, "SSL_export_keying_material() failed");
switch (this->localRole) switch (this->localRole)
{ {
case Role::SERVER: case Role::SERVER:
{ {
srtpRemoteKey = srtpMaterial; srtpRemoteKey = srtpMaterial;
srtpLocalKey = srtpRemoteKey + srtpKeyLength; srtpLocalKey = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalKey + srtpKeyLength; srtpRemoteSalt = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; srtpLocalSalt = srtpRemoteSalt + srtpSaltLength;
break; break;
} }
case Role::CLIENT: case Role::CLIENT:
{ {
srtpLocalKey = srtpMaterial; srtpLocalKey = srtpMaterial;
srtpRemoteKey = srtpLocalKey + srtpKeyLength; srtpRemoteKey = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteKey + srtpKeyLength; srtpLocalSalt = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; srtpRemoteSalt = srtpLocalSalt + srtpSaltLength;
break; break;
} }
default: default:
{ {
MS_ABORT("no DTLS role set"); MS_ABORT("no DTLS role set");
} }
} }
// Create the SRTP local master key. // Create the SRTP local master key.
std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength);
std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength);
// Create the SRTP remote master key. // Create the SRTP remote master key.
std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength);
std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength);
// Set state and notify the listener. // Set state and notify the listener.
this->state = DtlsState::CONNECTED; this->state = DtlsState::CONNECTED;
this->listener->OnDtlsTransportConnected( this->listener->OnDtlsTransportConnected(
this, this,
srtpCryptoSuite, srtpCryptoSuite,
srtpLocalMasterKey, srtpLocalMasterKey,
srtpMasterLength, srtpMasterLength,
srtpRemoteMasterKey, srtpRemoteMasterKey,
srtpMasterLength, srtpMasterLength,
this->remoteCert); this->remoteCert);
delete[] srtpMaterial; delete[] srtpMaterial;
delete[] srtpLocalMasterKey; delete[] srtpLocalMasterKey;
delete[] srtpRemoteMasterKey; delete[] srtpRemoteMasterKey;
} }
inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite()
{ {
MS_TRACE(); MS_TRACE();
RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE;
// Ensure that the SRTP crypto suite has been negotiated. // Ensure that the SRTP crypto suite has been negotiated.
// NOTE: This is a OpenSSL type. // NOTE: This is a OpenSSL type.
SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl);
if (!sslSrtpCryptoSuite) if (!sslSrtpCryptoSuite)
return negotiatedSrtpCryptoSuite; return negotiatedSrtpCryptoSuite;
// Get the negotiated SRTP crypto suite. // Get the negotiated SRTP crypto suite.
for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites)
{ {
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite);
if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0)
{ {
MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name);
negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite;
} }
} }
MS_ASSERT( MS_ASSERT(
negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE,
"chosen SRTP crypto suite is not an available one"); "chosen SRTP crypto suite is not an available one");
return negotiatedSrtpCryptoSuite; return negotiatedSrtpCryptoSuite;
} }
inline void DtlsTransport::OnSslInfo(int where, int ret) inline void DtlsTransport::OnSslInfo(int where, int ret)
{ {
MS_TRACE(); MS_TRACE();
int w = where & -SSL_ST_MASK; int w = where & -SSL_ST_MASK;
const char* role; const char* role;
if ((w & SSL_ST_CONNECT) != 0) if ((w & SSL_ST_CONNECT) != 0)
role = "client"; role = "client";
else if ((w & SSL_ST_ACCEPT) != 0) else if ((w & SSL_ST_ACCEPT) != 0)
role = "server"; role = "server";
else else
role = "undefined"; role = "undefined";
if ((where & SSL_CB_LOOP) != 0) if ((where & SSL_CB_LOOP) != 0)
{ {
MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl));
} }
else if ((where & SSL_CB_ALERT) != 0) else if ((where & SSL_CB_ALERT) != 0)
{ {
const char* alertType; const char* alertType;
switch (*SSL_alert_type_string(ret)) switch (*SSL_alert_type_string(ret))
{ {
case 'W': case 'W':
alertType = "warning"; alertType = "warning";
break; break;
case 'F': case 'F':
alertType = "fatal"; alertType = "fatal";
break; break;
default: default:
alertType = "undefined"; alertType = "undefined";
} }
if ((where & SSL_CB_READ) != 0) if ((where & SSL_CB_READ) != 0)
{ {
MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
} }
else if ((where & SSL_CB_WRITE) != 0) else if ((where & SSL_CB_WRITE) != 0)
{ {
MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
} }
else else
{ {
MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
} }
} }
else if ((where & SSL_CB_EXIT) != 0) else if ((where & SSL_CB_EXIT) != 0)
{ {
if (ret == 0) if (ret == 0)
MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl));
else if (ret < 0) else if (ret < 0)
MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl));
} }
else if ((where & SSL_CB_HANDSHAKE_START) != 0) else if ((where & SSL_CB_HANDSHAKE_START) != 0)
{ {
MS_DEBUG_TAG(dtls, "DTLS handshake start"); MS_DEBUG_TAG(dtls, "DTLS handshake start");
} }
else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) else if ((where & SSL_CB_HANDSHAKE_DONE) != 0)
{ {
MS_DEBUG_TAG(dtls, "DTLS handshake done"); MS_DEBUG_TAG(dtls, "DTLS handshake done");
this->handshakeDoneNow = true; this->handshakeDoneNow = true;
} }
// NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon
// receipt of a close alert does not work (the flag is set after this callback). // receipt of a close alert does not work (the flag is set after this callback).
} }
inline void DtlsTransport::OnTimer() inline void DtlsTransport::OnTimer()
{ {
MS_TRACE(); MS_TRACE();
// Workaround for https://github.com/openssl/openssl/issues/7998. // Workaround for https://github.com/openssl/openssl/issues/7998.
if (this->handshakeDone) if (this->handshakeDone)
{ {
MS_DEBUG_DEV("handshake is done so return"); MS_DEBUG_DEV("handshake is done so return");
return; return;
} }
DTLSv1_handle_timeout(this->ssl); DTLSv1_handle_timeout(this->ssl);
// If required, send DTLS data. // If required, send DTLS data.
SendPendingOutgoingDtlsData(); SendPendingOutgoingDtlsData();
// Set the DTLS timer again. // Set the DTLS timer again.
SetTimeout(); SetTimeout();
} }
} // namespace RTC } // namespace RTC
...@@ -33,50 +33,50 @@ using namespace toolkit; ...@@ -33,50 +33,50 @@ using namespace toolkit;
namespace RTC namespace RTC
{ {
class DtlsTransport : public std::enable_shared_from_this<DtlsTransport> class DtlsTransport : public std::enable_shared_from_this<DtlsTransport>
{ {
public: public:
enum class DtlsState enum class DtlsState
{ {
NEW = 1, NEW = 1,
CONNECTING, CONNECTING,
CONNECTED, CONNECTED,
FAILED, FAILED,
CLOSED CLOSED
}; };
public: public:
enum class Role enum class Role
{ {
NONE = 0, NONE = 0,
AUTO = 1, AUTO = 1,
CLIENT, CLIENT,
SERVER SERVER
}; };
public: public:
enum class FingerprintAlgorithm enum class FingerprintAlgorithm
{ {
NONE = 0, NONE = 0,
SHA1 = 1, SHA1 = 1,
SHA224, SHA224,
SHA256, SHA256,
SHA384, SHA384,
SHA512 SHA512
}; };
public: public:
struct Fingerprint struct Fingerprint
{ {
FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE };
std::string value; std::string value;
}; };
private: private:
struct SrtpCryptoSuiteMapEntry struct SrtpCryptoSuiteMapEntry
{ {
RTC::SrtpSession::CryptoSuite cryptoSuite; RTC::SrtpSession::CryptoSuite cryptoSuite;
const char* name; const char* name;
}; };
class DtlsEnvironment : public std::enable_shared_from_this<DtlsEnvironment> class DtlsEnvironment : public std::enable_shared_from_this<DtlsEnvironment>
{ {
...@@ -99,154 +99,154 @@ namespace RTC ...@@ -99,154 +99,154 @@ namespace RTC
std::vector<Fingerprint> localFingerprints; std::vector<Fingerprint> localFingerprints;
}; };
public: public:
class Listener class Listener
{ {
public: public:
// DTLS is in the process of negotiating a secure connection. Incoming // DTLS is in the process of negotiating a secure connection. Incoming
// media can flow through. // media can flow through.
// NOTE: The caller MUST NOT call any method during this callback. // NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0;
// DTLS has completed negotiation of a secure connection (including DTLS-SRTP // DTLS has completed negotiation of a secure connection (including DTLS-SRTP
// and remote fingerprint verification). Outgoing media can now flow through. // and remote fingerprint verification). Outgoing media can now flow through.
// NOTE: The caller MUST NOT call any method during this callback. // NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnected( virtual void OnDtlsTransportConnected(
const RTC::DtlsTransport* dtlsTransport, const RTC::DtlsTransport* dtlsTransport,
RTC::SrtpSession::CryptoSuite srtpCryptoSuite, RTC::SrtpSession::CryptoSuite srtpCryptoSuite,
uint8_t* srtpLocalKey, uint8_t* srtpLocalKey,
size_t srtpLocalKeyLen, size_t srtpLocalKeyLen,
uint8_t* srtpRemoteKey, uint8_t* srtpRemoteKey,
size_t srtpRemoteKeyLen, size_t srtpRemoteKeyLen,
std::string& remoteCert) = 0; std::string& remoteCert) = 0;
// The DTLS connection has been closed as the result of an error (such as a // The DTLS connection has been closed as the result of an error (such as a
// DTLS alert or a failure to validate the remote fingerprint). // DTLS alert or a failure to validate the remote fingerprint).
virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0;
// The DTLS connection has been closed due to receipt of a close_notify alert. // The DTLS connection has been closed due to receipt of a close_notify alert.
virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0;
// Need to send DTLS data to the peer. // Need to send DTLS data to the peer.
virtual void OnDtlsTransportSendData( virtual void OnDtlsTransportSendData(
const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
// DTLS application data received. // DTLS application data received.
virtual void OnDtlsTransportApplicationDataReceived( virtual void OnDtlsTransportApplicationDataReceived(
const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
}; };
public: public:
static Role StringToRole(const std::string& role) static Role StringToRole(const std::string& role)
{ {
auto it = DtlsTransport::string2Role.find(role); auto it = DtlsTransport::string2Role.find(role);
if (it != DtlsTransport::string2Role.end()) if (it != DtlsTransport::string2Role.end())
return it->second; return it->second;
else else
return DtlsTransport::Role::NONE; return DtlsTransport::Role::NONE;
} }
static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint)
{ {
auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint);
if (it != DtlsTransport::string2FingerprintAlgorithm.end()) if (it != DtlsTransport::string2FingerprintAlgorithm.end())
return it->second; return it->second;
else else
return DtlsTransport::FingerprintAlgorithm::NONE; return DtlsTransport::FingerprintAlgorithm::NONE;
} }
static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint)
{ {
auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint);
return it->second; return it->second;
} }
static bool IsDtls(const uint8_t* data, size_t len) static bool IsDtls(const uint8_t* data, size_t len)
{ {
// clang-format off // clang-format off
return ( return (
// Minimum DTLS record length is 13 bytes. // Minimum DTLS record length is 13 bytes.
(len >= 13) && (len >= 13) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] > 19 && data[0] < 64) (data[0] > 19 && data[0] < 64)
); );
// clang-format on // clang-format on
} }
private: private:
static std::map<std::string, Role> string2Role; static std::map<std::string, Role> string2Role;
static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm; static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm;
static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String; static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String;
static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites; static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites;
public: public:
DtlsTransport(EventPoller::Ptr poller, Listener* listener); DtlsTransport(EventPoller::Ptr poller, Listener* listener);
~DtlsTransport(); ~DtlsTransport();
public: public:
void Dump() const; void Dump() const;
void Run(Role localRole); void Run(Role localRole);
std::vector<Fingerprint>& GetLocalFingerprints() const std::vector<Fingerprint>& GetLocalFingerprints() const
{ {
return env->localFingerprints; return env->localFingerprints;
} }
bool SetRemoteFingerprint(Fingerprint fingerprint); bool SetRemoteFingerprint(Fingerprint fingerprint);
void ProcessDtlsData(const uint8_t* data, size_t len); void ProcessDtlsData(const uint8_t* data, size_t len);
DtlsState GetState() const DtlsState GetState() const
{ {
return this->state; return this->state;
} }
Role GetLocalRole() const Role GetLocalRole() const
{ {
return this->localRole; return this->localRole;
} }
void SendApplicationData(const uint8_t* data, size_t len); void SendApplicationData(const uint8_t* data, size_t len);
private:
bool IsRunning() const
{
switch (this->state)
{
case DtlsState::NEW:
return false;
case DtlsState::CONNECTING:
case DtlsState::CONNECTED:
return true;
case DtlsState::FAILED:
case DtlsState::CLOSED:
return false;
}
// Make GCC 4.9 happy.
return false;
}
void Reset();
bool CheckStatus(int returnCode);
void SendPendingOutgoingDtlsData();
bool SetTimeout();
bool ProcessHandshake();
bool CheckRemoteFingerprint();
void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite);
RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite();
private: private:
void OnSslInfo(int where, int ret); bool IsRunning() const
void OnTimer(); {
switch (this->state)
{
case DtlsState::NEW:
return false;
case DtlsState::CONNECTING:
case DtlsState::CONNECTED:
return true;
case DtlsState::FAILED:
case DtlsState::CLOSED:
return false;
}
// Make GCC 4.9 happy.
return false;
}
void Reset();
bool CheckStatus(int returnCode);
void SendPendingOutgoingDtlsData();
bool SetTimeout();
bool ProcessHandshake();
bool CheckRemoteFingerprint();
void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite);
RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite();
private:
void OnSslInfo(int where, int ret);
void OnTimer();
private: private:
DtlsEnvironment::Ptr env; DtlsEnvironment::Ptr env;
EventPoller::Ptr poller; EventPoller::Ptr poller;
// Passed by argument. // Passed by argument.
Listener* listener{ nullptr }; Listener* listener{ nullptr };
// Allocated by this. // Allocated by this.
SSL* ssl{ nullptr }; SSL* ssl{ nullptr };
BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads.
BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes.
Timer::Ptr timer; Timer::Ptr timer;
// Others. // Others.
DtlsState state{ DtlsState::NEW }; DtlsState state{ DtlsState::NEW };
Role localRole{ Role::NONE }; Role localRole{ Role::NONE };
Fingerprint remoteFingerprint; Fingerprint remoteFingerprint;
bool handshakeDone{ false }; bool handshakeDone{ false };
bool handshakeDoneNow{ false }; bool handshakeDoneNow{ false };
std::string remoteCert; std::string remoteCert;
//最大不超过mtu //最大不超过mtu
static constexpr int SslReadBufferSize{ 2000 }; static constexpr int SslReadBufferSize{ 2000 };
uint8_t sslReadBuffer[SslReadBufferSize]; uint8_t sslReadBuffer[SslReadBufferSize];
}; };
} // namespace RTC } // namespace RTC
......
...@@ -24,505 +24,505 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ...@@ -24,505 +24,505 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
namespace RTC namespace RTC
{ {
/* Static. */ /* Static. */
/* Instance methods. */ /* Instance methods. */
IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password)
: listener(listener), usernameFragment(usernameFragment), password(password) : listener(listener), usernameFragment(usernameFragment), password(password)
{ {
MS_TRACE(); MS_TRACE();
} }
void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple)
{ {
MS_TRACE(); MS_TRACE();
// Must be a Binding method. // Must be a Binding method.
if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) if (packet->GetMethod() != RTC::StunPacket::Method::BINDING)
{ {
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
{ {
MS_WARN_TAG( MS_WARN_TAG(
ice, ice,
"unknown method %#.3x in STUN Request => 400", "unknown method %#.3x in STUN Request => 400",
static_cast<unsigned int>(packet->GetMethod())); static_cast<unsigned int>(packet->GetMethod()));
// Reply 400. // Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400); RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
} }
else else
{ {
MS_WARN_TAG( MS_WARN_TAG(
ice, ice,
"ignoring STUN Indication or Response with unknown method %#.3x", "ignoring STUN Indication or Response with unknown method %#.3x",
static_cast<unsigned int>(packet->GetMethod())); static_cast<unsigned int>(packet->GetMethod()));
} }
return; return;
} }
// Must use FINGERPRINT (optional for ICE STUN indications). // Must use FINGERPRINT (optional for ICE STUN indications).
if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION)
{ {
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
{ {
MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400");
// Reply 400. // Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400); RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
} }
else else
{ {
MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT");
} }
return; return;
} }
switch (packet->GetClass()) switch (packet->GetClass())
{ {
case RTC::StunPacket::Class::REQUEST: case RTC::StunPacket::Class::REQUEST:
{ {
// USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required.
if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty())
{ {
MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400");
// Reply 400. // Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400); RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
return; return;
} }
// Check authentication. // Check authentication.
switch (packet->CheckAuthentication(this->usernameFragment, this->password)) switch (packet->CheckAuthentication(this->usernameFragment, this->password))
{ {
case RTC::StunPacket::Authentication::OK: case RTC::StunPacket::Authentication::OK:
{ {
if (!this->oldPassword.empty()) if (!this->oldPassword.empty())
{ {
MS_DEBUG_TAG(ice, "new ICE credentials applied"); MS_DEBUG_TAG(ice, "new ICE credentials applied");
this->oldUsernameFragment.clear(); this->oldUsernameFragment.clear();
this->oldPassword.clear(); this->oldPassword.clear();
} }
break; break;
} }
case RTC::StunPacket::Authentication::UNAUTHORIZED: case RTC::StunPacket::Authentication::UNAUTHORIZED:
{ {
// We may have changed our usernameFragment and password, so check // We may have changed our usernameFragment and password, so check
// the old ones. // the old ones.
// clang-format off // clang-format off
if ( if (
!this->oldUsernameFragment.empty() && !this->oldUsernameFragment.empty() &&
!this->oldPassword.empty() && !this->oldPassword.empty() &&
packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK
) )
// clang-format on // clang-format on
{ {
MS_DEBUG_TAG(ice, "using old ICE credentials"); MS_DEBUG_TAG(ice, "using old ICE credentials");
break; break;
} }
MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401");
// Reply 401. // Reply 401.
RTC::StunPacket* response = packet->CreateErrorResponse(401); RTC::StunPacket* response = packet->CreateErrorResponse(401);
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
return; return;
} }
case RTC::StunPacket::Authentication::BAD_REQUEST: case RTC::StunPacket::Authentication::BAD_REQUEST:
{ {
MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400");
// Reply 400. // Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400); RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
return; return;
} }
} }
#if 0 #if 0
// The remote peer must be ICE controlling. // The remote peer must be ICE controlling.
if (packet->GetIceControlled()) if (packet->GetIceControlled())
{ {
MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487");
// Reply 487 (Role Conflict). // Reply 487 (Role Conflict).
RTC::StunPacket* response = packet->CreateErrorResponse(487); RTC::StunPacket* response = packet->CreateErrorResponse(487);
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
return; return;
} }
#endif #endif
//MS_DEBUG_DEV( //MS_DEBUG_DEV(
// "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", // "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]",
// static_cast<uint32_t>(packet->GetPriority()), // static_cast<uint32_t>(packet->GetPriority()),
// packet->HasUseCandidate() ? "true" : "false"); // packet->HasUseCandidate() ? "true" : "false");
// Create a success response. // Create a success response.
RTC::StunPacket* response = packet->CreateSuccessResponse(); RTC::StunPacket* response = packet->CreateSuccessResponse();
sockaddr_storage peerAddr; sockaddr_storage peerAddr;
socklen_t addr_len = sizeof(peerAddr); socklen_t addr_len = sizeof(peerAddr);
getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len); getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len);
// Add XOR-MAPPED-ADDRESS. // Add XOR-MAPPED-ADDRESS.
response->SetXorMappedAddress((struct sockaddr *)&peerAddr); response->SetXorMappedAddress((struct sockaddr *)&peerAddr);
// Authenticate the response. // Authenticate the response.
if (this->oldPassword.empty()) if (this->oldPassword.empty())
response->Authenticate(this->password); response->Authenticate(this->password);
else else
response->Authenticate(this->oldPassword); response->Authenticate(this->oldPassword);
// Send back. // Send back.
response->Serialize(StunSerializeBuffer); response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple); this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response; delete response;
// Handle the tuple. // Handle the tuple.
HandleTuple(tuple, packet->HasUseCandidate()); HandleTuple(tuple, packet->HasUseCandidate());
break; break;
} }
case RTC::StunPacket::Class::INDICATION: case RTC::StunPacket::Class::INDICATION:
{ {
MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); MS_DEBUG_TAG(ice, "STUN Binding Indication processed");
break; break;
} }
case RTC::StunPacket::Class::SUCCESS_RESPONSE: case RTC::StunPacket::Class::SUCCESS_RESPONSE:
{ {
MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); MS_DEBUG_TAG(ice, "STUN Binding Success Response processed");
break; break;
} }
case RTC::StunPacket::Class::ERROR_RESPONSE: case RTC::StunPacket::Class::ERROR_RESPONSE:
{ {
MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); MS_DEBUG_TAG(ice, "STUN Binding Error Response processed");
break; break;
} }
} }
} }
bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const
{ {
MS_TRACE(); MS_TRACE();
return HasTuple(tuple) != nullptr; return HasTuple(tuple) != nullptr;
} }
void IceServer::RemoveTuple(RTC::TransportTuple* tuple) void IceServer::RemoveTuple(RTC::TransportTuple* tuple)
{ {
MS_TRACE(); MS_TRACE();
RTC::TransportTuple* removedTuple{ nullptr }; RTC::TransportTuple* removedTuple{ nullptr };
// Find the removed tuple. // Find the removed tuple.
auto it = this->tuples.begin(); auto it = this->tuples.begin();
for (; it != this->tuples.end(); ++it) for (; it != this->tuples.end(); ++it)
{ {
RTC::TransportTuple* storedTuple = *it; RTC::TransportTuple* storedTuple = *it;
if (storedTuple == tuple) if (storedTuple == tuple)
{ {
removedTuple = storedTuple; removedTuple = storedTuple;
break; break;
} }
} }
// If not found, ignore. // If not found, ignore.
if (!removedTuple) if (!removedTuple)
return; return;
// Remove from the list of tuples. // Remove from the list of tuples.
this->tuples.erase(it); this->tuples.erase(it);
// If this is not the selected tuple, stop here. // If this is not the selected tuple, stop here.
if (removedTuple != this->selectedTuple) if (removedTuple != this->selectedTuple)
return; return;
// Otherwise this was the selected tuple. // Otherwise this was the selected tuple.
this->selectedTuple = nullptr; this->selectedTuple = nullptr;
// Mark the first tuple as selected tuple (if any). // Mark the first tuple as selected tuple (if any).
if (!this->tuples.empty()) if (!this->tuples.empty())
{ {
SetSelectedTuple(this->tuples.front()); SetSelectedTuple(this->tuples.front());
} }
// Or just emit 'disconnected'. // Or just emit 'disconnected'.
else else
{ {
// Update state. // Update state.
this->state = IceState::DISCONNECTED; this->state = IceState::DISCONNECTED;
// Notify the listener. // Notify the listener.
this->listener->OnIceServerDisconnected(this); this->listener->OnIceServerDisconnected(this);
} }
} }
void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple)
{ {
MS_TRACE(); MS_TRACE();
MS_ASSERT( MS_ASSERT(
this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple");
auto* storedTuple = HasTuple(tuple); auto* storedTuple = HasTuple(tuple);
MS_ASSERT( MS_ASSERT(
storedTuple, storedTuple,
"cannot force the selected tuple if the given tuple was not already a valid tuple"); "cannot force the selected tuple if the given tuple was not already a valid tuple");
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
} }
void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate)
{ {
MS_TRACE(); MS_TRACE();
switch (this->state) switch (this->state)
{ {
case IceState::NEW: case IceState::NEW:
{ {
// There should be no tuples. // There should be no tuples.
MS_ASSERT( MS_ASSERT(
this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size());
// There shouldn't be a selected tuple. // There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple");
if (!hasUseCandidate) if (!hasUseCandidate)
{ {
MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'");
// Store the tuple. // Store the tuple.
auto* storedTuple = AddTuple(tuple); auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
// Update state. // Update state.
this->state = IceState::CONNECTED; this->state = IceState::CONNECTED;
// Notify the listener. // Notify the listener.
this->listener->OnIceServerConnected(this); this->listener->OnIceServerConnected(this);
} }
else else
{ {
MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'");
// Store the tuple. // Store the tuple.
auto* storedTuple = AddTuple(tuple); auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
// Update state. // Update state.
this->state = IceState::COMPLETED; this->state = IceState::COMPLETED;
// Notify the listener. // Notify the listener.
this->listener->OnIceServerCompleted(this); this->listener->OnIceServerCompleted(this);
} }
break; break;
} }
case IceState::DISCONNECTED: case IceState::DISCONNECTED:
{ {
// There should be no tuples. // There should be no tuples.
MS_ASSERT( MS_ASSERT(
this->tuples.empty(), this->tuples.empty(),
"state is 'disconnected' but there are %zu tuples", "state is 'disconnected' but there are %zu tuples",
this->tuples.size()); this->tuples.size());
// There shouldn't be a selected tuple. // There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple");
if (!hasUseCandidate) if (!hasUseCandidate)
{ {
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'");
// Store the tuple. // Store the tuple.
auto* storedTuple = AddTuple(tuple); auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
// Update state. // Update state.
this->state = IceState::CONNECTED; this->state = IceState::CONNECTED;
// Notify the listener. // Notify the listener.
this->listener->OnIceServerConnected(this); this->listener->OnIceServerConnected(this);
} }
else else
{ {
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'");
// Store the tuple. // Store the tuple.
auto* storedTuple = AddTuple(tuple); auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
// Update state. // Update state.
this->state = IceState::COMPLETED; this->state = IceState::COMPLETED;
// Notify the listener. // Notify the listener.
this->listener->OnIceServerCompleted(this); this->listener->OnIceServerCompleted(this);
} }
break; break;
} }
case IceState::CONNECTED: case IceState::CONNECTED:
{ {
// There should be some tuples. // There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples");
// There should be a selected tuple. // There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple");
if (!hasUseCandidate) if (!hasUseCandidate)
{ {
// If a new tuple store it. // If a new tuple store it.
if (!HasTuple(tuple)) if (!HasTuple(tuple))
AddTuple(tuple); AddTuple(tuple);
} }
else else
{ {
MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'");
auto* storedTuple = HasTuple(tuple); auto* storedTuple = HasTuple(tuple);
// If a new tuple store it. // If a new tuple store it.
if (!storedTuple) if (!storedTuple)
storedTuple = AddTuple(tuple); storedTuple = AddTuple(tuple);
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
// Update state. // Update state.
this->state = IceState::COMPLETED; this->state = IceState::COMPLETED;
// Notify the listener. // Notify the listener.
this->listener->OnIceServerCompleted(this); this->listener->OnIceServerCompleted(this);
} }
break; break;
} }
case IceState::COMPLETED: case IceState::COMPLETED:
{ {
// There should be some tuples. // There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples");
// There should be a selected tuple. // There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple");
if (!hasUseCandidate) if (!hasUseCandidate)
{ {
// If a new tuple store it. // If a new tuple store it.
if (!HasTuple(tuple)) if (!HasTuple(tuple))
AddTuple(tuple); AddTuple(tuple);
} }
else else
{ {
auto* storedTuple = HasTuple(tuple); auto* storedTuple = HasTuple(tuple);
// If a new tuple store it. // If a new tuple store it.
if (!storedTuple) if (!storedTuple)
storedTuple = AddTuple(tuple); storedTuple = AddTuple(tuple);
// Mark it as selected tuple. // Mark it as selected tuple.
SetSelectedTuple(storedTuple); SetSelectedTuple(storedTuple);
} }
break; break;
} }
} }
} }
inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple)
{ {
MS_TRACE(); MS_TRACE();
// Add the new tuple at the beginning of the list. // Add the new tuple at the beginning of the list.
this->tuples.push_front(tuple); this->tuples.push_front(tuple);
// Return the address of the inserted tuple. // Return the address of the inserted tuple.
return tuple; return tuple;
} }
inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const
{ {
MS_TRACE(); MS_TRACE();
// If there is no selected tuple yet then we know that the tuples list // If there is no selected tuple yet then we know that the tuples list
// is empty. // is empty.
if (!this->selectedTuple) if (!this->selectedTuple)
return nullptr; return nullptr;
// Check the current selected tuple. // Check the current selected tuple.
if (selectedTuple == tuple) if (selectedTuple == tuple)
return this->selectedTuple; return this->selectedTuple;
// Otherwise check other stored tuples. // Otherwise check other stored tuples.
for (const auto& it : this->tuples) for (const auto& it : this->tuples)
{ {
auto& storedTuple = it; auto& storedTuple = it;
if (storedTuple == tuple) if (storedTuple == tuple)
return storedTuple; return storedTuple;
} }
return nullptr; return nullptr;
} }
inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple)
{ {
MS_TRACE(); MS_TRACE();
// If already the selected tuple do nothing. // If already the selected tuple do nothing.
if (storedTuple == this->selectedTuple) if (storedTuple == this->selectedTuple)
return; return;
this->selectedTuple = storedTuple; this->selectedTuple = storedTuple;
this->lastSelectedTuple = storedTuple->shared_from_this(); this->lastSelectedTuple = storedTuple->shared_from_this();
// Notify the listener. // Notify the listener.
this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); this->listener->OnIceServerSelectedTuple(this, this->selectedTuple);
} }
} // namespace RTC } // namespace RTC
...@@ -30,109 +30,109 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ...@@ -30,109 +30,109 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
namespace RTC namespace RTC
{ {
using TransportTuple = toolkit::Session; using TransportTuple = toolkit::Session;
class IceServer class IceServer
{ {
public: public:
enum class IceState enum class IceState
{ {
NEW = 1, NEW = 1,
CONNECTED, CONNECTED,
COMPLETED, COMPLETED,
DISCONNECTED DISCONNECTED
}; };
public: public:
class Listener class Listener
{ {
public: public:
virtual ~Listener() = default; virtual ~Listener() = default;
public: public:
/** /**
* These callbacks are guaranteed to be called before ProcessStunPacket() * These callbacks are guaranteed to be called before ProcessStunPacket()
* returns, so the given pointers are still usable. * returns, so the given pointers are still usable.
*/ */
virtual void OnIceServerSendStunPacket( virtual void OnIceServerSendStunPacket(
const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerSelectedTuple( virtual void OnIceServerSelectedTuple(
const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0;
}; };
public: public:
IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password);
public: public:
void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple);
const std::string& GetUsernameFragment() const const std::string& GetUsernameFragment() const
{ {
return this->usernameFragment; return this->usernameFragment;
} }
const std::string& GetPassword() const const std::string& GetPassword() const
{ {
return this->password; return this->password;
} }
IceState GetState() const IceState GetState() const
{ {
return this->state; return this->state;
} }
RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const
{ {
return try_last_tuple ? this->lastSelectedTuple.lock().get() : this->selectedTuple; return try_last_tuple ? this->lastSelectedTuple.lock().get() : this->selectedTuple;
} }
void SetUsernameFragment(const std::string& usernameFragment) void SetUsernameFragment(const std::string& usernameFragment)
{ {
this->oldUsernameFragment = this->usernameFragment; this->oldUsernameFragment = this->usernameFragment;
this->usernameFragment = usernameFragment; this->usernameFragment = usernameFragment;
} }
void SetPassword(const std::string& password) void SetPassword(const std::string& password)
{ {
this->oldPassword = this->password; this->oldPassword = this->password;
this->password = password; this->password = password;
} }
bool IsValidTuple(const RTC::TransportTuple* tuple) const; bool IsValidTuple(const RTC::TransportTuple* tuple) const;
void RemoveTuple(RTC::TransportTuple* tuple); void RemoveTuple(RTC::TransportTuple* tuple);
// This should be just called in 'connected' or completed' state // This should be just called in 'connected' or completed' state
// and the given tuple must be an already valid tuple. // and the given tuple must be an already valid tuple.
void ForceSelectedTuple(const RTC::TransportTuple* tuple); void ForceSelectedTuple(const RTC::TransportTuple* tuple);
const std::list<RTC::TransportTuple *>& GetTuples() const { return tuples; } const std::list<RTC::TransportTuple *>& GetTuples() const { return tuples; }
private: private:
void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate);
/** /**
* Store the given tuple and return its stored address. * Store the given tuple and return its stored address.
*/ */
RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple);
/** /**
* If the given tuple exists return its stored address, nullptr otherwise. * If the given tuple exists return its stored address, nullptr otherwise.
*/ */
RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const;
/** /**
* Set the given tuple as the selected tuple. * Set the given tuple as the selected tuple.
* NOTE: The given tuple MUST be already stored within the list. * NOTE: The given tuple MUST be already stored within the list.
*/ */
void SetSelectedTuple(RTC::TransportTuple* storedTuple); void SetSelectedTuple(RTC::TransportTuple* storedTuple);
private: private:
// Passed by argument. // Passed by argument.
Listener* listener{ nullptr }; Listener* listener{ nullptr };
// Others. // Others.
std::string usernameFragment; std::string usernameFragment;
std::string password; std::string password;
std::string oldUsernameFragment; std::string oldUsernameFragment;
std::string oldPassword; std::string oldPassword;
IceState state{ IceState::NEW }; IceState state{ IceState::NEW };
std::list<RTC::TransportTuple *> tuples; std::list<RTC::TransportTuple *> tuples;
RTC::TransportTuple *selectedTuple; RTC::TransportTuple *selectedTuple;
std::weak_ptr<RTC::TransportTuple> lastSelectedTuple; std::weak_ptr<RTC::TransportTuple> lastSelectedTuple;
//最大不超过mtu //最大不超过mtu
static constexpr size_t StunSerializeBufferSize{ 1600 }; static constexpr size_t StunSerializeBufferSize{ 1600 };
uint8_t StunSerializeBuffer[StunSerializeBufferSize]; uint8_t StunSerializeBuffer[StunSerializeBufferSize];
}; };
} // namespace RTC } // namespace RTC
#endif #endif
...@@ -23,14 +23,14 @@ static constexpr uint16_t MaxSctpStreams{ 65535 }; ...@@ -23,14 +23,14 @@ static constexpr uint16_t MaxSctpStreams{ 65535 };
/* clang-format off */ /* clang-format off */
static constexpr uint16_t EventTypes[] = static constexpr uint16_t EventTypes[] =
{ {
SCTP_ADAPTATION_INDICATION, SCTP_ADAPTATION_INDICATION,
SCTP_ASSOC_CHANGE, SCTP_ASSOC_CHANGE,
SCTP_ASSOC_RESET_EVENT, SCTP_ASSOC_RESET_EVENT,
SCTP_REMOTE_ERROR, SCTP_REMOTE_ERROR,
SCTP_SHUTDOWN_EVENT, SCTP_SHUTDOWN_EVENT,
SCTP_SEND_FAILED_EVENT, SCTP_SEND_FAILED_EVENT,
SCTP_STREAM_RESET_EVENT, SCTP_STREAM_RESET_EVENT,
SCTP_STREAM_CHANGE_EVENT SCTP_STREAM_CHANGE_EVENT
}; };
/* clang-format on */ /* clang-format on */
...@@ -44,45 +44,45 @@ inline static int onRecvSctpData( ...@@ -44,45 +44,45 @@ inline static int onRecvSctpData(
int flags, int flags,
void* ulpInfo) void* ulpInfo)
{ {
auto* sctpAssociation = static_cast<RTC::SctpAssociation*>(ulpInfo); auto* sctpAssociation = static_cast<RTC::SctpAssociation*>(ulpInfo);
if (sctpAssociation == nullptr) if (sctpAssociation == nullptr)
{ {
std::free(data); std::free(data);
return 0; return 0;
} }
if (flags & MSG_NOTIFICATION) if (flags & MSG_NOTIFICATION)
{ {
sctpAssociation->OnUsrSctpReceiveSctpNotification( sctpAssociation->OnUsrSctpReceiveSctpNotification(
static_cast<union sctp_notification*>(data), len); static_cast<union sctp_notification*>(data), len);
} }
else else
{ {
uint16_t streamId = rcv.rcv_sid; uint16_t streamId = rcv.rcv_sid;
uint32_t ppid = ntohl(rcv.rcv_ppid); uint32_t ppid = ntohl(rcv.rcv_ppid);
uint16_t ssn = rcv.rcv_ssn; uint16_t ssn = rcv.rcv_ssn;
MS_DEBUG_TAG( MS_DEBUG_TAG(
sctp, sctp,
"data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32 "data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32
", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]", ", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]",
len, len,
rcv.rcv_sid, rcv.rcv_sid,
rcv.rcv_ssn, rcv.rcv_ssn,
rcv.rcv_tsn, rcv.rcv_tsn,
ntohl(rcv.rcv_ppid), ntohl(rcv.rcv_ppid),
rcv.rcv_context, rcv.rcv_context,
flags); flags);
sctpAssociation->OnUsrSctpReceiveSctpData( sctpAssociation->OnUsrSctpReceiveSctpData(
streamId, ssn, ppid, flags, static_cast<uint8_t*>(data), len); streamId, ssn, ppid, flags, static_cast<uint8_t*>(data), len);
} }
std::free(data); std::free(data);
return 1; return 1;
} }
/* Static methods for usrsctp global callbacks. */ /* Static methods for usrsctp global callbacks. */
...@@ -136,824 +136,824 @@ namespace RTC ...@@ -136,824 +136,824 @@ namespace RTC
//////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////
/* Instance methods. */ /* Instance methods. */
SctpAssociation::SctpAssociation( SctpAssociation::SctpAssociation(
Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel) Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel)
: listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize), : listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize),
isDataChannel(isDataChannel) isDataChannel(isDataChannel)
{ {
MS_TRACE(); MS_TRACE();
_env = SctpEnv::Instance().shared_from_this(); _env = SctpEnv::Instance().shared_from_this();
// Register ourselves in usrsctp. // Register ourselves in usrsctp.
usrsctp_register_address(static_cast<void*>(this)); usrsctp_register_address(static_cast<void*>(this));
int ret; int ret;
this->socket = usrsctp_socket( this->socket = usrsctp_socket(
AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast<void*>(this)); AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast<void*>(this));
if (this->socket == nullptr) if (this->socket == nullptr)
MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno));
usrsctp_set_ulpinfo(this->socket, static_cast<void*>(this)); usrsctp_set_ulpinfo(this->socket, static_cast<void*>(this));
// Make the socket non-blocking. // Make the socket non-blocking.
ret = usrsctp_set_non_blocking(this->socket, 1); ret = usrsctp_set_non_blocking(this->socket, 1);
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno));
// Set SO_LINGER. // Set SO_LINGER.
// This ensures that the usrsctp close call deletes the association. This // This ensures that the usrsctp close call deletes the association. This
// prevents usrsctp from calling the global send callback with references to // prevents usrsctp from calling the global send callback with references to
// this class as the address. // this class as the address.
struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init) struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init)
lingerOpt.l_onoff = 1; lingerOpt.l_onoff = 1;
lingerOpt.l_linger = 0; lingerOpt.l_linger = 0;
ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt)); ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt));
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno));
// Set SCTP_ENABLE_STREAM_RESET. // Set SCTP_ENABLE_STREAM_RESET.
struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init)
av.assoc_value = av.assoc_value =
SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ; SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ;
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av)); ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av));
if (ret < 0) if (ret < 0)
{ {
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno));
} }
// Set SCTP_NODELAY. // Set SCTP_NODELAY.
uint32_t noDelay = 1; uint32_t noDelay = 1;
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay)); ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay));
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno));
// Enable events. // Enable events.
struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&event, 0, sizeof(event)); std::memset(&event, 0, sizeof(event));
event.se_on = 1; event.se_on = 1;
for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i) for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i)
{ {
event.se_type = EventTypes[i]; event.se_type = EventTypes[i];
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event)); ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event));
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno));
} }
// Init message. // Init message.
struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&initmsg, 0, sizeof(initmsg)); std::memset(&initmsg, 0, sizeof(initmsg));
initmsg.sinit_num_ostreams = this->os; initmsg.sinit_num_ostreams = this->os;
initmsg.sinit_max_instreams = this->mis; initmsg.sinit_max_instreams = this->mis;
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg)); ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg));
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno));
// Server side. // Server side.
struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&sconn, 0, sizeof(sconn)); std::memset(&sconn, 0, sizeof(sconn));
sconn.sconn_family = AF_CONN; sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(5000); sconn.sconn_port = htons(5000);
sconn.sconn_addr = static_cast<void*>(this); sconn.sconn_addr = static_cast<void*>(this);
#ifdef HAVE_SCONN_LEN #ifdef HAVE_SCONN_LEN
sconn.sconn_len = sizeof(sconn); sconn.sconn_len = sizeof(sconn);
#endif #endif
ret = usrsctp_bind(this->socket, reinterpret_cast<struct sockaddr*>(&sconn), sizeof(sconn)); ret = usrsctp_bind(this->socket, reinterpret_cast<struct sockaddr*>(&sconn), sizeof(sconn));
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno));
} }
SctpAssociation::~SctpAssociation() SctpAssociation::~SctpAssociation()
{ {
MS_TRACE(); MS_TRACE();
usrsctp_set_ulpinfo(this->socket, nullptr); usrsctp_set_ulpinfo(this->socket, nullptr);
usrsctp_close(this->socket); usrsctp_close(this->socket);
// Deregister ourselves from usrsctp. // Deregister ourselves from usrsctp.
usrsctp_deregister_address(static_cast<void*>(this)); usrsctp_deregister_address(static_cast<void*>(this));
delete[] this->messageBuffer; delete[] this->messageBuffer;
} }
void SctpAssociation::TransportConnected() void SctpAssociation::TransportConnected()
{ {
MS_TRACE(); MS_TRACE();
// Just run the SCTP stack if our state is 'new'. // Just run the SCTP stack if our state is 'new'.
if (this->state != SctpState::NEW) if (this->state != SctpState::NEW)
return; return;
try try
{ {
int ret; int ret;
struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&rconn, 0, sizeof(rconn)); std::memset(&rconn, 0, sizeof(rconn));
rconn.sconn_family = AF_CONN; rconn.sconn_family = AF_CONN;
rconn.sconn_port = htons(5000); rconn.sconn_port = htons(5000);
rconn.sconn_addr = static_cast<void*>(this); rconn.sconn_addr = static_cast<void*>(this);
#ifdef HAVE_SCONN_LEN #ifdef HAVE_SCONN_LEN
rconn.sconn_len = sizeof(rconn); rconn.sconn_len = sizeof(rconn);
#endif #endif
ret = usrsctp_connect(this->socket, reinterpret_cast<struct sockaddr*>(&rconn), sizeof(rconn)); ret = usrsctp_connect(this->socket, reinterpret_cast<struct sockaddr*>(&rconn), sizeof(rconn));
if (ret < 0 && errno != EINPROGRESS) if (ret < 0 && errno != EINPROGRESS)
MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno));
// Disable MTU discovery. // Disable MTU discovery.
sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init) sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&peerAddrParams, 0, sizeof(peerAddrParams)); std::memset(&peerAddrParams, 0, sizeof(peerAddrParams));
std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn)); std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn));
peerAddrParams.spp_flags = SPP_PMTUD_DISABLE; peerAddrParams.spp_flags = SPP_PMTUD_DISABLE;
// The MTU value provided specifies the space available for chunks in the // The MTU value provided specifies the space available for chunks in the
// packet, so let's subtract the SCTP header size. // packet, so let's subtract the SCTP header size.
peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams); peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams);
ret = usrsctp_setsockopt( ret = usrsctp_setsockopt(
this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams)); this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams));
if (ret < 0) if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno)); MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno));
// Announce connecting state. // Announce connecting state.
this->state = SctpState::CONNECTING; this->state = SctpState::CONNECTING;
this->listener->OnSctpAssociationConnecting(this); this->listener->OnSctpAssociationConnecting(this);
} }
catch (... /*error*/) catch (... /*error*/)
{ {
this->state = SctpState::FAILED; this->state = SctpState::FAILED;
this->listener->OnSctpAssociationFailed(this); this->listener->OnSctpAssociationFailed(this);
throw; throw;
} }
} }
void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len)
{ {
MS_TRACE(); MS_TRACE();
#if MS_LOG_DEV_LEVEL == 3 #if MS_LOG_DEV_LEVEL == 3
MS_DUMP_DATA(data, len); MS_DUMP_DATA(data, len);
#endif #endif
usrsctp_conninput(static_cast<void*>(this), data, len, 0); usrsctp_conninput(static_cast<void*>(this), data, len, 0);
} }
void SctpAssociation::SendSctpMessage( void SctpAssociation::SendSctpMessage(
const RTC::SctpStreamParameters &parameters, uint32_t ppid, const uint8_t* msg, size_t len) const RTC::SctpStreamParameters &parameters, uint32_t ppid, const uint8_t* msg, size_t len)
{ {
MS_TRACE(); MS_TRACE();
// This must be controlled by the DataConsumer. // This must be controlled by the DataConsumer.
MS_ASSERT( MS_ASSERT(
len <= this->maxSctpMessageSize, len <= this->maxSctpMessageSize,
"given message exceeds max allowed message size [message size:%zu, max message size:%zu]", "given message exceeds max allowed message size [message size:%zu, max message size:%zu]",
len, len,
this->maxSctpMessageSize); this->maxSctpMessageSize);
// Fill stcp_sendv_spa. // Fill stcp_sendv_spa.
struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&spa, 0, sizeof(spa)); std::memset(&spa, 0, sizeof(spa));
spa.sendv_flags = SCTP_SEND_SNDINFO_VALID; spa.sendv_flags = SCTP_SEND_SNDINFO_VALID;
spa.sendv_sndinfo.snd_sid = parameters.streamId; spa.sendv_sndinfo.snd_sid = parameters.streamId;
spa.sendv_sndinfo.snd_ppid = htonl(ppid); spa.sendv_sndinfo.snd_ppid = htonl(ppid);
spa.sendv_sndinfo.snd_flags = SCTP_EOR; spa.sendv_sndinfo.snd_flags = SCTP_EOR;
// If ordered it must be reliable. // If ordered it must be reliable.
if (parameters.ordered) if (parameters.ordered)
{ {
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE;
spa.sendv_prinfo.pr_value = 0; spa.sendv_prinfo.pr_value = 0;
} }
// Configure reliability: https://tools.ietf.org/html/rfc3758 // Configure reliability: https://tools.ietf.org/html/rfc3758
else else
{ {
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
if (parameters.maxPacketLifeTime != 0) if (parameters.maxPacketLifeTime != 0)
{ {
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL;
spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime; spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime;
} }
else if (parameters.maxRetransmits != 0) else if (parameters.maxRetransmits != 0)
{ {
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
spa.sendv_prinfo.pr_value = parameters.maxRetransmits; spa.sendv_prinfo.pr_value = parameters.maxRetransmits;
} }
} }
int ret = usrsctp_sendv( int ret = usrsctp_sendv(
this->socket, msg, len, nullptr, 0, &spa, static_cast<socklen_t>(sizeof(spa)), SCTP_SENDV_SPA, 0); this->socket, msg, len, nullptr, 0, &spa, static_cast<socklen_t>(sizeof(spa)), SCTP_SENDV_SPA, 0);
if (ret < 0) if (ret < 0)
{ {
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s", "error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s",
parameters.streamId, parameters.streamId,
ppid, ppid,
len, len,
std::strerror(errno)); std::strerror(errno));
} }
} }
void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters &params) void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters &params)
{ {
MS_TRACE(); MS_TRACE();
auto streamId = params.streamId; auto streamId = params.streamId;
// We need more OS. // We need more OS.
if (streamId > this->os - 1) if (streamId > this->os - 1)
AddOutgoingStreams(/*force*/ false); AddOutgoingStreams(/*force*/ false);
} }
void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters &params) void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters &params)
{ {
MS_TRACE(); MS_TRACE();
auto streamId = params.streamId; auto streamId = params.streamId;
// Send SCTP_RESET_STREAMS to the remote. // Send SCTP_RESET_STREAMS to the remote.
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
if (this->isDataChannel) if (this->isDataChannel)
ResetSctpStream(streamId, StreamDirection::OUTGOING); ResetSctpStream(streamId, StreamDirection::OUTGOING);
else else
ResetSctpStream(streamId, StreamDirection::INCOMING); ResetSctpStream(streamId, StreamDirection::INCOMING);
} }
void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters &params) void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters &params)
{ {
MS_TRACE(); MS_TRACE();
auto streamId = params.streamId; auto streamId = params.streamId;
// Send SCTP_RESET_STREAMS to the remote. // Send SCTP_RESET_STREAMS to the remote.
ResetSctpStream(streamId, StreamDirection::OUTGOING); ResetSctpStream(streamId, StreamDirection::OUTGOING);
} }
void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction) void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction)
{ {
MS_TRACE(); MS_TRACE();
// Do nothing if an outgoing stream that could not be allocated by us. // Do nothing if an outgoing stream that could not be allocated by us.
if (direction == StreamDirection::OUTGOING && streamId > this->os - 1) if (direction == StreamDirection::OUTGOING && streamId > this->os - 1)
return; return;
int ret; int ret;
struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init)
socklen_t len = sizeof(av); socklen_t len = sizeof(av);
#ifndef SCTP_RECONFIG_SUPPORTED #ifndef SCTP_RECONFIG_SUPPORTED
#define SCTP_RECONFIG_SUPPORTED 0x00000029 #define SCTP_RECONFIG_SUPPORTED 0x00000029
#endif #endif
ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len); ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len);
if (ret == 0)
{
if (av.assoc_value != 1)
{
MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated");
return; if (ret == 0)
} {
} if (av.assoc_value != 1)
else {
{ MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated");
MS_WARN_TAG(
sctp,
"could not retrieve whether stream reconfiguration has been negotiated: %s\n",
std::strerror(errno));
return; return;
} }
}
else
{
MS_WARN_TAG(
sctp,
"could not retrieve whether stream reconfiguration has been negotiated: %s\n",
std::strerror(errno));
return;
}
// As per spec: https://tools.ietf.org/html/rfc6525#section-4.1 // As per spec: https://tools.ietf.org/html/rfc6525#section-4.1
len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t); len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t);
auto* srs = static_cast<struct sctp_reset_streams*>(std::malloc(len)); auto* srs = static_cast<struct sctp_reset_streams*>(std::malloc(len));
switch (direction) switch (direction)
{ {
case StreamDirection::INCOMING: case StreamDirection::INCOMING:
srs->srs_flags = SCTP_STREAM_RESET_INCOMING; srs->srs_flags = SCTP_STREAM_RESET_INCOMING;
break; break;
case StreamDirection::OUTGOING: case StreamDirection::OUTGOING:
srs->srs_flags = SCTP_STREAM_RESET_OUTGOING; srs->srs_flags = SCTP_STREAM_RESET_OUTGOING;
break; break;
} }
srs->srs_number_streams = 1; srs->srs_number_streams = 1;
srs->srs_stream_list[0] = streamId; // No need for htonl(). srs->srs_stream_list[0] = streamId; // No need for htonl().
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len); ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len);
if (ret == 0) if (ret == 0)
{ {
MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId); MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId);
} }
else else
{ {
MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno)); MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno));
} }
std::free(srs); std::free(srs);
} }
void SctpAssociation::AddOutgoingStreams(bool force) void SctpAssociation::AddOutgoingStreams(bool force)
{ {
MS_TRACE(); MS_TRACE();
uint16_t additionalOs{ 0 }; uint16_t additionalOs{ 0 };
if (MaxSctpStreams - this->os >= 32) if (MaxSctpStreams - this->os >= 32)
additionalOs = 32; additionalOs = 32;
else else
additionalOs = MaxSctpStreams - this->os; additionalOs = MaxSctpStreams - this->os;
if (additionalOs == 0) if (additionalOs == 0)
{ {
MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os); MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os);
return; return;
} }
auto nextDesiredOs = this->os + additionalOs; auto nextDesiredOs = this->os + additionalOs;
// Already in progress, ignore (unless forced). // Already in progress, ignore (unless forced).
if (!force && nextDesiredOs == this->desiredOs) if (!force && nextDesiredOs == this->desiredOs)
return; return;
// Update desired value. // Update desired value.
this->desiredOs = nextDesiredOs; this->desiredOs = nextDesiredOs;
// If not connected, defer it. // If not connected, defer it.
if (this->state != SctpState::CONNECTED) if (this->state != SctpState::CONNECTED)
{ {
MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase"); MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase");
return; return;
} }
struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init) struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&sas, 0, sizeof(sas)); std::memset(&sas, 0, sizeof(sas));
sas.sas_instrms = 0; sas.sas_instrms = 0;
sas.sas_outstrms = additionalOs; sas.sas_outstrms = additionalOs;
MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs); MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs);
int ret = usrsctp_setsockopt( int ret = usrsctp_setsockopt(
this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast<socklen_t>(sizeof(sas))); this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast<socklen_t>(sizeof(sas)));
if (ret < 0) if (ret < 0)
MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno)); MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno));
} }
void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len) void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len)
{ {
MS_TRACE(); MS_TRACE();
const uint8_t* data = static_cast<uint8_t*>(buffer); const uint8_t* data = static_cast<uint8_t*>(buffer);
#if MS_LOG_DEV_LEVEL == 3 #if MS_LOG_DEV_LEVEL == 3
MS_DUMP_DATA(data, len); MS_DUMP_DATA(data, len);
#endif #endif
this->listener->OnSctpAssociationSendData(this, data, len); this->listener->OnSctpAssociationSendData(this, data, len);
} }
void SctpAssociation::OnUsrSctpReceiveSctpData( void SctpAssociation::OnUsrSctpReceiveSctpData(
uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len) uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len)
{ {
// Ignore WebRTC DataChannel Control DATA chunks. // Ignore WebRTC DataChannel Control DATA chunks.
if (ppid == 50) if (ppid == 50)
{ {
MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)"); MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)");
return; return;
} }
if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived) if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived)
{ {
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16 "message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16
", last ssn received:%" PRIu16 "]", ", last ssn received:%" PRIu16 "]",
ssn, ssn,
this->lastSsnReceived); this->lastSsnReceived);
this->messageBufferLen = 0; this->messageBufferLen = 0;
} }
// Update last SSN received. // Update last SSN received.
this->lastSsnReceived = ssn; this->lastSsnReceived = ssn;
auto eor = static_cast<bool>(flags & MSG_EOR); auto eor = static_cast<bool>(flags & MSG_EOR);
if (this->messageBufferLen + len > this->maxSctpMessageSize) if (this->messageBufferLen + len > this->maxSctpMessageSize)
{ {
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]", "ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]",
this->messageBufferLen + len, this->messageBufferLen + len,
this->maxSctpMessageSize, this->maxSctpMessageSize,
eor ? 1 : 0); eor ? 1 : 0);
this->lastSsnReceived = 0; this->lastSsnReceived = 0;
return; return;
} }
// If end of message and there is no buffered data, notify it directly. // If end of message and there is no buffered data, notify it directly.
if (eor && this->messageBufferLen == 0) if (eor && this->messageBufferLen == 0)
{ {
MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]"); MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]");
this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len); this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len);
} }
// If end of message and there is buffered data, append data and notify buffer. // If end of message and there is buffered data, append data and notify buffer.
else if (eor && this->messageBufferLen != 0) else if (eor && this->messageBufferLen != 0)
{ {
std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); std::memcpy(this->messageBuffer + this->messageBufferLen, data, len);
this->messageBufferLen += len; this->messageBufferLen += len;
MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen); MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen);
this->listener->OnSctpAssociationMessageReceived( this->listener->OnSctpAssociationMessageReceived(
this, streamId, ppid, this->messageBuffer, this->messageBufferLen); this, streamId, ppid, this->messageBuffer, this->messageBufferLen);
this->messageBufferLen = 0; this->messageBufferLen = 0;
} }
// If non end of message, append data to the buffer. // If non end of message, append data to the buffer.
else if (!eor) else if (!eor)
{ {
// Allocate the buffer if not already done. // Allocate the buffer if not already done.
if (!this->messageBuffer) if (!this->messageBuffer)
this->messageBuffer = new uint8_t[this->maxSctpMessageSize]; this->messageBuffer = new uint8_t[this->maxSctpMessageSize];
std::memcpy(this->messageBuffer + this->messageBufferLen, data, len); std::memcpy(this->messageBuffer + this->messageBufferLen, data, len);
this->messageBufferLen += len; this->messageBufferLen += len;
MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen); MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen);
} }
} }
void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len) void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len)
{ {
if (notification->sn_header.sn_length != (uint32_t)len) if (notification->sn_header.sn_length != (uint32_t)len)
return; return;
switch (notification->sn_header.sn_type) switch (notification->sn_header.sn_type)
{ {
case SCTP_ADAPTATION_INDICATION: case SCTP_ADAPTATION_INDICATION:
{ {
MS_DEBUG_TAG( MS_DEBUG_TAG(
sctp, sctp,
"SCTP adaptation indication [%x]", "SCTP adaptation indication [%x]",
notification->sn_adaptation_event.sai_adaptation_ind); notification->sn_adaptation_event.sai_adaptation_ind);
break; break;
} }
case SCTP_ASSOC_CHANGE: case SCTP_ASSOC_CHANGE:
{ {
switch (notification->sn_assoc_change.sac_state) switch (notification->sn_assoc_change.sac_state)
{ {
case SCTP_COMM_UP: case SCTP_COMM_UP:
{ {
MS_DEBUG_TAG( MS_DEBUG_TAG(
sctp, sctp,
"SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]", "SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]",
notification->sn_assoc_change.sac_outbound_streams, notification->sn_assoc_change.sac_outbound_streams,
notification->sn_assoc_change.sac_inbound_streams); notification->sn_assoc_change.sac_inbound_streams);
// Update our OS. // Update our OS.
this->os = notification->sn_assoc_change.sac_outbound_streams; this->os = notification->sn_assoc_change.sac_outbound_streams;
// Increase if requested before connected. // Increase if requested before connected.
if (this->desiredOs > this->os) if (this->desiredOs > this->os)
AddOutgoingStreams(/*force*/ true); AddOutgoingStreams(/*force*/ true);
if (this->state != SctpState::CONNECTED) if (this->state != SctpState::CONNECTED)
{ {
this->state = SctpState::CONNECTED; this->state = SctpState::CONNECTED;
this->listener->OnSctpAssociationConnected(this); this->listener->OnSctpAssociationConnected(this);
} }
break; break;
} }
case SCTP_COMM_LOST: case SCTP_COMM_LOST:
{ {
if (notification->sn_header.sn_length > 0) if (notification->sn_header.sn_length > 0)
{ {
static const size_t BufferSize{ 1024 }; static const size_t BufferSize{ 1024 };
static char buffer[BufferSize]; static char buffer[BufferSize];
uint32_t len = notification->sn_header.sn_length; uint32_t len = notification->sn_header.sn_length;
for (uint32_t i{ 0 }; i < len; ++i) for (uint32_t i{ 0 }; i < len; ++i)
{ {
std::snprintf( std::snprintf(
buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]);
} }
MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer); MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer);
} }
else else
{ {
MS_DEBUG_TAG(sctp, "SCTP communication lost"); MS_DEBUG_TAG(sctp, "SCTP communication lost");
} }
if (this->state != SctpState::CLOSED) if (this->state != SctpState::CLOSED)
{ {
this->state = SctpState::CLOSED; this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this); this->listener->OnSctpAssociationClosed(this);
} }
break; break;
} }
case SCTP_RESTART: case SCTP_RESTART:
{ {
MS_DEBUG_TAG( MS_DEBUG_TAG(
sctp, sctp,
"SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]", "SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]",
notification->sn_assoc_change.sac_outbound_streams, notification->sn_assoc_change.sac_outbound_streams,
notification->sn_assoc_change.sac_inbound_streams); notification->sn_assoc_change.sac_inbound_streams);
// Update our OS. // Update our OS.
this->os = notification->sn_assoc_change.sac_outbound_streams; this->os = notification->sn_assoc_change.sac_outbound_streams;
// Increase if requested before connected. // Increase if requested before connected.
if (this->desiredOs > this->os) if (this->desiredOs > this->os)
AddOutgoingStreams(/*force*/ true); AddOutgoingStreams(/*force*/ true);
if (this->state != SctpState::CONNECTED) if (this->state != SctpState::CONNECTED)
{ {
this->state = SctpState::CONNECTED; this->state = SctpState::CONNECTED;
this->listener->OnSctpAssociationConnected(this); this->listener->OnSctpAssociationConnected(this);
} }
break; break;
} }
case SCTP_SHUTDOWN_COMP: case SCTP_SHUTDOWN_COMP:
{ {
MS_DEBUG_TAG(sctp, "SCTP association gracefully closed"); MS_DEBUG_TAG(sctp, "SCTP association gracefully closed");
if (this->state != SctpState::CLOSED) if (this->state != SctpState::CLOSED)
{ {
this->state = SctpState::CLOSED; this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this); this->listener->OnSctpAssociationClosed(this);
} }
break; break;
} }
case SCTP_CANT_STR_ASSOC: case SCTP_CANT_STR_ASSOC:
{ {
if (notification->sn_header.sn_length > 0) if (notification->sn_header.sn_length > 0)
{ {
static const size_t BufferSize{ 1024 }; static const size_t BufferSize{ 1024 };
static char buffer[BufferSize]; static char buffer[BufferSize];
uint32_t len = notification->sn_header.sn_length; uint32_t len = notification->sn_header.sn_length;
for (uint32_t i{ 0 }; i < len; ++i) for (uint32_t i{ 0 }; i < len; ++i)
{ {
std::snprintf( std::snprintf(
buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]); buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]);
} }
MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer); MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer);
} }
if (this->state != SctpState::FAILED) if (this->state != SctpState::FAILED)
{ {
this->state = SctpState::FAILED; this->state = SctpState::FAILED;
this->listener->OnSctpAssociationFailed(this); this->listener->OnSctpAssociationFailed(this);
} }
break; break;
} }
default:; default:;
} }
break; break;
} }
// https://tools.ietf.org/html/rfc6525#section-6.1.2. // https://tools.ietf.org/html/rfc6525#section-6.1.2.
case SCTP_ASSOC_RESET_EVENT: case SCTP_ASSOC_RESET_EVENT:
{ {
MS_DEBUG_TAG(sctp, "SCTP association reset event received"); MS_DEBUG_TAG(sctp, "SCTP association reset event received");
break; break;
} }
// An Operation Error is not considered fatal in and of itself, but may be // An Operation Error is not considered fatal in and of itself, but may be
// used with an ABORT chunk to report a fatal condition. // used with an ABORT chunk to report a fatal condition.
case SCTP_REMOTE_ERROR: case SCTP_REMOTE_ERROR:
{ {
static const size_t BufferSize{ 1024 }; static const size_t BufferSize{ 1024 };
static char buffer[BufferSize]; static char buffer[BufferSize];
uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error); uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error);
for (uint32_t i{ 0 }; i < len; i++) for (uint32_t i{ 0 }; i < len; i++)
{ {
std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]); std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]);
} }
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"remote SCTP association error [type:0x%04x, data:%s]", "remote SCTP association error [type:0x%04x, data:%s]",
notification->sn_remote_error.sre_error, notification->sn_remote_error.sre_error,
buffer); buffer);
break; break;
} }
// When a peer sends a SHUTDOWN, SCTP delivers this notification to // When a peer sends a SHUTDOWN, SCTP delivers this notification to
// inform the application that it should cease sending data. // inform the application that it should cease sending data.
case SCTP_SHUTDOWN_EVENT: case SCTP_SHUTDOWN_EVENT:
{ {
MS_DEBUG_TAG(sctp, "remote SCTP association shutdown"); MS_DEBUG_TAG(sctp, "remote SCTP association shutdown");
if (this->state != SctpState::CLOSED) if (this->state != SctpState::CLOSED)
{ {
this->state = SctpState::CLOSED; this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this); this->listener->OnSctpAssociationClosed(this);
} }
break; break;
} }
case SCTP_SEND_FAILED_EVENT: case SCTP_SEND_FAILED_EVENT:
{ {
static const size_t BufferSize{ 1024 }; static const size_t BufferSize{ 1024 };
static char buffer[BufferSize]; static char buffer[BufferSize];
uint32_t len = uint32_t len =
notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event); notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event);
for (uint32_t i{ 0 }; i < len; ++i) for (uint32_t i{ 0 }; i < len; ++i)
{ {
std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]); std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]);
} }
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32 "SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32
", sent:%s, error:0x%08x, info:%s]", ", sent:%s, error:0x%08x, info:%s]",
notification->sn_send_failed_event.ssfe_info.snd_sid, notification->sn_send_failed_event.ssfe_info.snd_sid,
ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid), ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid),
(notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no", (notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no",
notification->sn_send_failed_event.ssfe_error, notification->sn_send_failed_event.ssfe_error,
buffer); buffer);
break; break;
} }
case SCTP_STREAM_RESET_EVENT: case SCTP_STREAM_RESET_EVENT:
{ {
bool incoming{ false }; bool incoming{ false };
bool outgoing{ false }; bool outgoing{ false };
uint16_t numStreams = uint16_t numStreams =
(notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) / (notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) /
sizeof(uint16_t); sizeof(uint16_t);
if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN)
incoming = true; incoming = true;
if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN)
outgoing = true; outgoing = true;
//todo 打印sctp调试信息 //todo 打印sctp调试信息
if (false /*MS_HAS_DEBUG_TAG(sctp)*/) if (false /*MS_HAS_DEBUG_TAG(sctp)*/)
{ {
std::string streamIds; std::string streamIds;
for (uint16_t i{ 0 }; i < numStreams; ++i) for (uint16_t i{ 0 }; i < numStreams; ++i)
{ {
auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; auto streamId = notification->sn_strreset_event.strreset_stream_list[i];
// Don't log more than 5 stream ids. // Don't log more than 5 stream ids.
if (i > 4) if (i > 4)
{ {
streamIds.append("..."); streamIds.append("...");
break; break;
} }
if (i > 0) if (i > 0)
streamIds.append(","); streamIds.append(",");
streamIds.append(std::to_string(streamId)); streamIds.append(std::to_string(streamId));
} }
MS_DEBUG_TAG( MS_DEBUG_TAG(
sctp, sctp,
"SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]", "SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]",
notification->sn_strreset_event.strreset_flags, notification->sn_strreset_event.strreset_flags,
incoming ? "true" : "false", incoming ? "true" : "false",
outgoing ? "true" : "false", outgoing ? "true" : "false",
numStreams, numStreams,
streamIds.c_str()); streamIds.c_str());
} }
// Special case for WebRTC DataChannels in which we must also reset our // Special case for WebRTC DataChannels in which we must also reset our
// outgoing SCTP stream. // outgoing SCTP stream.
if (incoming && !outgoing && this->isDataChannel) if (incoming && !outgoing && this->isDataChannel)
{ {
for (uint16_t i{ 0 }; i < numStreams; ++i) for (uint16_t i{ 0 }; i < numStreams; ++i)
{ {
auto streamId = notification->sn_strreset_event.strreset_stream_list[i]; auto streamId = notification->sn_strreset_event.strreset_stream_list[i];
ResetSctpStream(streamId, StreamDirection::OUTGOING); ResetSctpStream(streamId, StreamDirection::OUTGOING);
} }
} }
break; break;
} }
case SCTP_STREAM_CHANGE_EVENT: case SCTP_STREAM_CHANGE_EVENT:
{ {
if (notification->sn_strchange_event.strchange_flags == 0) if (notification->sn_strchange_event.strchange_flags == 0)
{ {
MS_DEBUG_TAG( MS_DEBUG_TAG(
sctp, sctp,
"SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", "SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms, notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms, notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags); notification->sn_strchange_event.strchange_flags);
} }
else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED) else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED)
{ {
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", "SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms, notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms, notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags); notification->sn_strchange_event.strchange_flags);
break; break;
} }
else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED) else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED)
{ {
MS_WARN_TAG( MS_WARN_TAG(
sctp, sctp,
"SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]", "SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms, notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms, notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags); notification->sn_strchange_event.strchange_flags);
break; break;
} }
// Update OS. // Update OS.
this->os = notification->sn_strchange_event.strchange_outstrms; this->os = notification->sn_strchange_event.strchange_outstrms;
break; break;
} }
default: default:
{ {
MS_WARN_TAG( MS_WARN_TAG(
sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type); sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type);
} }
} }
} }
//////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -18,104 +18,104 @@ namespace RTC ...@@ -18,104 +18,104 @@ namespace RTC
uint16_t maxRetransmits{ 0u }; uint16_t maxRetransmits{ 0u };
}; };
class SctpAssociation class SctpAssociation
{ {
public: public:
enum class SctpState enum class SctpState
{ {
NEW = 1, NEW = 1,
CONNECTING, CONNECTING,
CONNECTED, CONNECTED,
FAILED, FAILED,
CLOSED CLOSED
}; };
private: private:
enum class StreamDirection enum class StreamDirection
{ {
INCOMING = 1, INCOMING = 1,
OUTGOING OUTGOING
}; };
public: public:
class Listener class Listener
{ {
public: public:
virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0; virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0; virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0; virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0; virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationSendData( virtual void OnSctpAssociationSendData(
RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0; RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0;
virtual void OnSctpAssociationMessageReceived( virtual void OnSctpAssociationMessageReceived(
RTC::SctpAssociation* sctpAssociation, RTC::SctpAssociation* sctpAssociation,
uint16_t streamId, uint16_t streamId,
uint32_t ppid, uint32_t ppid,
const uint8_t* msg, const uint8_t* msg,
size_t len) = 0; size_t len) = 0;
}; };
public: public:
static bool IsSctp(const uint8_t* data, size_t len) static bool IsSctp(const uint8_t* data, size_t len)
{ {
// clang-format off // clang-format off
return ( return (
(len >= 12) && (len >= 12) &&
// Must have Source Port Number and Destination Port Number set to 5000 (hack). // Must have Source Port Number and Destination Port Number set to 5000 (hack).
(Utils::Byte::Get2Bytes(data, 0) == 5000) && (Utils::Byte::Get2Bytes(data, 0) == 5000) &&
(Utils::Byte::Get2Bytes(data, 2) == 5000) (Utils::Byte::Get2Bytes(data, 2) == 5000)
); );
// clang-format on // clang-format on
} }
public: public:
SctpAssociation( SctpAssociation(
Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel); Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel);
virtual ~SctpAssociation(); virtual ~SctpAssociation();
public: public:
void TransportConnected(); void TransportConnected();
size_t GetMaxSctpMessageSize() const size_t GetMaxSctpMessageSize() const
{ {
return this->maxSctpMessageSize; return this->maxSctpMessageSize;
} }
SctpState GetState() const SctpState GetState() const
{ {
return this->state; return this->state;
} }
void ProcessSctpData(const uint8_t* data, size_t len); void ProcessSctpData(const uint8_t* data, size_t len);
void SendSctpMessage(const RTC::SctpStreamParameters &params, uint32_t ppid, const uint8_t* msg, size_t len); void SendSctpMessage(const RTC::SctpStreamParameters &params, uint32_t ppid, const uint8_t* msg, size_t len);
void HandleDataConsumer(const RTC::SctpStreamParameters &params); void HandleDataConsumer(const RTC::SctpStreamParameters &params);
void DataProducerClosed(const RTC::SctpStreamParameters &params); void DataProducerClosed(const RTC::SctpStreamParameters &params);
void DataConsumerClosed(const RTC::SctpStreamParameters &params); void DataConsumerClosed(const RTC::SctpStreamParameters &params);
private: private:
void ResetSctpStream(uint16_t streamId, StreamDirection); void ResetSctpStream(uint16_t streamId, StreamDirection);
void AddOutgoingStreams(bool force = false); void AddOutgoingStreams(bool force = false);
public: public:
/* Callbacks fired by usrsctp events. */ /* Callbacks fired by usrsctp events. */
virtual void OnUsrSctpSendSctpData(void* buffer, size_t len); virtual void OnUsrSctpSendSctpData(void* buffer, size_t len);
virtual void OnUsrSctpReceiveSctpData(uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len); virtual void OnUsrSctpReceiveSctpData(uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len);
virtual void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len); virtual void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len);
private: private:
// Passed by argument. // Passed by argument.
Listener* listener{ nullptr }; Listener* listener{ nullptr };
uint16_t os{ 1024u }; uint16_t os{ 1024u };
uint16_t mis{ 1024u }; uint16_t mis{ 1024u };
size_t maxSctpMessageSize{ 262144u }; size_t maxSctpMessageSize{ 262144u };
bool isDataChannel{ false }; bool isDataChannel{ false };
// Allocated by this. // Allocated by this.
uint8_t* messageBuffer{ nullptr }; uint8_t* messageBuffer{ nullptr };
// Others. // Others.
SctpState state{ SctpState::NEW }; SctpState state{ SctpState::NEW };
struct socket* socket{ nullptr }; struct socket* socket{ nullptr };
uint16_t desiredOs{ 0u }; uint16_t desiredOs{ 0u };
size_t messageBufferLen{ 0u }; size_t messageBufferLen{ 0u };
uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support. uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support.
std::shared_ptr<SctpEnv> _env; std::shared_ptr<SctpEnv> _env;
}; };
//保证线程安全 //保证线程安全
class SctpAssociationImp : public SctpAssociation, public std::enable_shared_from_this<SctpAssociationImp>{ class SctpAssociationImp : public SctpAssociation, public std::enable_shared_from_this<SctpAssociationImp>{
......
...@@ -97,785 +97,785 @@ namespace RTC ...@@ -97,785 +97,785 @@ namespace RTC
return str; return str;
} }
/* Class variables. */ /* Class variables. */
const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 };
/* Class methods. */ /* Class methods. */
StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) StunPacket* StunPacket::Parse(const uint8_t* data, size_t len)
{ {
MS_TRACE(); MS_TRACE();
if (!StunPacket::IsStun(data, len)) if (!StunPacket::IsStun(data, len))
return nullptr; return nullptr;
/* /*
The message type field is decomposed further into the following The message type field is decomposed further into the following
structure: structure:
0 1 0 1
2 3 4 5 6 7 8 9 0 1 2 3 4 5 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+ +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|M |M |M|M|M|C|M|M|M|C|M|M|M|M| |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
|11|10|9|8|7|1|6|5|4|0|3|2|1|0| |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+ +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 3: Format of STUN Message Type Field Figure 3: Format of STUN Message Type Field
Here the bits in the message type field are shown as most significant Here the bits in the message type field are shown as most significant
(M11) through least significant (M0). M11 through M0 represent a 12- (M11) through least significant (M0). M11 through M0 represent a 12-
bit encoding of the method. C1 and C0 represent a 2-bit encoding of bit encoding of the method. C1 and C0 represent a 2-bit encoding of
the class. the class.
*/ */
// Get type field. // Get type field.
uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); uint16_t msgType = Utils::Byte::Get2Bytes(data, 0);
// Get length field. // Get length field.
uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2);
// length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes.
if ((static_cast<size_t>(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) if ((static_cast<size_t>(msgLength) != len - 20) || ((msgLength & 0x03) != 0))
{ {
MS_WARN_TAG( MS_WARN_TAG(
ice, ice,
"length field + 20 does not match total size (or it is not multiple of 4 bytes), " "length field + 20 does not match total size (or it is not multiple of 4 bytes), "
"packet discarded"); "packet discarded");
return nullptr; return nullptr;
} }
// Get STUN method. // Get STUN method.
uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2);
// Get STUN class. // Get STUN class.
uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4);
// Create a new StunPacket (data + 8 points to the received TransactionID field). // Create a new StunPacket (data + 8 points to the received TransactionID field).
auto* packet = new StunPacket( auto* packet = new StunPacket(
static_cast<Class>(msgClass), static_cast<Method>(msgMethod), data + 8, data, len); static_cast<Class>(msgClass), static_cast<Method>(msgMethod), data + 8, data, len);
/* /*
STUN Attributes STUN Attributes
After the STUN header are zero or more attributes. Each attribute After the STUN header are zero or more attributes. Each attribute
MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. MUST be TLV encoded, with a 16-bit type, 16-bit length, and value.
Each STUN attribute MUST end on a 32-bit boundary. As mentioned Each STUN attribute MUST end on a 32-bit boundary. As mentioned
above, all fields in an attribute are transmitted most significant above, all fields in an attribute are transmitted most significant
bit first. bit first.
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length | | Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Value (variable) .... | Value (variable) ....
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/ */
// Start looking for attributes after STUN header (Byte #20). // Start looking for attributes after STUN header (Byte #20).
size_t pos{ 20 }; size_t pos{ 20 };
// Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes.
bool hasMessageIntegrity{ false }; bool hasMessageIntegrity{ false };
bool hasFingerprint{ false }; bool hasFingerprint{ false };
size_t fingerprintAttrPos; // Will point to the beginning of the attribute. size_t fingerprintAttrPos; // Will point to the beginning of the attribute.
uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute.
// Ensure there are at least 4 remaining bytes (attribute with 0 length). // Ensure there are at least 4 remaining bytes (attribute with 0 length).
while (pos + 4 <= len) while (pos + 4 <= len)
{ {
// Get the attribute type. // Get the attribute type.
auto attrType = static_cast<Attribute>(Utils::Byte::Get2Bytes(data, pos)); auto attrType = static_cast<Attribute>(Utils::Byte::Get2Bytes(data, pos));
// Get the attribute length. // Get the attribute length.
uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2);
// Ensure the attribute length is not greater than the remaining size. // Ensure the attribute length is not greater than the remaining size.
if ((pos + 4 + attrLength) > len) if ((pos + 4 + attrLength) > len)
{ {
MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded"); MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
// FINGERPRINT must be the last attribute. // FINGERPRINT must be the last attribute.
if (hasFingerprint) if (hasFingerprint)
{ {
MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded"); MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
// After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed.
if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT)
{ {
MS_WARN_TAG( MS_WARN_TAG(
ice, ice,
"attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, "
"packet discarded"); "packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
const uint8_t* attrValuePos = data + pos + 4; const uint8_t* attrValuePos = data + pos + 4;
switch (attrType) switch (attrType)
{ {
case Attribute::USERNAME: case Attribute::USERNAME:
{ {
packet->SetUsername( packet->SetUsername(
reinterpret_cast<const char*>(attrValuePos), static_cast<size_t>(attrLength)); reinterpret_cast<const char*>(attrValuePos), static_cast<size_t>(attrLength));
break; break;
} }
case Attribute::PRIORITY: case Attribute::PRIORITY:
{ {
// Ensure attribute length is 4 bytes. // Ensure attribute length is 4 bytes.
if (attrLength != 4) if (attrLength != 4)
{ {
MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0));
break; break;
} }
case Attribute::ICE_CONTROLLING: case Attribute::ICE_CONTROLLING:
{ {
// Ensure attribute length is 8 bytes. // Ensure attribute length is 8 bytes.
if (attrLength != 8) if (attrLength != 8)
{ {
MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0));
break; break;
} }
case Attribute::ICE_CONTROLLED: case Attribute::ICE_CONTROLLED:
{ {
// Ensure attribute length is 8 bytes. // Ensure attribute length is 8 bytes.
if (attrLength != 8) if (attrLength != 8)
{ {
MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0));
break; break;
} }
case Attribute::USE_CANDIDATE: case Attribute::USE_CANDIDATE:
{ {
// Ensure attribute length is 0 bytes. // Ensure attribute length is 0 bytes.
if (attrLength != 0) if (attrLength != 0)
{ {
MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
packet->SetUseCandidate(); packet->SetUseCandidate();
break; break;
} }
case Attribute::MESSAGE_INTEGRITY: case Attribute::MESSAGE_INTEGRITY:
{ {
// Ensure attribute length is 20 bytes. // Ensure attribute length is 20 bytes.
if (attrLength != 20) if (attrLength != 20)
{ {
MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
hasMessageIntegrity = true; hasMessageIntegrity = true;
packet->SetMessageIntegrity(attrValuePos); packet->SetMessageIntegrity(attrValuePos);
break; break;
} }
case Attribute::FINGERPRINT: case Attribute::FINGERPRINT:
{ {
// Ensure attribute length is 4 bytes. // Ensure attribute length is 4 bytes.
if (attrLength != 4) if (attrLength != 4)
{ {
MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
hasFingerprint = true; hasFingerprint = true;
fingerprintAttrPos = pos; fingerprintAttrPos = pos;
fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0);
packet->SetFingerprint(); packet->SetFingerprint();
break; break;
} }
case Attribute::ERROR_CODE: case Attribute::ERROR_CODE:
{ {
// Ensure attribute length >= 4bytes. // Ensure attribute length >= 4bytes.
if (attrLength < 4) if (attrLength < 4)
{ {
MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded"); MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2);
uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3);
auto errorCode = static_cast<uint16_t>(errorClass * 100 + errorNumber); auto errorCode = static_cast<uint16_t>(errorClass * 100 + errorNumber);
packet->SetErrorCode(errorCode); packet->SetErrorCode(errorCode);
break; break;
} }
default:; default:;
} }
// Set next attribute position. // Set next attribute position.
pos = pos =
static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(pos + 4 + attrLength))); static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(pos + 4 + attrLength)));
} }
// Ensure current position matches the total length. // Ensure current position matches the total length.
if (pos != len) if (pos != len)
{ {
MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded"); MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
// If it has FINGERPRINT attribute then verify it. // If it has FINGERPRINT attribute then verify it.
if (hasFingerprint) if (hasFingerprint)
{ {
// Compute the CRC32 of the received packet up to (but excluding) the // Compute the CRC32 of the received packet up to (but excluding) the
// FINGERPRINT attribute and XOR it with 0x5354554e. // FINGERPRINT attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e;
// Compare with the FINGERPRINT value in the packet. // Compare with the FINGERPRINT value in the packet.
if (fingerprint != computedFingerprint) if (fingerprint != computedFingerprint)
{ {
MS_WARN_TAG( MS_WARN_TAG(
ice, ice,
"computed FINGERPRINT value does not match the value in the packet, " "computed FINGERPRINT value does not match the value in the packet, "
"packet discarded"); "packet discarded");
delete packet; delete packet;
return nullptr; return nullptr;
} }
} }
return packet; return packet;
} }
/* Instance methods. */ /* Instance methods. */
StunPacket::StunPacket( StunPacket::StunPacket(
Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size) Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size)
: klass(klass), method(method), transactionId(transactionId), data(const_cast<uint8_t*>(data)), : klass(klass), method(method), transactionId(transactionId), data(const_cast<uint8_t*>(data)),
size(size) size(size)
{ {
MS_TRACE(); MS_TRACE();
} }
StunPacket::~StunPacket() StunPacket::~StunPacket()
{ {
MS_TRACE(); MS_TRACE();
} }
#if 0 #if 0
void StunPacket::Dump() const void StunPacket::Dump() const
{ {
MS_TRACE(); MS_TRACE();
MS_DUMP("<StunPacket>"); MS_DUMP("<StunPacket>");
std::string klass; std::string klass;
switch (this->klass) switch (this->klass)
{ {
case Class::REQUEST: case Class::REQUEST:
klass = "Request"; klass = "Request";
break; break;
case Class::INDICATION: case Class::INDICATION:
klass = "Indication"; klass = "Indication";
break; break;
case Class::SUCCESS_RESPONSE: case Class::SUCCESS_RESPONSE:
klass = "SuccessResponse"; klass = "SuccessResponse";
break; break;
case Class::ERROR_RESPONSE: case Class::ERROR_RESPONSE:
klass = "ErrorResponse"; klass = "ErrorResponse";
break; break;
} }
if (this->method == Method::BINDING) if (this->method == Method::BINDING)
{ {
MS_DUMP(" Binding %s", klass.c_str()); MS_DUMP(" Binding %s", klass.c_str());
} }
else else
{ {
// This prints the unknown method number. Example: TURN Allocate => 0x003. // This prints the unknown method number. Example: TURN Allocate => 0x003.
MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast<uint16_t>(this->method)); MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast<uint16_t>(this->method));
} }
MS_DUMP(" size: %zu bytes", this->size); MS_DUMP(" size: %zu bytes", this->size);
static char transactionId[25]; static char transactionId[25];
for (int i{ 0 }; i < 12; ++i) for (int i{ 0 }; i < 12; ++i)
{ {
// NOTE: n must be 3 because snprintf adds a \0 after printed chars. // NOTE: n must be 3 because snprintf adds a \0 after printed chars.
std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]);
} }
MS_DUMP(" transactionId: %s", transactionId); MS_DUMP(" transactionId: %s", transactionId);
if (this->errorCode != 0u) if (this->errorCode != 0u)
MS_DUMP(" errorCode: %" PRIu16, this->errorCode); MS_DUMP(" errorCode: %" PRIu16, this->errorCode);
if (!this->username.empty()) if (!this->username.empty())
MS_DUMP(" username: %s", this->username.c_str()); MS_DUMP(" username: %s", this->username.c_str());
if (this->priority != 0u) if (this->priority != 0u)
MS_DUMP(" priority: %" PRIu32, this->priority); MS_DUMP(" priority: %" PRIu32, this->priority);
if (this->iceControlling != 0u) if (this->iceControlling != 0u)
MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling);
if (this->iceControlled != 0u) if (this->iceControlled != 0u)
MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled);
if (this->hasUseCandidate) if (this->hasUseCandidate)
MS_DUMP(" useCandidate"); MS_DUMP(" useCandidate");
if (this->xorMappedAddress != nullptr) if (this->xorMappedAddress != nullptr)
{ {
int family; int family;
uint16_t port; uint16_t port;
std::string ip; std::string ip;
Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port);
MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port);
} }
if (this->messageIntegrity != nullptr) if (this->messageIntegrity != nullptr)
{ {
static char messageIntegrity[41]; static char messageIntegrity[41];
for (int i{ 0 }; i < 20; ++i) for (int i{ 0 }; i < 20; ++i)
{ {
std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]);
} }
MS_DUMP(" messageIntegrity: %s", messageIntegrity); MS_DUMP(" messageIntegrity: %s", messageIntegrity);
} }
if (this->hasFingerprint) if (this->hasFingerprint)
MS_DUMP(" has fingerprint"); MS_DUMP(" has fingerprint");
MS_DUMP("</StunPacket>"); MS_DUMP("</StunPacket>");
} }
#endif #endif
StunPacket::Authentication StunPacket::CheckAuthentication( StunPacket::Authentication StunPacket::CheckAuthentication(
const std::string& localUsername, const std::string& localPassword) const std::string& localUsername, const std::string& localPassword)
{ {
MS_TRACE(); MS_TRACE();
switch (this->klass) switch (this->klass)
{ {
case Class::REQUEST: case Class::REQUEST:
case Class::INDICATION: case Class::INDICATION:
{ {
// Both USERNAME and MESSAGE-INTEGRITY must be present. // Both USERNAME and MESSAGE-INTEGRITY must be present.
if (!this->messageIntegrity || this->username.empty()) if (!this->messageIntegrity || this->username.empty())
return Authentication::BAD_REQUEST; return Authentication::BAD_REQUEST;
// Check that USERNAME attribute begins with our local username plus ":". // Check that USERNAME attribute begins with our local username plus ":".
size_t localUsernameLen = localUsername.length(); size_t localUsernameLen = localUsername.length();
if ( if (
this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' || this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' ||
(this->username.compare(0, localUsernameLen, localUsername) != 0)) (this->username.compare(0, localUsernameLen, localUsername) != 0))
{ {
return Authentication::UNAUTHORIZED; return Authentication::UNAUTHORIZED;
} }
break; break;
} }
// This method cannot check authentication in received responses (as we // This method cannot check authentication in received responses (as we
// are ICE-Lite and don't generate requests). // are ICE-Lite and don't generate requests).
case Class::SUCCESS_RESPONSE: case Class::SUCCESS_RESPONSE:
case Class::ERROR_RESPONSE: case Class::ERROR_RESPONSE:
{ {
MS_ERROR("cannot check authentication for a STUN response"); MS_ERROR("cannot check authentication for a STUN response");
return Authentication::BAD_REQUEST; return Authentication::BAD_REQUEST;
} }
} }
// If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation,
// so the header length field must be modified (and later restored). // so the header length field must be modified (and later restored).
if (this->hasFingerprint) if (this->hasFingerprint)
// Set the header length field: full size - header length (20) - FINGERPRINT length (8). // Set the header length field: full size - header length (20) - FINGERPRINT length (8).
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8)); Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules.
auto computedMessageIntegrity = openssl_HMACsha1( auto computedMessageIntegrity = openssl_HMACsha1(
localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data); localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data);
Authentication result; Authentication result;
// Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet.
if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0)
result = Authentication::OK; result = Authentication::OK;
else else
result = Authentication::UNAUTHORIZED; result = Authentication::UNAUTHORIZED;
// Restore the header length field. // Restore the header length field.
if (this->hasFingerprint) if (this->hasFingerprint)
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20)); Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20));
return result; return result;
} }
StunPacket* StunPacket::CreateSuccessResponse()
{
MS_TRACE();
MS_ASSERT(
this->klass == Class::REQUEST,
"attempt to create a success response for a non Request STUN packet");
return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0);
}
StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode)
{
MS_TRACE();
StunPacket* StunPacket::CreateSuccessResponse() MS_ASSERT(
{ this->klass == Class::REQUEST,
MS_TRACE(); "attempt to create an error response for a non Request STUN packet");
MS_ASSERT( auto* response =
this->klass == Class::REQUEST, new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0);
"attempt to create a success response for a non Request STUN packet");
return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); response->SetErrorCode(errorCode);
}
StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) return response;
{ }
MS_TRACE();
MS_ASSERT( void StunPacket::Authenticate(const std::string& password)
this->klass == Class::REQUEST, {
"attempt to create an error response for a non Request STUN packet"); // Just for Request, Indication and SuccessResponse messages.
if (this->klass == Class::ERROR_RESPONSE)
{
MS_ERROR("cannot set password for ErrorResponse messages");
auto* response = return;
new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); }
response->SetErrorCode(errorCode); this->password = password;
}
return response; void StunPacket::Serialize(uint8_t* buffer)
} {
MS_TRACE();
// Some useful variables.
uint16_t usernamePaddedLen{ 0 };
uint16_t xorMappedAddressPaddedLen{ 0 };
bool addXorMappedAddress =
((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING &&
this->klass == Class::SUCCESS_RESPONSE);
bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE);
bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty());
bool addFingerprint{ true }; // Do always.
// Update data pointer.
this->data = buffer;
// First calculate the total required size for the entire packet.
this->size = 20; // Header.
if (!this->username.empty())
{
usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(this->username.length()));
this->size += 4 + usernamePaddedLen;
}
if (this->priority != 0u)
this->size += 4 + 4;
if (this->iceControlling != 0u)
this->size += 4 + 8;
if (this->iceControlled != 0u)
this->size += 4 + 8;
if (this->hasUseCandidate)
this->size += 4;
if (addXorMappedAddress)
{
switch (this->xorMappedAddress->sa_family)
{
case AF_INET:
{
xorMappedAddressPaddedLen = 8;
this->size += 4 + 8;
break;
}
case AF_INET6:
{
xorMappedAddressPaddedLen = 20;
this->size += 4 + 20;
break;
}
default:
{
MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute");
addXorMappedAddress = false;
}
}
}
if (addErrorCode)
this->size += 4 + 4;
if (addMessageIntegrity)
this->size += 4 + 20;
if (addFingerprint)
this->size += 4 + 4;
// Merge class and method fields into type.
uint16_t typeField = (static_cast<uint16_t>(this->method) & 0x0f80) << 2;
typeField |= (static_cast<uint16_t>(this->method) & 0x0070) << 1;
typeField |= (static_cast<uint16_t>(this->method) & 0x000f);
typeField |= (static_cast<uint16_t>(this->klass) & 0x02) << 7;
typeField |= (static_cast<uint16_t>(this->klass) & 0x01) << 4;
// Set type field.
Utils::Byte::Set2Bytes(buffer, 0, typeField);
// Set length field.
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size) - 20);
// Set magic cookie.
std::memcpy(buffer + 4, StunPacket::magicCookie, 4);
// Set TransactionId field.
std::memcpy(buffer + 8, this->transactionId, 12);
// Update the transaction ID pointer.
this->transactionId = buffer + 8;
// Add atributes.
size_t pos{ 20 };
// Add USERNAME.
if (usernamePaddedLen != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USERNAME));
Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast<uint16_t>(this->username.length()));
std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length());
pos += 4 + usernamePaddedLen;
}
void StunPacket::Authenticate(const std::string& password) // Add PRIORITY.
{ if (this->priority != 0u)
// Just for Request, Indication and SuccessResponse messages. {
if (this->klass == Class::ERROR_RESPONSE) Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::PRIORITY));
{ Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
MS_ERROR("cannot set password for ErrorResponse messages"); Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority);
pos += 4 + 4;
return; }
}
// Add ICE-CONTROLLING.
this->password = password; if (this->iceControlling != 0u)
} {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLING));
void StunPacket::Serialize(uint8_t* buffer) Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
{ Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling);
MS_TRACE(); pos += 4 + 8;
}
// Some useful variables.
uint16_t usernamePaddedLen{ 0 }; // Add ICE-CONTROLLED.
uint16_t xorMappedAddressPaddedLen{ 0 }; if (this->iceControlled != 0u)
bool addXorMappedAddress = {
((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLED));
this->klass == Class::SUCCESS_RESPONSE); Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled);
bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); pos += 4 + 8;
bool addFingerprint{ true }; // Do always. }
// Update data pointer. // Add USE-CANDIDATE.
this->data = buffer; if (this->hasUseCandidate)
{
// First calculate the total required size for the entire packet. Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USE_CANDIDATE));
this->size = 20; // Header. Utils::Byte::Set2Bytes(buffer, pos + 2, 0);
pos += 4;
if (!this->username.empty()) }
{
usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(this->username.length())); // Add XOR-MAPPED-ADDRESS
this->size += 4 + usernamePaddedLen; if (addXorMappedAddress)
} {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::XOR_MAPPED_ADDRESS));
if (this->priority != 0u) Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen);
this->size += 4 + 4;
uint8_t* attrValue = buffer + pos + 4;
if (this->iceControlling != 0u)
this->size += 4 + 8; switch (this->xorMappedAddress->sa_family)
{
if (this->iceControlled != 0u) case AF_INET:
this->size += 4 + 8; {
// Set first byte to 0.
if (this->hasUseCandidate) attrValue[0] = 0;
this->size += 4; // Set inet family.
attrValue[1] = 0x01;
if (addXorMappedAddress) // Set port and XOR it.
{ std::memcpy(
switch (this->xorMappedAddress->sa_family) attrValue + 2,
{ &(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_port,
case AF_INET: 2);
{ attrValue[2] ^= StunPacket::magicCookie[0];
xorMappedAddressPaddedLen = 8; attrValue[3] ^= StunPacket::magicCookie[1];
this->size += 4 + 8; // Set address and XOR it.
std::memcpy(
break; attrValue + 4,
} &(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_addr.s_addr,
4);
case AF_INET6: attrValue[4] ^= StunPacket::magicCookie[0];
{ attrValue[5] ^= StunPacket::magicCookie[1];
xorMappedAddressPaddedLen = 20; attrValue[6] ^= StunPacket::magicCookie[2];
this->size += 4 + 20; attrValue[7] ^= StunPacket::magicCookie[3];
break; pos += 4 + 8;
}
break;
default: }
{
MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); case AF_INET6:
{
addXorMappedAddress = false; // Set first byte to 0.
} attrValue[0] = 0;
} // Set inet family.
} attrValue[1] = 0x02;
// Set port and XOR it.
if (addErrorCode) std::memcpy(
this->size += 4 + 4; attrValue + 2,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_port,
if (addMessageIntegrity) 2);
this->size += 4 + 20; attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
if (addFingerprint) // Set address and XOR it.
this->size += 4 + 4; std::memcpy(
attrValue + 4,
// Merge class and method fields into type. &(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_addr.s6_addr,
uint16_t typeField = (static_cast<uint16_t>(this->method) & 0x0f80) << 2; 16);
attrValue[4] ^= StunPacket::magicCookie[0];
typeField |= (static_cast<uint16_t>(this->method) & 0x0070) << 1; attrValue[5] ^= StunPacket::magicCookie[1];
typeField |= (static_cast<uint16_t>(this->method) & 0x000f); attrValue[6] ^= StunPacket::magicCookie[2];
typeField |= (static_cast<uint16_t>(this->klass) & 0x02) << 7; attrValue[7] ^= StunPacket::magicCookie[3];
typeField |= (static_cast<uint16_t>(this->klass) & 0x01) << 4; attrValue[8] ^= this->transactionId[0];
attrValue[9] ^= this->transactionId[1];
// Set type field. attrValue[10] ^= this->transactionId[2];
Utils::Byte::Set2Bytes(buffer, 0, typeField); attrValue[11] ^= this->transactionId[3];
// Set length field. attrValue[12] ^= this->transactionId[4];
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size) - 20); attrValue[13] ^= this->transactionId[5];
// Set magic cookie. attrValue[14] ^= this->transactionId[6];
std::memcpy(buffer + 4, StunPacket::magicCookie, 4); attrValue[15] ^= this->transactionId[7];
// Set TransactionId field. attrValue[16] ^= this->transactionId[8];
std::memcpy(buffer + 8, this->transactionId, 12); attrValue[17] ^= this->transactionId[9];
// Update the transaction ID pointer. attrValue[18] ^= this->transactionId[10];
this->transactionId = buffer + 8; attrValue[19] ^= this->transactionId[11];
// Add atributes.
size_t pos{ 20 }; pos += 4 + 20;
// Add USERNAME. break;
if (usernamePaddedLen != 0u) }
{ }
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USERNAME)); }
Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast<uint16_t>(this->username.length()));
std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); // Add ERROR-CODE.
pos += 4 + usernamePaddedLen; if (addErrorCode)
} {
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ERROR_CODE));
// Add PRIORITY. Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
if (this->priority != 0u)
{ auto codeClass = static_cast<uint8_t>(this->errorCode / 100);
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::PRIORITY)); uint8_t codeNumber = static_cast<uint8_t>(this->errorCode) - (codeClass * 100);
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); Utils::Byte::Set2Bytes(buffer, pos + 4, 0);
pos += 4 + 4; Utils::Byte::Set1Byte(buffer, pos + 6, codeClass);
} Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber);
pos += 4 + 4;
// Add ICE-CONTROLLING. }
if (this->iceControlling != 0u)
{ // Add MESSAGE-INTEGRITY.
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLING)); if (addMessageIntegrity)
Utils::Byte::Set2Bytes(buffer, pos + 2, 8); {
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); // Ignore FINGERPRINT.
pos += 4 + 8; if (addFingerprint)
} Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Add ICE-CONTROLLED. // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules.
if (this->iceControlled != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLED));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled);
pos += 4 + 8;
}
// Add USE-CANDIDATE.
if (this->hasUseCandidate)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USE_CANDIDATE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 0);
pos += 4;
}
// Add XOR-MAPPED-ADDRESS
if (addXorMappedAddress)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::XOR_MAPPED_ADDRESS));
Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen);
uint8_t* attrValue = buffer + pos + 4;
switch (this->xorMappedAddress->sa_family)
{
case AF_INET:
{
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x01;
// Set port and XOR it.
std::memcpy(
attrValue + 2,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_port,
2);
attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_addr.s_addr,
4);
attrValue[4] ^= StunPacket::magicCookie[0];
attrValue[5] ^= StunPacket::magicCookie[1];
attrValue[6] ^= StunPacket::magicCookie[2];
attrValue[7] ^= StunPacket::magicCookie[3];
pos += 4 + 8;
break;
}
case AF_INET6:
{
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x02;
// Set port and XOR it.
std::memcpy(
attrValue + 2,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_port,
2);
attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_addr.s6_addr,
16);
attrValue[4] ^= StunPacket::magicCookie[0];
attrValue[5] ^= StunPacket::magicCookie[1];
attrValue[6] ^= StunPacket::magicCookie[2];
attrValue[7] ^= StunPacket::magicCookie[3];
attrValue[8] ^= this->transactionId[0];
attrValue[9] ^= this->transactionId[1];
attrValue[10] ^= this->transactionId[2];
attrValue[11] ^= this->transactionId[3];
attrValue[12] ^= this->transactionId[4];
attrValue[13] ^= this->transactionId[5];
attrValue[14] ^= this->transactionId[6];
attrValue[15] ^= this->transactionId[7];
attrValue[16] ^= this->transactionId[8];
attrValue[17] ^= this->transactionId[9];
attrValue[18] ^= this->transactionId[10];
attrValue[19] ^= this->transactionId[11];
pos += 4 + 20;
break;
}
}
}
// Add ERROR-CODE.
if (addErrorCode)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ERROR_CODE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
auto codeClass = static_cast<uint8_t>(this->errorCode / 100);
uint8_t codeNumber = static_cast<uint8_t>(this->errorCode) - (codeClass * 100);
Utils::Byte::Set2Bytes(buffer, pos + 4, 0);
Utils::Byte::Set1Byte(buffer, pos + 6, codeClass);
Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber);
pos += 4 + 4;
}
// Add MESSAGE-INTEGRITY.
if (addMessageIntegrity)
{
// Ignore FINGERPRINT.
if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules.
auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos); auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos);
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::MESSAGE_INTEGRITY)); Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::MESSAGE_INTEGRITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 20); Utils::Byte::Set2Bytes(buffer, pos + 2, 20);
std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size());
// Update the pointer. // Update the pointer.
this->messageIntegrity = buffer + pos + 4; this->messageIntegrity = buffer + pos + 4;
pos += 4 + 20; pos += 4 + 20;
// Restore length field. // Restore length field.
if (addFingerprint) if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20)); Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20));
} }
else else
{ {
// Unset the pointer (if it was set). // Unset the pointer (if it was set).
this->messageIntegrity = nullptr; this->messageIntegrity = nullptr;
} }
// Add FINGERPRINT. // Add FINGERPRINT.
if (addFingerprint) if (addFingerprint)
{ {
// Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT
// attribute and XOR it with 0x5354554e. // attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e;
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT)); Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4); Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint);
pos += 4 + 4; pos += 4 + 4;
// Set flag. // Set flag.
this->hasFingerprint = true; this->hasFingerprint = true;
} }
else else
{ {
this->hasFingerprint = false; this->hasFingerprint = false;
} }
MS_ASSERT(pos == this->size, "pos != this->size"); MS_ASSERT(pos == this->size, "pos != this->size");
} }
} // namespace RTC } // namespace RTC
...@@ -26,188 +26,188 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ...@@ -26,188 +26,188 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
namespace RTC namespace RTC
{ {
class StunPacket class StunPacket
{ {
public: public:
// STUN message class. // STUN message class.
enum class Class : uint16_t enum class Class : uint16_t
{ {
REQUEST = 0, REQUEST = 0,
INDICATION = 1, INDICATION = 1,
SUCCESS_RESPONSE = 2, SUCCESS_RESPONSE = 2,
ERROR_RESPONSE = 3 ERROR_RESPONSE = 3
}; };
// STUN message method. // STUN message method.
enum class Method : uint16_t enum class Method : uint16_t
{ {
BINDING = 1 BINDING = 1
}; };
// Attribute type. // Attribute type.
enum class Attribute : uint16_t enum class Attribute : uint16_t
{ {
MAPPED_ADDRESS = 0x0001, MAPPED_ADDRESS = 0x0001,
USERNAME = 0x0006, USERNAME = 0x0006,
MESSAGE_INTEGRITY = 0x0008, MESSAGE_INTEGRITY = 0x0008,
ERROR_CODE = 0x0009, ERROR_CODE = 0x0009,
UNKNOWN_ATTRIBUTES = 0x000A, UNKNOWN_ATTRIBUTES = 0x000A,
REALM = 0x0014, REALM = 0x0014,
NONCE = 0x0015, NONCE = 0x0015,
XOR_MAPPED_ADDRESS = 0x0020, XOR_MAPPED_ADDRESS = 0x0020,
PRIORITY = 0x0024, PRIORITY = 0x0024,
USE_CANDIDATE = 0x0025, USE_CANDIDATE = 0x0025,
SOFTWARE = 0x8022, SOFTWARE = 0x8022,
ALTERNATE_SERVER = 0x8023, ALTERNATE_SERVER = 0x8023,
FINGERPRINT = 0x8028, FINGERPRINT = 0x8028,
ICE_CONTROLLED = 0x8029, ICE_CONTROLLED = 0x8029,
ICE_CONTROLLING = 0x802A ICE_CONTROLLING = 0x802A
}; };
// Authentication result. // Authentication result.
enum class Authentication enum class Authentication
{ {
OK = 0, OK = 0,
UNAUTHORIZED = 1, UNAUTHORIZED = 1,
BAD_REQUEST = 2 BAD_REQUEST = 2
}; };
public: public:
static bool IsStun(const uint8_t* data, size_t len) static bool IsStun(const uint8_t* data, size_t len)
{ {
// clang-format off // clang-format off
return ( return (
// STUN headers are 20 bytes. // STUN headers are 20 bytes.
(len >= 20) && (len >= 20) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] < 3) && (data[0] < 3) &&
// Magic cookie must match. // Magic cookie must match.
(data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) && (data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) &&
(data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3]) (data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3])
); );
// clang-format on // clang-format on
} }
static StunPacket* Parse(const uint8_t* data, size_t len); static StunPacket* Parse(const uint8_t* data, size_t len);
private: private:
static const uint8_t magicCookie[]; static const uint8_t magicCookie[];
public: public:
StunPacket( StunPacket(
Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size); Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size);
~StunPacket(); ~StunPacket();
void Dump() const; void Dump() const;
Class GetClass() const Class GetClass() const
{ {
return this->klass; return this->klass;
} }
Method GetMethod() const Method GetMethod() const
{ {
return this->method; return this->method;
} }
const uint8_t* GetData() const const uint8_t* GetData() const
{ {
return this->data; return this->data;
} }
size_t GetSize() const size_t GetSize() const
{ {
return this->size; return this->size;
} }
void SetUsername(const char* username, size_t len) void SetUsername(const char* username, size_t len)
{ {
this->username.assign(username, len); this->username.assign(username, len);
} }
void SetPriority(uint32_t priority) void SetPriority(uint32_t priority)
{ {
this->priority = priority; this->priority = priority;
} }
void SetIceControlling(uint64_t iceControlling) void SetIceControlling(uint64_t iceControlling)
{ {
this->iceControlling = iceControlling; this->iceControlling = iceControlling;
} }
void SetIceControlled(uint64_t iceControlled) void SetIceControlled(uint64_t iceControlled)
{ {
this->iceControlled = iceControlled; this->iceControlled = iceControlled;
} }
void SetUseCandidate() void SetUseCandidate()
{ {
this->hasUseCandidate = true; this->hasUseCandidate = true;
} }
void SetXorMappedAddress(const struct sockaddr* xorMappedAddress) void SetXorMappedAddress(const struct sockaddr* xorMappedAddress)
{ {
this->xorMappedAddress = xorMappedAddress; this->xorMappedAddress = xorMappedAddress;
} }
void SetErrorCode(uint16_t errorCode) void SetErrorCode(uint16_t errorCode)
{ {
this->errorCode = errorCode; this->errorCode = errorCode;
} }
void SetMessageIntegrity(const uint8_t* messageIntegrity) void SetMessageIntegrity(const uint8_t* messageIntegrity)
{ {
this->messageIntegrity = messageIntegrity; this->messageIntegrity = messageIntegrity;
} }
void SetFingerprint() void SetFingerprint()
{ {
this->hasFingerprint = true; this->hasFingerprint = true;
} }
const std::string& GetUsername() const const std::string& GetUsername() const
{ {
return this->username; return this->username;
} }
uint32_t GetPriority() const uint32_t GetPriority() const
{ {
return this->priority; return this->priority;
} }
uint64_t GetIceControlling() const uint64_t GetIceControlling() const
{ {
return this->iceControlling; return this->iceControlling;
} }
uint64_t GetIceControlled() const uint64_t GetIceControlled() const
{ {
return this->iceControlled; return this->iceControlled;
} }
bool HasUseCandidate() const bool HasUseCandidate() const
{ {
return this->hasUseCandidate; return this->hasUseCandidate;
} }
uint16_t GetErrorCode() const uint16_t GetErrorCode() const
{ {
return this->errorCode; return this->errorCode;
} }
bool HasMessageIntegrity() const bool HasMessageIntegrity() const
{ {
return (this->messageIntegrity ? true : false); return (this->messageIntegrity ? true : false);
} }
bool HasFingerprint() const bool HasFingerprint() const
{ {
return this->hasFingerprint; return this->hasFingerprint;
} }
Authentication CheckAuthentication( Authentication CheckAuthentication(
const std::string& localUsername, const std::string& localPassword); const std::string& localUsername, const std::string& localPassword);
StunPacket* CreateSuccessResponse(); StunPacket* CreateSuccessResponse();
StunPacket* CreateErrorResponse(uint16_t errorCode); StunPacket* CreateErrorResponse(uint16_t errorCode);
void Authenticate(const std::string& password); void Authenticate(const std::string& password);
void Serialize(uint8_t* buffer); void Serialize(uint8_t* buffer);
private: private:
// Passed by argument. // Passed by argument.
Class klass; // 2 bytes. Class klass; // 2 bytes.
Method method; // 2 bytes. Method method; // 2 bytes.
const uint8_t* transactionId{ nullptr }; // 12 bytes. const uint8_t* transactionId{ nullptr }; // 12 bytes.
uint8_t* data{ nullptr }; // Pointer to binary data. uint8_t* data{ nullptr }; // Pointer to binary data.
size_t size{ 0u }; // The full message size (including header). size_t size{ 0u }; // The full message size (including header).
// STUN attributes. // STUN attributes.
std::string username; // Less than 513 bytes. std::string username; // Less than 513 bytes.
uint32_t priority{ 0u }; // 4 bytes unsigned integer. uint32_t priority{ 0u }; // 4 bytes unsigned integer.
uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer. uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer.
uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer. uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer.
bool hasUseCandidate{ false }; // 0 bytes. bool hasUseCandidate{ false }; // 0 bytes.
const uint8_t* messageIntegrity{ nullptr }; // 20 bytes. const uint8_t* messageIntegrity{ nullptr }; // 20 bytes.
bool hasFingerprint{ false }; // 4 bytes. bool hasFingerprint{ false }; // 4 bytes.
const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes. const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes.
uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase). uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase).
std::string password; std::string password;
}; };
} // namespace RTC } // namespace RTC
#endif #endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论