Commit 56d6eb0f by ziyue

批量替换tab为4个空格

parent a5c3db4e
......@@ -269,9 +269,9 @@ API_EXPORT int API_CALL mk_media_input_aac(mk_media ctx, const void *data, int l
}
API_EXPORT int API_CALL mk_media_input_pcm(mk_media ctx, void *data , int len, uint64_t pts){
assert(ctx && data && len > 0);
MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx;
return (*obj)->getChannel()->inputPCM((char*)data, len, pts);
assert(ctx && data && len > 0);
MediaHelper::Ptr* obj = (MediaHelper::Ptr*) ctx;
return (*obj)->getChannel()->inputPCM((char*)data, len, pts);
}
API_EXPORT int API_CALL mk_media_input_audio(mk_media ctx, const void* data, int len, uint64_t dts){
......
......@@ -33,1453 +33,1453 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
using namespace std;
#define LOG_OPENSSL_ERROR(desc) \
do \
{ \
if (ERR_peek_error() == 0) \
MS_ERROR("OpenSSL error [desc:'%s']", desc); \
else \
{ \
int64_t err; \
while ((err = ERR_get_error()) != 0) \
{ \
MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \
} \
ERR_clear_error(); \
} \
} while (false)
do \
{ \
if (ERR_peek_error() == 0) \
MS_ERROR("OpenSSL error [desc:'%s']", desc); \
else \
{ \
int64_t err; \
while ((err = ERR_get_error()) != 0) \
{ \
MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \
} \
ERR_clear_error(); \
} \
} while (false)
/* Static methods for OpenSSL callbacks. */
inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/)
{
MS_TRACE();
MS_TRACE();
// Always valid since DTLS certificates are self-signed.
return 1;
// Always valid since DTLS certificates are self-signed.
return 1;
}
inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs)
{
if (timerUs == 0)
return 100000;
else if (timerUs >= 4000000)
return 4000000;
else
return 2 * timerUs;
if (timerUs == 0)
return 100000;
else if (timerUs >= 4000000)
return 4000000;
else
return 2 * timerUs;
}
namespace RTC
{
/* Static. */
// clang-format off
static constexpr int DtlsMtu{ 1350 };
// AES-HMAC: http://tools.ietf.org/html/rfc3711
static constexpr size_t SrtpMasterKeyLength{ 16 };
static constexpr size_t SrtpMasterSaltLength{ 14 };
static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength };
// AES-GCM: http://tools.ietf.org/html/rfc7714
static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 };
static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength };
static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 };
static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength };
// clang-format on
/* Class variables. */
// clang-format off
std::map<std::string, DtlsTransport::FingerprintAlgorithm> DtlsTransport::string2FingerprintAlgorithm =
{
{ "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 },
{ "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 },
{ "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 },
{ "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 },
{ "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 }
};
std::map<DtlsTransport::FingerprintAlgorithm, std::string> DtlsTransport::fingerprintAlgorithm2String =
{
{ DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" },
{ DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" },
{ DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" },
{ DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" },
{ DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" }
};
std::map<std::string, DtlsTransport::Role> DtlsTransport::string2Role =
{
{ "auto", DtlsTransport::Role::AUTO },
{ "client", DtlsTransport::Role::CLIENT },
{ "server", DtlsTransport::Role::SERVER }
};
std::vector<DtlsTransport::SrtpCryptoSuiteMapEntry> DtlsTransport::srtpCryptoSuites =
{
{ RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" },
{ RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" },
{ RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" },
{ RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" }
};
// clang-format on
INSTANCE_IMP(DtlsTransport::DtlsEnvironment);
/* Class methods. */
/* Static. */
// clang-format off
static constexpr int DtlsMtu{ 1350 };
// AES-HMAC: http://tools.ietf.org/html/rfc3711
static constexpr size_t SrtpMasterKeyLength{ 16 };
static constexpr size_t SrtpMasterSaltLength{ 14 };
static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength };
// AES-GCM: http://tools.ietf.org/html/rfc7714
static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 };
static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength };
static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 };
static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 };
static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength };
// clang-format on
/* Class variables. */
// clang-format off
std::map<std::string, DtlsTransport::FingerprintAlgorithm> DtlsTransport::string2FingerprintAlgorithm =
{
{ "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 },
{ "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 },
{ "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 },
{ "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 },
{ "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 }
};
std::map<DtlsTransport::FingerprintAlgorithm, std::string> DtlsTransport::fingerprintAlgorithm2String =
{
{ DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" },
{ DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" },
{ DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" },
{ DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" },
{ DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" }
};
std::map<std::string, DtlsTransport::Role> DtlsTransport::string2Role =
{
{ "auto", DtlsTransport::Role::AUTO },
{ "client", DtlsTransport::Role::CLIENT },
{ "server", DtlsTransport::Role::SERVER }
};
std::vector<DtlsTransport::SrtpCryptoSuiteMapEntry> DtlsTransport::srtpCryptoSuites =
{
{ RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" },
{ RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" },
{ RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" },
{ RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" }
};
// clang-format on
INSTANCE_IMP(DtlsTransport::DtlsEnvironment);
/* Class methods. */
DtlsTransport::DtlsEnvironment::DtlsEnvironment()
{
MS_TRACE();
// Generate a X509 certificate and private key (unless PEM files are provided).
if (true /*
Settings::configuration.dtlsCertificateFile.empty() ||
Settings::configuration.dtlsPrivateKeyFile.empty()*/)
{
GenerateCertificateAndPrivateKey();
}
else
{
ReadCertificateAndPrivateKeyFromFiles();
}
// Create a global SSL_CTX.
CreateSslCtx();
// Generate certificate fingerprints.
GenerateFingerprints();
}
{
MS_TRACE();
// Generate a X509 certificate and private key (unless PEM files are provided).
if (true /*
Settings::configuration.dtlsCertificateFile.empty() ||
Settings::configuration.dtlsPrivateKeyFile.empty()*/)
{
GenerateCertificateAndPrivateKey();
}
else
{
ReadCertificateAndPrivateKeyFromFiles();
}
// Create a global SSL_CTX.
CreateSslCtx();
// Generate certificate fingerprints.
GenerateFingerprints();
}
DtlsTransport::DtlsEnvironment::~DtlsEnvironment()
{
MS_TRACE();
{
MS_TRACE();
if (privateKey)
EVP_PKEY_free(privateKey);
if (certificate)
X509_free(certificate);
if (sslCtx)
SSL_CTX_free(sslCtx);
}
if (privateKey)
EVP_PKEY_free(privateKey);
if (certificate)
X509_free(certificate);
if (sslCtx)
SSL_CTX_free(sslCtx);
}
void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey()
{
MS_TRACE();
void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey()
{
MS_TRACE();
int ret{ 0 };
EC_KEY* ecKey{ nullptr };
X509_NAME* certName{ nullptr };
std::string subject =
std::string("mediasoup") + to_string(rand() % 999999 + 100000);
int ret{ 0 };
EC_KEY* ecKey{ nullptr };
X509_NAME* certName{ nullptr };
std::string subject =
std::string("mediasoup") + to_string(rand() % 999999 + 100000);
// Create key with curve.
ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
// Create key with curve.
ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
if (!ecKey)
{
LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed");
if (!ecKey)
{
LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed");
goto error;
}
goto error;
}
EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE);
EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE);
// NOTE: This can take some time.
ret = EC_KEY_generate_key(ecKey);
// NOTE: This can take some time.
ret = EC_KEY_generate_key(ecKey);
if (ret == 0)
{
LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed");
goto error;
}
goto error;
}
// Create a private key object.
privateKey = EVP_PKEY_new();
// Create a private key object.
privateKey = EVP_PKEY_new();
if (!privateKey)
{
LOG_OPENSSL_ERROR("EVP_PKEY_new() failed");
if (!privateKey)
{
LOG_OPENSSL_ERROR("EVP_PKEY_new() failed");
goto error;
}
goto error;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey);
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey);
if (ret == 0)
{
LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed");
goto error;
}
goto error;
}
// The EC key now belongs to the private key, so don't clean it up separately.
ecKey = nullptr;
// The EC key now belongs to the private key, so don't clean it up separately.
ecKey = nullptr;
// Create the X509 certificate.
certificate = X509_new();
// Create the X509 certificate.
certificate = X509_new();
if (!certificate)
{
LOG_OPENSSL_ERROR("X509_new() failed");
if (!certificate)
{
LOG_OPENSSL_ERROR("X509_new() failed");
goto error;
}
goto error;
}
// Set version 3 (note that 0 means version 1).
X509_set_version(certificate, 2);
// Set version 3 (note that 0 means version 1).
X509_set_version(certificate, 2);
// Set serial number (avoid default 0).
ASN1_INTEGER_set(
X509_get_serialNumber(certificate),
static_cast<uint64_t>(rand() % 999999 + 100000));
// Set serial number (avoid default 0).
ASN1_INTEGER_set(
X509_get_serialNumber(certificate),
static_cast<uint64_t>(rand() % 999999 + 100000));
// Set valid period.
X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years.
X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years.
// Set valid period.
X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years.
X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years.
// Set the public key for the certificate using the key.
ret = X509_set_pubkey(certificate, privateKey);
// Set the public key for the certificate using the key.
ret = X509_set_pubkey(certificate, privateKey);
if (ret == 0)
{
LOG_OPENSSL_ERROR("X509_set_pubkey() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("X509_set_pubkey() failed");
goto error;
}
goto error;
}
// Set certificate fields.
certName = X509_get_subject_name(certificate);
// Set certificate fields.
certName = X509_get_subject_name(certificate);
if (!certName)
{
LOG_OPENSSL_ERROR("X509_get_subject_name() failed");
if (!certName)
{
LOG_OPENSSL_ERROR("X509_get_subject_name() failed");
goto error;
}
goto error;
}
X509_NAME_add_entry_by_txt(
certName, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
X509_NAME_add_entry_by_txt(
certName, "CN", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
X509_NAME_add_entry_by_txt(
certName, "O", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
X509_NAME_add_entry_by_txt(
certName, "CN", MBSTRING_ASC, reinterpret_cast<const uint8_t*>(subject.c_str()), -1, -1, 0);
// It is self-signed so set the issuer name to be the same as the subject.
ret = X509_set_issuer_name(certificate, certName);
// It is self-signed so set the issuer name to be the same as the subject.
ret = X509_set_issuer_name(certificate, certName);
if (ret == 0)
{
LOG_OPENSSL_ERROR("X509_set_issuer_name() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("X509_set_issuer_name() failed");
goto error;
}
goto error;
}
// Sign the certificate with its own private key.
ret = X509_sign(certificate, privateKey, EVP_sha1());
// Sign the certificate with its own private key.
ret = X509_sign(certificate, privateKey, EVP_sha1());
if (ret == 0)
{
LOG_OPENSSL_ERROR("X509_sign() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("X509_sign() failed");
goto error;
}
goto error;
}
return;
return;
error:
error:
if (ecKey)
EC_KEY_free(ecKey);
if (ecKey)
EC_KEY_free(ecKey);
if (privateKey)
EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key.
if (privateKey)
EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key.
if (certificate)
X509_free(certificate);
if (certificate)
X509_free(certificate);
MS_THROW_ERROR("DTLS certificate and private key generation failed");
}
MS_THROW_ERROR("DTLS certificate and private key generation failed");
}
void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles()
{
void DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromFiles()
{
#if 0
MS_TRACE();
MS_TRACE();
FILE* file{ nullptr };
FILE* file{ nullptr };
file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r");
file = fopen(Settings::configuration.dtlsCertificateFile.c_str(), "r");
if (!file)
{
MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno));
if (!file)
{
MS_ERROR("error reading DTLS certificate file: %s", std::strerror(errno));
goto error;
}
goto error;
}
certificate = PEM_read_X509(file, nullptr, nullptr, nullptr);
certificate = PEM_read_X509(file, nullptr, nullptr, nullptr);
if (!certificate)
{
LOG_OPENSSL_ERROR("PEM_read_X509() failed");
if (!certificate)
{
LOG_OPENSSL_ERROR("PEM_read_X509() failed");
goto error;
}
goto error;
}
fclose(file);
fclose(file);
file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r");
file = fopen(Settings::configuration.dtlsPrivateKeyFile.c_str(), "r");
if (!file)
{
MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno));
if (!file)
{
MS_ERROR("error reading DTLS private key file: %s", std::strerror(errno));
goto error;
}
goto error;
}
privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr);
privateKey = PEM_read_PrivateKey(file, nullptr, nullptr, nullptr);
if (!privateKey)
{
LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed");
if (!privateKey)
{
LOG_OPENSSL_ERROR("PEM_read_PrivateKey() failed");
goto error;
}
goto error;
}
fclose(file);
fclose(file);
return;
return;
error:
error:
MS_THROW_ERROR("error reading DTLS certificate and private key PEM files");
MS_THROW_ERROR("error reading DTLS certificate and private key PEM files");
#endif
}
}
void DtlsTransport::DtlsEnvironment::CreateSslCtx()
{
MS_TRACE();
void DtlsTransport::DtlsEnvironment::CreateSslCtx()
{
MS_TRACE();
std::string dtlsSrtpCryptoSuites;
int ret;
std::string dtlsSrtpCryptoSuites;
int ret;
/* Set the global DTLS context. */
/* Set the global DTLS context. */
// Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0).
sslCtx = SSL_CTX_new(DTLS_method());
// Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0).
sslCtx = SSL_CTX_new(DTLS_method());
if (!sslCtx)
{
LOG_OPENSSL_ERROR("SSL_CTX_new() failed");
if (!sslCtx)
{
LOG_OPENSSL_ERROR("SSL_CTX_new() failed");
goto error;
}
goto error;
}
ret = SSL_CTX_use_certificate(sslCtx, certificate);
ret = SSL_CTX_use_certificate(sslCtx, certificate);
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed");
goto error;
}
goto error;
}
ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey);
ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey);
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed");
goto error;
}
goto error;
}
ret = SSL_CTX_check_private_key(sslCtx);
ret = SSL_CTX_check_private_key(sslCtx);
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed");
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed");
goto error;
}
goto error;
}
// Set options.
SSL_CTX_set_options(
sslCtx,
SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE |
SSL_OP_NO_QUERY_MTU);
// Set options.
SSL_CTX_set_options(
sslCtx,
SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE |
SSL_OP_NO_QUERY_MTU);
// Don't use sessions cache.
SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF);
// Don't use sessions cache.
SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF);
// Read always as much into the buffer as possible.
// NOTE: This is the default for DTLS, but a bug in non latest OpenSSL
// versions makes this call required.
SSL_CTX_set_read_ahead(sslCtx, 1);
// Read always as much into the buffer as possible.
// NOTE: This is the default for DTLS, but a bug in non latest OpenSSL
// versions makes this call required.
SSL_CTX_set_read_ahead(sslCtx, 1);
SSL_CTX_set_verify_depth(sslCtx, 4);
SSL_CTX_set_verify_depth(sslCtx, 4);
// Require certificate from peer.
SSL_CTX_set_verify(
sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify);
// Require certificate from peer.
SSL_CTX_set_verify(
sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify);
// Set SSL info callback.
SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){
// Set SSL info callback.
SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){
static_cast<RTC::DtlsTransport*>(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret);
});
// Set ciphers.
ret = SSL_CTX_set_cipher_list(
sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK");
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed");
goto error;
}
// Enable ECDH ciphers.
// DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters
// NOTE: https://code.google.com/p/chromium/issues/detail?id=406458
// NOTE: https://bugs.ruby-lang.org/issues/12324
// For OpenSSL >= 1.0.2.
SSL_CTX_set_ecdh_auto(sslCtx, 1);
// Set the "use_srtp" DTLS extension.
for (auto it = DtlsTransport::srtpCryptoSuites.begin();
it != DtlsTransport::srtpCryptoSuites.end();
++it)
{
if (it != DtlsTransport::srtpCryptoSuites.begin())
dtlsSrtpCryptoSuites += ":";
// Set ciphers.
ret = SSL_CTX_set_cipher_list(
sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK");
if (ret == 0)
{
LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed");
goto error;
}
// Enable ECDH ciphers.
// DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters
// NOTE: https://code.google.com/p/chromium/issues/detail?id=406458
// NOTE: https://bugs.ruby-lang.org/issues/12324
// For OpenSSL >= 1.0.2.
SSL_CTX_set_ecdh_auto(sslCtx, 1);
// Set the "use_srtp" DTLS extension.
for (auto it = DtlsTransport::srtpCryptoSuites.begin();
it != DtlsTransport::srtpCryptoSuites.end();
++it)
{
if (it != DtlsTransport::srtpCryptoSuites.begin())
dtlsSrtpCryptoSuites += ":";
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it);
dtlsSrtpCryptoSuites += cryptoSuiteEntry->name;
}
MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str());
// NOTE: This function returns 0 on success.
ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str());
if (ret != 0)
{
MS_ERROR(
"SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str());
LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed");
goto error;
}
return;
error:
if (sslCtx)
{
SSL_CTX_free(sslCtx);
sslCtx = nullptr;
}
MS_THROW_ERROR("SSL context creation failed");
}
void DtlsTransport::DtlsEnvironment::GenerateFingerprints()
{
MS_TRACE();
for (auto& kv : DtlsTransport::string2FingerprintAlgorithm)
{
const std::string& algorithmString = kv.first;
FingerprintAlgorithm algorithm = kv.second;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
switch (algorithm)
{
case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224();
break;
case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256();
break;
case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384();
break;
case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512();
break;
default:
MS_THROW_ERROR("unknown algorithm");
}
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0)
{
MS_ERROR("X509_digest() failed");
MS_THROW_ERROR("Fingerprints generation failed");
}
// Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i)
{
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint);
// Store it in the vector.
DtlsTransport::Fingerprint fingerprint;
fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString);
fingerprint.value = hexFingerprint;
localFingerprints.push_back(fingerprint);
}
}
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it);
dtlsSrtpCryptoSuites += cryptoSuiteEntry->name;
}
MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str());
// NOTE: This function returns 0 on success.
ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str());
if (ret != 0)
{
MS_ERROR(
"SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str());
LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed");
goto error;
}
return;
error:
if (sslCtx)
{
SSL_CTX_free(sslCtx);
sslCtx = nullptr;
}
MS_THROW_ERROR("SSL context creation failed");
}
void DtlsTransport::DtlsEnvironment::GenerateFingerprints()
{
MS_TRACE();
for (auto& kv : DtlsTransport::string2FingerprintAlgorithm)
{
const std::string& algorithmString = kv.first;
FingerprintAlgorithm algorithm = kv.second;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
switch (algorithm)
{
case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224();
break;
case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256();
break;
case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384();
break;
case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512();
break;
default:
MS_THROW_ERROR("unknown algorithm");
}
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0)
{
MS_ERROR("X509_digest() failed");
MS_THROW_ERROR("Fingerprints generation failed");
}
// Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i)
{
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint);
// Store it in the vector.
DtlsTransport::Fingerprint fingerprint;
fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString);
fingerprint.value = hexFingerprint;
localFingerprints.push_back(fingerprint);
}
}
/* Instance methods. */
/* Instance methods. */
DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener)
{
MS_TRACE();
DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener)
{
MS_TRACE();
env = DtlsEnvironment::Instance().shared_from_this();
/* Set SSL. */
/* Set SSL. */
this->ssl = SSL_new(env->sslCtx);
this->ssl = SSL_new(env->sslCtx);
if (!this->ssl)
{
LOG_OPENSSL_ERROR("SSL_new() failed");
if (!this->ssl)
{
LOG_OPENSSL_ERROR("SSL_new() failed");
goto error;
}
goto error;
}
// Set this as custom data.
SSL_set_ex_data(this->ssl, 0, static_cast<void*>(this));
// Set this as custom data.
SSL_set_ex_data(this->ssl, 0, static_cast<void*>(this));
this->sslBioFromNetwork = BIO_new(BIO_s_mem());
this->sslBioFromNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioFromNetwork)
{
LOG_OPENSSL_ERROR("BIO_new() failed");
if (!this->sslBioFromNetwork)
{
LOG_OPENSSL_ERROR("BIO_new() failed");
SSL_free(this->ssl);
SSL_free(this->ssl);
goto error;
}
goto error;
}
this->sslBioToNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioToNetwork)
{
LOG_OPENSSL_ERROR("BIO_new() failed");
BIO_free(this->sslBioFromNetwork);
SSL_free(this->ssl);
goto error;
}
SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork);
// Set the MTU so that we don't send packets that are too large with no fragmentation.
SSL_set_mtu(this->ssl, DtlsMtu);
DTLS_set_link_mtu(this->ssl, DtlsMtu);
// Set callback handler for setting DTLS timer interval.
DTLS_set_timer_cb(this->ssl, onSslDtlsTimer);
this->sslBioToNetwork = BIO_new(BIO_s_mem());
if (!this->sslBioToNetwork)
{
LOG_OPENSSL_ERROR("BIO_new() failed");
BIO_free(this->sslBioFromNetwork);
SSL_free(this->ssl);
goto error;
}
SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork);
// Set the MTU so that we don't send packets that are too large with no fragmentation.
SSL_set_mtu(this->ssl, DtlsMtu);
DTLS_set_link_mtu(this->ssl, DtlsMtu);
// Set callback handler for setting DTLS timer interval.
DTLS_set_timer_cb(this->ssl, onSslDtlsTimer);
return;
error:
return;
error:
// NOTE: At this point SSL_set_bio() was not called so we must free BIOs as
// well.
if (this->sslBioFromNetwork)
BIO_free(this->sslBioFromNetwork);
if (this->sslBioToNetwork)
BIO_free(this->sslBioToNetwork);
// NOTE: At this point SSL_set_bio() was not called so we must free BIOs as
// well.
if (this->sslBioFromNetwork)
BIO_free(this->sslBioFromNetwork);
if (this->sslBioToNetwork)
BIO_free(this->sslBioToNetwork);
if (this->ssl)
SSL_free(this->ssl);
if (this->ssl)
SSL_free(this->ssl);
// NOTE: If this is not catched by the caller the program will abort, but
// this should never happen.
MS_THROW_ERROR("DtlsTransport instance creation failed");
}
// NOTE: If this is not catched by the caller the program will abort, but
// this should never happen.
MS_THROW_ERROR("DtlsTransport instance creation failed");
}
DtlsTransport::~DtlsTransport()
{
MS_TRACE();
DtlsTransport::~DtlsTransport()
{
MS_TRACE();
if (IsRunning())
{
// Send close alert to the peer.
SSL_shutdown(this->ssl);
SendPendingOutgoingDtlsData();
}
if (this->ssl)
{
SSL_free(this->ssl);
if (IsRunning())
{
// Send close alert to the peer.
SSL_shutdown(this->ssl);
SendPendingOutgoingDtlsData();
}
if (this->ssl)
{
SSL_free(this->ssl);
this->ssl = nullptr;
this->sslBioFromNetwork = nullptr;
this->sslBioToNetwork = nullptr;
}
this->ssl = nullptr;
this->sslBioFromNetwork = nullptr;
this->sslBioToNetwork = nullptr;
}
// Close the DTLS timer.
this->timer = nullptr;
}
// Close the DTLS timer.
this->timer = nullptr;
}
void DtlsTransport::Dump() const
{
MS_TRACE();
void DtlsTransport::Dump() const
{
MS_TRACE();
std::string state{ "new" };
std::string role{ "none " };
switch (this->state)
{
case DtlsState::CONNECTING:
state = "connecting";
break;
case DtlsState::CONNECTED:
state = "connected";
break;
case DtlsState::FAILED:
state = "failed";
break;
case DtlsState::CLOSED:
state = "closed";
break;
default:;
}
switch (this->localRole)
{
case Role::AUTO:
role = "auto";
break;
case Role::SERVER:
role = "server";
break;
case Role::CLIENT:
role = "client";
break;
default:;
}
MS_DUMP("<DtlsTransport>");
MS_DUMP(" state : %s", state.c_str());
MS_DUMP(" role : %s", role.c_str());
MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no");
MS_DUMP("</DtlsTransport>");
}
void DtlsTransport::Run(Role localRole)
{
MS_TRACE();
MS_ASSERT(
localRole == Role::CLIENT || localRole == Role::SERVER,
"local DTLS role must be 'client' or 'server'");
Role previousLocalRole = this->localRole;
if (localRole == previousLocalRole)
{
MS_ERROR("same local DTLS role provided, doing nothing");
return;
}
// If the previous local DTLS role was 'client' or 'server' do reset.
if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER)
{
MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change");
Reset();
}
// Update local role.
this->localRole = localRole;
// Set state and notify the listener.
this->state = DtlsState::CONNECTING;
this->listener->OnDtlsTransportConnecting(this);
switch (this->localRole)
{
case Role::CLIENT:
{
MS_DEBUG_TAG(dtls, "running [role:client]");
SSL_set_connect_state(this->ssl);
SSL_do_handshake(this->ssl);
SendPendingOutgoingDtlsData();
SetTimeout();
break;
}
case Role::SERVER:
{
MS_DEBUG_TAG(dtls, "running [role:server]");
SSL_set_accept_state(this->ssl);
SSL_do_handshake(this->ssl);
break;
}
default:
{
MS_ABORT("invalid local DTLS role");
}
}
}
bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint)
{
MS_TRACE();
MS_ASSERT(
fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided");
this->remoteFingerprint = fingerprint;
// The remote fingerpring may have been set after DTLS handshake was done,
// so we may need to process it now.
if (this->handshakeDone && this->state != DtlsState::CONNECTED)
{
MS_DEBUG_TAG(dtls, "handshake already done, processing it right now");
return ProcessHandshake();
}
return true;
}
void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len)
{
MS_TRACE();
int written;
int read;
std::string state{ "new" };
std::string role{ "none " };
switch (this->state)
{
case DtlsState::CONNECTING:
state = "connecting";
break;
case DtlsState::CONNECTED:
state = "connected";
break;
case DtlsState::FAILED:
state = "failed";
break;
case DtlsState::CLOSED:
state = "closed";
break;
default:;
}
switch (this->localRole)
{
case Role::AUTO:
role = "auto";
break;
case Role::SERVER:
role = "server";
break;
case Role::CLIENT:
role = "client";
break;
default:;
}
MS_DUMP("<DtlsTransport>");
MS_DUMP(" state : %s", state.c_str());
MS_DUMP(" role : %s", role.c_str());
MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no");
MS_DUMP("</DtlsTransport>");
}
void DtlsTransport::Run(Role localRole)
{
MS_TRACE();
MS_ASSERT(
localRole == Role::CLIENT || localRole == Role::SERVER,
"local DTLS role must be 'client' or 'server'");
Role previousLocalRole = this->localRole;
if (localRole == previousLocalRole)
{
MS_ERROR("same local DTLS role provided, doing nothing");
return;
}
// If the previous local DTLS role was 'client' or 'server' do reset.
if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER)
{
MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change");
Reset();
}
// Update local role.
this->localRole = localRole;
// Set state and notify the listener.
this->state = DtlsState::CONNECTING;
this->listener->OnDtlsTransportConnecting(this);
switch (this->localRole)
{
case Role::CLIENT:
{
MS_DEBUG_TAG(dtls, "running [role:client]");
SSL_set_connect_state(this->ssl);
SSL_do_handshake(this->ssl);
SendPendingOutgoingDtlsData();
SetTimeout();
break;
}
case Role::SERVER:
{
MS_DEBUG_TAG(dtls, "running [role:server]");
SSL_set_accept_state(this->ssl);
SSL_do_handshake(this->ssl);
break;
}
default:
{
MS_ABORT("invalid local DTLS role");
}
}
}
bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint)
{
MS_TRACE();
MS_ASSERT(
fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided");
this->remoteFingerprint = fingerprint;
// The remote fingerpring may have been set after DTLS handshake was done,
// so we may need to process it now.
if (this->handshakeDone && this->state != DtlsState::CONNECTED)
{
MS_DEBUG_TAG(dtls, "handshake already done, processing it right now");
return ProcessHandshake();
}
return true;
}
void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len)
{
MS_TRACE();
int written;
int read;
if (!IsRunning())
{
MS_ERROR("cannot process data while not running");
if (!IsRunning())
{
MS_ERROR("cannot process data while not running");
return;
}
return;
}
// Write the received DTLS data into the sslBioFromNetwork.
written =
BIO_write(this->sslBioFromNetwork, static_cast<const void*>(data), static_cast<int>(len));
// Write the received DTLS data into the sslBioFromNetwork.
written =
BIO_write(this->sslBioFromNetwork, static_cast<const void*>(data), static_cast<int>(len));
if (written != static_cast<int>(len))
{
MS_WARN_TAG(
dtls,
"OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)",
static_cast<size_t>(written),
len);
}
// Must call SSL_read() to process received DTLS data.
read = SSL_read(this->ssl, static_cast<void*>(DtlsTransport::sslReadBuffer), SslReadBufferSize);
if (written != static_cast<int>(len))
{
MS_WARN_TAG(
dtls,
"OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)",
static_cast<size_t>(written),
len);
}
// Must call SSL_read() to process received DTLS data.
read = SSL_read(this->ssl, static_cast<void*>(DtlsTransport::sslReadBuffer), SslReadBufferSize);
// Send data if it's ready.
SendPendingOutgoingDtlsData();
// Send data if it's ready.
SendPendingOutgoingDtlsData();
// Check SSL status and return if it is bad/closed.
if (!CheckStatus(read))
return;
// Check SSL status and return if it is bad/closed.
if (!CheckStatus(read))
return;
// Set/update the DTLS timeout.
if (!SetTimeout())
return;
// Set/update the DTLS timeout.
if (!SetTimeout())
return;
// Application data received. Notify to the listener.
if (read > 0)
{
// It is allowed to receive DTLS data even before validating remote fingerprint.
if (!this->handshakeDone)
{
MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done");
return;
}
// Application data received. Notify to the listener.
if (read > 0)
{
// It is allowed to receive DTLS data even before validating remote fingerprint.
if (!this->handshakeDone)
{
MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done");
return;
}
// Notify the listener.
this->listener->OnDtlsTransportApplicationDataReceived(
this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast<size_t>(read));
}
}
// Notify the listener.
this->listener->OnDtlsTransportApplicationDataReceived(
this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast<size_t>(read));
}
}
void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len)
{
MS_TRACE();
void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len)
{
MS_TRACE();
// We cannot send data to the peer if its remote fingerprint is not validated.
if (this->state != DtlsState::CONNECTED)
{
MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected");
// We cannot send data to the peer if its remote fingerprint is not validated.
if (this->state != DtlsState::CONNECTED)
{
MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected");
return;
}
return;
}
if (len == 0)
{
MS_WARN_TAG(dtls, "ignoring 0 length data");
if (len == 0)
{
MS_WARN_TAG(dtls, "ignoring 0 length data");
return;
}
return;
}
int written;
int written;
written = SSL_write(this->ssl, static_cast<const void*>(data), static_cast<int>(len));
written = SSL_write(this->ssl, static_cast<const void*>(data), static_cast<int>(len));
if (written < 0)
{
LOG_OPENSSL_ERROR("SSL_write() failed");
if (written < 0)
{
LOG_OPENSSL_ERROR("SSL_write() failed");
if (!CheckStatus(written))
return;
}
else if (written != static_cast<int>(len))
{
MS_WARN_TAG(
dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len);
}
if (!CheckStatus(written))
return;
}
else if (written != static_cast<int>(len))
{
MS_WARN_TAG(
dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len);
}
// Send data.
SendPendingOutgoingDtlsData();
}
void DtlsTransport::Reset()
{
MS_TRACE();
int ret;
// Send data.
SendPendingOutgoingDtlsData();
}
void DtlsTransport::Reset()
{
MS_TRACE();
int ret;
if (!IsRunning())
return;
MS_WARN_TAG(dtls, "resetting DTLS transport");
// Stop the DTLS timer.
this->timer = nullptr;
// We need to reset the SSL instance so we need to "shutdown" it, but we
// don't want to send a Close Alert to the peer, so just don't call
// SendPendingOutgoingDTLSData().
SSL_shutdown(this->ssl);
this->localRole = Role::NONE;
this->state = DtlsState::NEW;
this->handshakeDone = false;
this->handshakeDoneNow = false;
// Reset SSL status.
// NOTE: For this to properly work, SSL_shutdown() must be called before.
// NOTE: This may fail if not enough DTLS handshake data has been received,
// but we don't care so just clear the error queue.
ret = SSL_clear(this->ssl);
if (ret == 0)
ERR_clear_error();
}
inline bool DtlsTransport::CheckStatus(int returnCode)
{
MS_TRACE();
int err;
bool wasHandshakeDone = this->handshakeDone;
err = SSL_get_error(this->ssl, returnCode);
switch (err)
{
case SSL_ERROR_NONE:
break;
if (!IsRunning())
return;
MS_WARN_TAG(dtls, "resetting DTLS transport");
// Stop the DTLS timer.
this->timer = nullptr;
// We need to reset the SSL instance so we need to "shutdown" it, but we
// don't want to send a Close Alert to the peer, so just don't call
// SendPendingOutgoingDTLSData().
SSL_shutdown(this->ssl);
this->localRole = Role::NONE;
this->state = DtlsState::NEW;
this->handshakeDone = false;
this->handshakeDoneNow = false;
// Reset SSL status.
// NOTE: For this to properly work, SSL_shutdown() must be called before.
// NOTE: This may fail if not enough DTLS handshake data has been received,
// but we don't care so just clear the error queue.
ret = SSL_clear(this->ssl);
if (ret == 0)
ERR_clear_error();
}
inline bool DtlsTransport::CheckStatus(int returnCode)
{
MS_TRACE();
int err;
bool wasHandshakeDone = this->handshakeDone;
err = SSL_get_error(this->ssl, returnCode);
switch (err)
{
case SSL_ERROR_NONE:
break;
case SSL_ERROR_SSL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL");
break;
case SSL_ERROR_SSL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL");
break;
case SSL_ERROR_WANT_READ:
break;
case SSL_ERROR_WANT_READ:
break;
case SSL_ERROR_WANT_WRITE:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE");
break;
case SSL_ERROR_WANT_WRITE:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE");
break;
case SSL_ERROR_WANT_X509_LOOKUP:
MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP");
break;
case SSL_ERROR_WANT_X509_LOOKUP:
MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP");
break;
case SSL_ERROR_SYSCALL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL");
break;
case SSL_ERROR_SYSCALL:
LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL");
break;
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_CONNECT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT");
break;
case SSL_ERROR_WANT_CONNECT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT");
break;
case SSL_ERROR_WANT_ACCEPT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT");
break;
case SSL_ERROR_WANT_ACCEPT:
MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT");
break;
default:
MS_WARN_TAG(dtls, "SSL status: unknown error");
}
default:
MS_WARN_TAG(dtls, "SSL status: unknown error");
}
// Check if the handshake (or re-handshake) has been done right now.
if (this->handshakeDoneNow)
{
this->handshakeDoneNow = false;
this->handshakeDone = true;
// Stop the timer.
this->timer = nullptr;
// Check if the handshake (or re-handshake) has been done right now.
if (this->handshakeDoneNow)
{
this->handshakeDoneNow = false;
this->handshakeDone = true;
// Stop the timer.
this->timer = nullptr;
// Process the handshake just once (ignore if DTLS renegotiation).
if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE)
return ProcessHandshake();
// Process the handshake just once (ignore if DTLS renegotiation).
if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE)
return ProcessHandshake();
return true;
}
// Check if the peer sent close alert or a fatal error happened.
else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL)
{
if (this->state == DtlsState::CONNECTED)
{
MS_DEBUG_TAG(dtls, "disconnected");
return true;
}
// Check if the peer sent close alert or a fatal error happened.
else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL)
{
if (this->state == DtlsState::CONNECTED)
{
MS_DEBUG_TAG(dtls, "disconnected");
Reset();
Reset();
// Set state and notify the listener.
this->state = DtlsState::CLOSED;
this->listener->OnDtlsTransportClosed(this);
}
else
{
MS_WARN_TAG(dtls, "connection failed");
// Set state and notify the listener.
this->state = DtlsState::CLOSED;
this->listener->OnDtlsTransportClosed(this);
}
else
{
MS_WARN_TAG(dtls, "connection failed");
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
}
return false;
}
else
{
return true;
}
}
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
}
return false;
}
else
{
return true;
}
}
inline void DtlsTransport::SendPendingOutgoingDtlsData()
{
MS_TRACE();
inline void DtlsTransport::SendPendingOutgoingDtlsData()
{
MS_TRACE();
if (BIO_eof(this->sslBioToNetwork))
return;
if (BIO_eof(this->sslBioToNetwork))
return;
int64_t read;
char* data{ nullptr };
int64_t read;
char* data{ nullptr };
read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT
read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT
if (read <= 0)
return;
MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read);
if (read <= 0)
return;
MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read);
// Notify the listener.
this->listener->OnDtlsTransportSendData(
this, reinterpret_cast<uint8_t*>(data), static_cast<size_t>(read));
// Notify the listener.
this->listener->OnDtlsTransportSendData(
this, reinterpret_cast<uint8_t*>(data), static_cast<size_t>(read));
// Clear the BIO buffer.
// NOTE: the (void) avoids the -Wunused-value warning.
(void)BIO_reset(this->sslBioToNetwork);
}
// Clear the BIO buffer.
// NOTE: the (void) avoids the -Wunused-value warning.
(void)BIO_reset(this->sslBioToNetwork);
}
inline bool DtlsTransport::SetTimeout()
{
MS_TRACE();
inline bool DtlsTransport::SetTimeout()
{
MS_TRACE();
MS_ASSERT(
this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED,
"invalid DTLS state");
MS_ASSERT(
this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED,
"invalid DTLS state");
int64_t ret;
int64_t ret;
struct timeval dtlsTimeout{ 0, 0 };
uint64_t timeoutMs;
uint64_t timeoutMs;
// NOTE: If ret == 0 then ignore the value in dtlsTimeout.
// NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev.
ret = DTLSv1_get_timeout(this->ssl, static_cast<void*>(&dtlsTimeout)); // NOLINT
// NOTE: If ret == 0 then ignore the value in dtlsTimeout.
// NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev.
ret = DTLSv1_get_timeout(this->ssl, static_cast<void*>(&dtlsTimeout)); // NOLINT
if (ret == 0)
return true;
if (ret == 0)
return true;
timeoutMs = (dtlsTimeout.tv_sec * static_cast<uint64_t>(1000)) + (dtlsTimeout.tv_usec / 1000);
timeoutMs = (dtlsTimeout.tv_sec * static_cast<uint64_t>(1000)) + (dtlsTimeout.tv_usec / 1000);
if (timeoutMs == 0)
{
return true;
}
else if (timeoutMs < 30000)
{
MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs);
if (timeoutMs == 0)
{
return true;
}
else if (timeoutMs < 30000)
{
MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs);
weak_ptr<DtlsTransport> weak_self = shared_from_this();
this->timer = std::make_shared<Timer>(timeoutMs / 1000.0f, [weak_self](){
auto strong_self = weak_self.lock();
if(strong_self){
weak_ptr<DtlsTransport> weak_self = shared_from_this();
this->timer = std::make_shared<Timer>(timeoutMs / 1000.0f, [weak_self](){
auto strong_self = weak_self.lock();
if(strong_self){
strong_self->OnTimer();
}
}
return true;
}, this->poller);
}, this->poller);
return true;
}
// NOTE: Don't start the timer again if the timeout is greater than 30 seconds.
else
{
MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs);
return true;
}
// NOTE: Don't start the timer again if the timeout is greater than 30 seconds.
else
{
MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs);
Reset();
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
return false;
}
}
return false;
}
}
inline bool DtlsTransport::ProcessHandshake()
{
MS_TRACE();
inline bool DtlsTransport::ProcessHandshake()
{
MS_TRACE();
MS_ASSERT(this->handshakeDone, "handshake not done yet");
MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
MS_ASSERT(this->handshakeDone, "handshake not done yet");
MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
// Validate the remote fingerprint.
if (!CheckRemoteFingerprint())
{
Reset();
// Validate the remote fingerprint.
if (!CheckRemoteFingerprint())
{
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
return false;
}
return false;
}
// Get the negotiated SRTP crypto suite.
RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite();
// Get the negotiated SRTP crypto suite.
RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite();
if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE)
{
// Extract the SRTP keys (will notify the listener with them).
ExtractSrtpKeys(srtpCryptoSuite);
if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE)
{
// Extract the SRTP keys (will notify the listener with them).
ExtractSrtpKeys(srtpCryptoSuite);
return true;
}
return true;
}
// NOTE: We assume that "use_srtp" DTLS extension is required even if
// there is no audio/video.
MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated");
// NOTE: We assume that "use_srtp" DTLS extension is required even if
// there is no audio/video.
MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated");
Reset();
Reset();
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
// Set state and notify the listener.
this->state = DtlsState::FAILED;
this->listener->OnDtlsTransportFailed(this);
return false;
}
return false;
}
inline bool DtlsTransport::CheckRemoteFingerprint()
{
MS_TRACE();
inline bool DtlsTransport::CheckRemoteFingerprint()
{
MS_TRACE();
MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
MS_ASSERT(
this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set");
X509* certificate;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
X509* certificate;
uint8_t binaryFingerprint[EVP_MAX_MD_SIZE];
unsigned int size{ 0 };
char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1];
const EVP_MD* hashFunction;
int ret;
certificate = SSL_get_peer_certificate(this->ssl);
certificate = SSL_get_peer_certificate(this->ssl);
if (!certificate)
{
MS_WARN_TAG(dtls, "no certificate was provided by the peer");
if (!certificate)
{
MS_WARN_TAG(dtls, "no certificate was provided by the peer");
return false;
}
return false;
}
switch (this->remoteFingerprint.algorithm)
{
case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
switch (this->remoteFingerprint.algorithm)
{
case FingerprintAlgorithm::SHA1:
hashFunction = EVP_sha1();
break;
case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224();
break;
case FingerprintAlgorithm::SHA224:
hashFunction = EVP_sha224();
break;
case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256();
break;
case FingerprintAlgorithm::SHA256:
hashFunction = EVP_sha256();
break;
case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384();
break;
case FingerprintAlgorithm::SHA384:
hashFunction = EVP_sha384();
break;
case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512();
break;
case FingerprintAlgorithm::SHA512:
hashFunction = EVP_sha512();
break;
default:
MS_ABORT("unknown algorithm");
}
default:
MS_ABORT("unknown algorithm");
}
// Compare the remote fingerprint with the value given via signaling.
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
// Compare the remote fingerprint with the value given via signaling.
ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size);
if (ret == 0)
{
MS_ERROR("X509_digest() failed");
if (ret == 0)
{
MS_ERROR("X509_digest() failed");
X509_free(certificate);
X509_free(certificate);
return false;
}
return false;
}
// Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i)
{
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
// Convert to hexadecimal format in uppercase with colons.
for (unsigned int i{ 0 }; i < size; ++i)
{
std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]);
}
hexFingerprint[(size * 3) - 1] = '\0';
if (this->remoteFingerprint.value != hexFingerprint)
{
MS_WARN_TAG(
dtls,
"fingerprint in the remote certificate (%s) does not match the announced one (%s)",
hexFingerprint,
this->remoteFingerprint.value.c_str());
X509_free(certificate);
return false;
}
if (this->remoteFingerprint.value != hexFingerprint)
{
MS_WARN_TAG(
dtls,
"fingerprint in the remote certificate (%s) does not match the announced one (%s)",
hexFingerprint,
this->remoteFingerprint.value.c_str());
X509_free(certificate);
return false;
}
MS_DEBUG_TAG(dtls, "valid remote fingerprint");
MS_DEBUG_TAG(dtls, "valid remote fingerprint");
// Get the remote certificate in PEM format.
// Get the remote certificate in PEM format.
BIO* bio = BIO_new(BIO_s_mem());
BIO* bio = BIO_new(BIO_s_mem());
// Ensure the underlying BUF_MEM structure is also freed.
// NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since
// BIO_set_close() always returns 1.
(void)BIO_set_close(bio, BIO_CLOSE);
// Ensure the underlying BUF_MEM structure is also freed.
// NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since
// BIO_set_close() always returns 1.
(void)BIO_set_close(bio, BIO_CLOSE);
ret = PEM_write_bio_X509(bio, certificate);
ret = PEM_write_bio_X509(bio, certificate);
if (ret != 1)
{
LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed");
if (ret != 1)
{
LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed");
X509_free(certificate);
BIO_free(bio);
return false;
}
BUF_MEM* mem;
X509_free(certificate);
BIO_free(bio);
return false;
}
BUF_MEM* mem;
BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast]
if (!mem || !mem->data || mem->length == 0u)
{
LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed");
X509_free(certificate);
BIO_free(bio);
return false;
}
this->remoteCert = std::string(mem->data, mem->length);
X509_free(certificate);
BIO_free(bio);
return true;
}
inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite)
{
MS_TRACE();
size_t srtpKeyLength{ 0 };
size_t srtpSaltLength{ 0 };
size_t srtpMasterLength{ 0 };
switch (srtpCryptoSuite)
{
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80:
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32:
{
srtpKeyLength = SrtpMasterKeyLength;
srtpSaltLength = SrtpMasterSaltLength;
srtpMasterLength = SrtpMasterLength;
break;
}
case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM:
{
srtpKeyLength = SrtpAesGcm256MasterKeyLength;
srtpSaltLength = SrtpAesGcm256MasterSaltLength;
srtpMasterLength = SrtpAesGcm256MasterLength;
break;
}
case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM:
{
srtpKeyLength = SrtpAesGcm128MasterKeyLength;
srtpSaltLength = SrtpAesGcm128MasterSaltLength;
srtpMasterLength = SrtpAesGcm128MasterLength;
break;
}
default:
{
MS_ABORT("unknown SRTP crypto suite");
}
}
auto* srtpMaterial = new uint8_t[srtpMasterLength * 2];
uint8_t* srtpLocalKey{ nullptr };
uint8_t* srtpLocalSalt{ nullptr };
uint8_t* srtpRemoteKey{ nullptr };
uint8_t* srtpRemoteSalt{ nullptr };
auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength];
auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength];
int ret;
ret = SSL_export_keying_material(
this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0);
MS_ASSERT(ret != 0, "SSL_export_keying_material() failed");
switch (this->localRole)
{
case Role::SERVER:
{
srtpRemoteKey = srtpMaterial;
srtpLocalKey = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteSalt + srtpSaltLength;
break;
}
case Role::CLIENT:
{
srtpLocalKey = srtpMaterial;
srtpRemoteKey = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalSalt + srtpSaltLength;
break;
}
default:
{
MS_ABORT("no DTLS role set");
}
}
// Create the SRTP local master key.
std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength);
std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength);
// Create the SRTP remote master key.
std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength);
std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength);
// Set state and notify the listener.
this->state = DtlsState::CONNECTED;
this->listener->OnDtlsTransportConnected(
this,
srtpCryptoSuite,
srtpLocalMasterKey,
srtpMasterLength,
srtpRemoteMasterKey,
srtpMasterLength,
this->remoteCert);
delete[] srtpMaterial;
delete[] srtpLocalMasterKey;
delete[] srtpRemoteMasterKey;
}
inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite()
{
MS_TRACE();
RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE;
// Ensure that the SRTP crypto suite has been negotiated.
// NOTE: This is a OpenSSL type.
SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl);
if (!sslSrtpCryptoSuite)
return negotiatedSrtpCryptoSuite;
// Get the negotiated SRTP crypto suite.
for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites)
{
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite);
if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0)
{
MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name);
negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite;
}
}
MS_ASSERT(
negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE,
"chosen SRTP crypto suite is not an available one");
return negotiatedSrtpCryptoSuite;
}
inline void DtlsTransport::OnSslInfo(int where, int ret)
{
MS_TRACE();
int w = where & -SSL_ST_MASK;
const char* role;
if ((w & SSL_ST_CONNECT) != 0)
role = "client";
else if ((w & SSL_ST_ACCEPT) != 0)
role = "server";
else
role = "undefined";
if ((where & SSL_CB_LOOP) != 0)
{
MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl));
}
else if ((where & SSL_CB_ALERT) != 0)
{
const char* alertType;
switch (*SSL_alert_type_string(ret))
{
case 'W':
alertType = "warning";
break;
case 'F':
alertType = "fatal";
break;
default:
alertType = "undefined";
}
if ((where & SSL_CB_READ) != 0)
{
MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
else if ((where & SSL_CB_WRITE) != 0)
{
MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
else
{
MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
}
else if ((where & SSL_CB_EXIT) != 0)
{
if (ret == 0)
MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl));
else if (ret < 0)
MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl));
}
else if ((where & SSL_CB_HANDSHAKE_START) != 0)
{
MS_DEBUG_TAG(dtls, "DTLS handshake start");
}
else if ((where & SSL_CB_HANDSHAKE_DONE) != 0)
{
MS_DEBUG_TAG(dtls, "DTLS handshake done");
this->handshakeDoneNow = true;
}
// NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon
// receipt of a close alert does not work (the flag is set after this callback).
}
inline void DtlsTransport::OnTimer()
{
MS_TRACE();
// Workaround for https://github.com/openssl/openssl/issues/7998.
if (this->handshakeDone)
{
MS_DEBUG_DEV("handshake is done so return");
return;
}
DTLSv1_handle_timeout(this->ssl);
// If required, send DTLS data.
SendPendingOutgoingDtlsData();
// Set the DTLS timer again.
SetTimeout();
}
BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast]
if (!mem || !mem->data || mem->length == 0u)
{
LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed");
X509_free(certificate);
BIO_free(bio);
return false;
}
this->remoteCert = std::string(mem->data, mem->length);
X509_free(certificate);
BIO_free(bio);
return true;
}
inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite)
{
MS_TRACE();
size_t srtpKeyLength{ 0 };
size_t srtpSaltLength{ 0 };
size_t srtpMasterLength{ 0 };
switch (srtpCryptoSuite)
{
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80:
case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32:
{
srtpKeyLength = SrtpMasterKeyLength;
srtpSaltLength = SrtpMasterSaltLength;
srtpMasterLength = SrtpMasterLength;
break;
}
case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM:
{
srtpKeyLength = SrtpAesGcm256MasterKeyLength;
srtpSaltLength = SrtpAesGcm256MasterSaltLength;
srtpMasterLength = SrtpAesGcm256MasterLength;
break;
}
case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM:
{
srtpKeyLength = SrtpAesGcm128MasterKeyLength;
srtpSaltLength = SrtpAesGcm128MasterSaltLength;
srtpMasterLength = SrtpAesGcm128MasterLength;
break;
}
default:
{
MS_ABORT("unknown SRTP crypto suite");
}
}
auto* srtpMaterial = new uint8_t[srtpMasterLength * 2];
uint8_t* srtpLocalKey{ nullptr };
uint8_t* srtpLocalSalt{ nullptr };
uint8_t* srtpRemoteKey{ nullptr };
uint8_t* srtpRemoteSalt{ nullptr };
auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength];
auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength];
int ret;
ret = SSL_export_keying_material(
this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0);
MS_ASSERT(ret != 0, "SSL_export_keying_material() failed");
switch (this->localRole)
{
case Role::SERVER:
{
srtpRemoteKey = srtpMaterial;
srtpLocalKey = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteSalt + srtpSaltLength;
break;
}
case Role::CLIENT:
{
srtpLocalKey = srtpMaterial;
srtpRemoteKey = srtpLocalKey + srtpKeyLength;
srtpLocalSalt = srtpRemoteKey + srtpKeyLength;
srtpRemoteSalt = srtpLocalSalt + srtpSaltLength;
break;
}
default:
{
MS_ABORT("no DTLS role set");
}
}
// Create the SRTP local master key.
std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength);
std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength);
// Create the SRTP remote master key.
std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength);
std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength);
// Set state and notify the listener.
this->state = DtlsState::CONNECTED;
this->listener->OnDtlsTransportConnected(
this,
srtpCryptoSuite,
srtpLocalMasterKey,
srtpMasterLength,
srtpRemoteMasterKey,
srtpMasterLength,
this->remoteCert);
delete[] srtpMaterial;
delete[] srtpLocalMasterKey;
delete[] srtpRemoteMasterKey;
}
inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite()
{
MS_TRACE();
RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE;
// Ensure that the SRTP crypto suite has been negotiated.
// NOTE: This is a OpenSSL type.
SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl);
if (!sslSrtpCryptoSuite)
return negotiatedSrtpCryptoSuite;
// Get the negotiated SRTP crypto suite.
for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites)
{
SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite);
if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0)
{
MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name);
negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite;
}
}
MS_ASSERT(
negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE,
"chosen SRTP crypto suite is not an available one");
return negotiatedSrtpCryptoSuite;
}
inline void DtlsTransport::OnSslInfo(int where, int ret)
{
MS_TRACE();
int w = where & -SSL_ST_MASK;
const char* role;
if ((w & SSL_ST_CONNECT) != 0)
role = "client";
else if ((w & SSL_ST_ACCEPT) != 0)
role = "server";
else
role = "undefined";
if ((where & SSL_CB_LOOP) != 0)
{
MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl));
}
else if ((where & SSL_CB_ALERT) != 0)
{
const char* alertType;
switch (*SSL_alert_type_string(ret))
{
case 'W':
alertType = "warning";
break;
case 'F':
alertType = "fatal";
break;
default:
alertType = "undefined";
}
if ((where & SSL_CB_READ) != 0)
{
MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
else if ((where & SSL_CB_WRITE) != 0)
{
MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
else
{
MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret));
}
}
else if ((where & SSL_CB_EXIT) != 0)
{
if (ret == 0)
MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl));
else if (ret < 0)
MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl));
}
else if ((where & SSL_CB_HANDSHAKE_START) != 0)
{
MS_DEBUG_TAG(dtls, "DTLS handshake start");
}
else if ((where & SSL_CB_HANDSHAKE_DONE) != 0)
{
MS_DEBUG_TAG(dtls, "DTLS handshake done");
this->handshakeDoneNow = true;
}
// NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon
// receipt of a close alert does not work (the flag is set after this callback).
}
inline void DtlsTransport::OnTimer()
{
MS_TRACE();
// Workaround for https://github.com/openssl/openssl/issues/7998.
if (this->handshakeDone)
{
MS_DEBUG_DEV("handshake is done so return");
return;
}
DTLSv1_handle_timeout(this->ssl);
// If required, send DTLS data.
SendPendingOutgoingDtlsData();
// Set the DTLS timer again.
SetTimeout();
}
} // namespace RTC
......@@ -33,50 +33,50 @@ using namespace toolkit;
namespace RTC
{
class DtlsTransport : public std::enable_shared_from_this<DtlsTransport>
{
public:
enum class DtlsState
{
NEW = 1,
CONNECTING,
CONNECTED,
FAILED,
CLOSED
};
public:
enum class Role
{
NONE = 0,
AUTO = 1,
CLIENT,
SERVER
};
public:
enum class FingerprintAlgorithm
{
NONE = 0,
SHA1 = 1,
SHA224,
SHA256,
SHA384,
SHA512
};
public:
struct Fingerprint
{
FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE };
std::string value;
};
private:
struct SrtpCryptoSuiteMapEntry
{
RTC::SrtpSession::CryptoSuite cryptoSuite;
const char* name;
};
{
public:
enum class DtlsState
{
NEW = 1,
CONNECTING,
CONNECTED,
FAILED,
CLOSED
};
public:
enum class Role
{
NONE = 0,
AUTO = 1,
CLIENT,
SERVER
};
public:
enum class FingerprintAlgorithm
{
NONE = 0,
SHA1 = 1,
SHA224,
SHA256,
SHA384,
SHA512
};
public:
struct Fingerprint
{
FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE };
std::string value;
};
private:
struct SrtpCryptoSuiteMapEntry
{
RTC::SrtpSession::CryptoSuite cryptoSuite;
const char* name;
};
class DtlsEnvironment : public std::enable_shared_from_this<DtlsEnvironment>
{
......@@ -99,154 +99,154 @@ namespace RTC
std::vector<Fingerprint> localFingerprints;
};
public:
class Listener
{
public:
// DTLS is in the process of negotiating a secure connection. Incoming
// media can flow through.
// NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0;
// DTLS has completed negotiation of a secure connection (including DTLS-SRTP
// and remote fingerprint verification). Outgoing media can now flow through.
// NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnected(
const RTC::DtlsTransport* dtlsTransport,
RTC::SrtpSession::CryptoSuite srtpCryptoSuite,
uint8_t* srtpLocalKey,
size_t srtpLocalKeyLen,
uint8_t* srtpRemoteKey,
size_t srtpRemoteKeyLen,
std::string& remoteCert) = 0;
// The DTLS connection has been closed as the result of an error (such as a
// DTLS alert or a failure to validate the remote fingerprint).
virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0;
// The DTLS connection has been closed due to receipt of a close_notify alert.
virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0;
// Need to send DTLS data to the peer.
virtual void OnDtlsTransportSendData(
const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
// DTLS application data received.
virtual void OnDtlsTransportApplicationDataReceived(
const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
};
public:
static Role StringToRole(const std::string& role)
{
auto it = DtlsTransport::string2Role.find(role);
if (it != DtlsTransport::string2Role.end())
return it->second;
else
return DtlsTransport::Role::NONE;
}
static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint)
{
auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint);
if (it != DtlsTransport::string2FingerprintAlgorithm.end())
return it->second;
else
return DtlsTransport::FingerprintAlgorithm::NONE;
}
static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint)
{
auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint);
return it->second;
}
static bool IsDtls(const uint8_t* data, size_t len)
{
// clang-format off
return (
// Minimum DTLS record length is 13 bytes.
(len >= 13) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] > 19 && data[0] < 64)
);
// clang-format on
}
private:
static std::map<std::string, Role> string2Role;
static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm;
static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String;
static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites;
public:
DtlsTransport(EventPoller::Ptr poller, Listener* listener);
~DtlsTransport();
public:
void Dump() const;
void Run(Role localRole);
std::vector<Fingerprint>& GetLocalFingerprints() const
{
return env->localFingerprints;
}
bool SetRemoteFingerprint(Fingerprint fingerprint);
void ProcessDtlsData(const uint8_t* data, size_t len);
DtlsState GetState() const
{
return this->state;
}
Role GetLocalRole() const
{
return this->localRole;
}
void SendApplicationData(const uint8_t* data, size_t len);
private:
bool IsRunning() const
{
switch (this->state)
{
case DtlsState::NEW:
return false;
case DtlsState::CONNECTING:
case DtlsState::CONNECTED:
return true;
case DtlsState::FAILED:
case DtlsState::CLOSED:
return false;
}
// Make GCC 4.9 happy.
return false;
}
void Reset();
bool CheckStatus(int returnCode);
void SendPendingOutgoingDtlsData();
bool SetTimeout();
bool ProcessHandshake();
bool CheckRemoteFingerprint();
void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite);
RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite();
public:
class Listener
{
public:
// DTLS is in the process of negotiating a secure connection. Incoming
// media can flow through.
// NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0;
// DTLS has completed negotiation of a secure connection (including DTLS-SRTP
// and remote fingerprint verification). Outgoing media can now flow through.
// NOTE: The caller MUST NOT call any method during this callback.
virtual void OnDtlsTransportConnected(
const RTC::DtlsTransport* dtlsTransport,
RTC::SrtpSession::CryptoSuite srtpCryptoSuite,
uint8_t* srtpLocalKey,
size_t srtpLocalKeyLen,
uint8_t* srtpRemoteKey,
size_t srtpRemoteKeyLen,
std::string& remoteCert) = 0;
// The DTLS connection has been closed as the result of an error (such as a
// DTLS alert or a failure to validate the remote fingerprint).
virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0;
// The DTLS connection has been closed due to receipt of a close_notify alert.
virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0;
// Need to send DTLS data to the peer.
virtual void OnDtlsTransportSendData(
const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
// DTLS application data received.
virtual void OnDtlsTransportApplicationDataReceived(
const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0;
};
public:
static Role StringToRole(const std::string& role)
{
auto it = DtlsTransport::string2Role.find(role);
if (it != DtlsTransport::string2Role.end())
return it->second;
else
return DtlsTransport::Role::NONE;
}
static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint)
{
auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint);
if (it != DtlsTransport::string2FingerprintAlgorithm.end())
return it->second;
else
return DtlsTransport::FingerprintAlgorithm::NONE;
}
static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint)
{
auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint);
return it->second;
}
static bool IsDtls(const uint8_t* data, size_t len)
{
// clang-format off
return (
// Minimum DTLS record length is 13 bytes.
(len >= 13) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] > 19 && data[0] < 64)
);
// clang-format on
}
private:
static std::map<std::string, Role> string2Role;
static std::map<std::string, FingerprintAlgorithm> string2FingerprintAlgorithm;
static std::map<FingerprintAlgorithm, std::string> fingerprintAlgorithm2String;
static std::vector<SrtpCryptoSuiteMapEntry> srtpCryptoSuites;
public:
DtlsTransport(EventPoller::Ptr poller, Listener* listener);
~DtlsTransport();
public:
void Dump() const;
void Run(Role localRole);
std::vector<Fingerprint>& GetLocalFingerprints() const
{
return env->localFingerprints;
}
bool SetRemoteFingerprint(Fingerprint fingerprint);
void ProcessDtlsData(const uint8_t* data, size_t len);
DtlsState GetState() const
{
return this->state;
}
Role GetLocalRole() const
{
return this->localRole;
}
void SendApplicationData(const uint8_t* data, size_t len);
private:
void OnSslInfo(int where, int ret);
void OnTimer();
bool IsRunning() const
{
switch (this->state)
{
case DtlsState::NEW:
return false;
case DtlsState::CONNECTING:
case DtlsState::CONNECTED:
return true;
case DtlsState::FAILED:
case DtlsState::CLOSED:
return false;
}
// Make GCC 4.9 happy.
return false;
}
void Reset();
bool CheckStatus(int returnCode);
void SendPendingOutgoingDtlsData();
bool SetTimeout();
bool ProcessHandshake();
bool CheckRemoteFingerprint();
void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite);
RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite();
private:
void OnSslInfo(int where, int ret);
void OnTimer();
private:
private:
DtlsEnvironment::Ptr env;
EventPoller::Ptr poller;
// Passed by argument.
Listener* listener{ nullptr };
// Allocated by this.
SSL* ssl{ nullptr };
BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads.
BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes.
Timer::Ptr timer;
// Others.
DtlsState state{ DtlsState::NEW };
Role localRole{ Role::NONE };
Fingerprint remoteFingerprint;
bool handshakeDone{ false };
bool handshakeDoneNow{ false };
std::string remoteCert;
//最大不超过mtu
static constexpr int SslReadBufferSize{ 2000 };
Listener* listener{ nullptr };
// Allocated by this.
SSL* ssl{ nullptr };
BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads.
BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes.
Timer::Ptr timer;
// Others.
DtlsState state{ DtlsState::NEW };
Role localRole{ Role::NONE };
Fingerprint remoteFingerprint;
bool handshakeDone{ false };
bool handshakeDoneNow{ false };
std::string remoteCert;
//最大不超过mtu
static constexpr int SslReadBufferSize{ 2000 };
uint8_t sslReadBuffer[SslReadBufferSize];
};
} // namespace RTC
......
......@@ -24,505 +24,505 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
namespace RTC
{
/* Static. */
/* Instance methods. */
IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password)
: listener(listener), usernameFragment(usernameFragment), password(password)
{
MS_TRACE();
}
void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple)
{
MS_TRACE();
// Must be a Binding method.
if (packet->GetMethod() != RTC::StunPacket::Method::BINDING)
{
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
{
MS_WARN_TAG(
ice,
"unknown method %#.3x in STUN Request => 400",
static_cast<unsigned int>(packet->GetMethod()));
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
}
else
{
MS_WARN_TAG(
ice,
"ignoring STUN Indication or Response with unknown method %#.3x",
static_cast<unsigned int>(packet->GetMethod()));
}
return;
}
// Must use FINGERPRINT (optional for ICE STUN indications).
if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION)
{
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
{
MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
}
else
{
MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT");
}
return;
}
switch (packet->GetClass())
{
case RTC::StunPacket::Class::REQUEST:
{
// USERNAME, MESSAGE-INTEGRITY and PRIORITY are required.
if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty())
{
MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
// Check authentication.
switch (packet->CheckAuthentication(this->usernameFragment, this->password))
{
case RTC::StunPacket::Authentication::OK:
{
if (!this->oldPassword.empty())
{
MS_DEBUG_TAG(ice, "new ICE credentials applied");
this->oldUsernameFragment.clear();
this->oldPassword.clear();
}
break;
}
case RTC::StunPacket::Authentication::UNAUTHORIZED:
{
// We may have changed our usernameFragment and password, so check
// the old ones.
// clang-format off
if (
!this->oldUsernameFragment.empty() &&
!this->oldPassword.empty() &&
packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK
)
// clang-format on
{
MS_DEBUG_TAG(ice, "using old ICE credentials");
break;
}
MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401");
// Reply 401.
RTC::StunPacket* response = packet->CreateErrorResponse(401);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
case RTC::StunPacket::Authentication::BAD_REQUEST:
{
MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
/* Static. */
/* Instance methods. */
IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password)
: listener(listener), usernameFragment(usernameFragment), password(password)
{
MS_TRACE();
}
void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple)
{
MS_TRACE();
// Must be a Binding method.
if (packet->GetMethod() != RTC::StunPacket::Method::BINDING)
{
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
{
MS_WARN_TAG(
ice,
"unknown method %#.3x in STUN Request => 400",
static_cast<unsigned int>(packet->GetMethod()));
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
}
else
{
MS_WARN_TAG(
ice,
"ignoring STUN Indication or Response with unknown method %#.3x",
static_cast<unsigned int>(packet->GetMethod()));
}
return;
}
// Must use FINGERPRINT (optional for ICE STUN indications).
if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION)
{
if (packet->GetClass() == RTC::StunPacket::Class::REQUEST)
{
MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
}
else
{
MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT");
}
return;
}
switch (packet->GetClass())
{
case RTC::StunPacket::Class::REQUEST:
{
// USERNAME, MESSAGE-INTEGRITY and PRIORITY are required.
if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty())
{
MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
// Check authentication.
switch (packet->CheckAuthentication(this->usernameFragment, this->password))
{
case RTC::StunPacket::Authentication::OK:
{
if (!this->oldPassword.empty())
{
MS_DEBUG_TAG(ice, "new ICE credentials applied");
this->oldUsernameFragment.clear();
this->oldPassword.clear();
}
break;
}
case RTC::StunPacket::Authentication::UNAUTHORIZED:
{
// We may have changed our usernameFragment and password, so check
// the old ones.
// clang-format off
if (
!this->oldUsernameFragment.empty() &&
!this->oldPassword.empty() &&
packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK
)
// clang-format on
{
MS_DEBUG_TAG(ice, "using old ICE credentials");
break;
}
MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401");
// Reply 401.
RTC::StunPacket* response = packet->CreateErrorResponse(401);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
case RTC::StunPacket::Authentication::BAD_REQUEST:
{
MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400");
// Reply 400.
RTC::StunPacket* response = packet->CreateErrorResponse(400);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
return;
}
}
delete response;
return;
}
}
#if 0
// The remote peer must be ICE controlling.
if (packet->GetIceControlled())
{
MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487");
// The remote peer must be ICE controlling.
if (packet->GetIceControlled())
{
MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487");
// Reply 487 (Role Conflict).
RTC::StunPacket* response = packet->CreateErrorResponse(487);
// Reply 487 (Role Conflict).
RTC::StunPacket* response = packet->CreateErrorResponse(487);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
delete response;
return;
}
return;
}
#endif
//MS_DEBUG_DEV(
// "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]",
// static_cast<uint32_t>(packet->GetPriority()),
// packet->HasUseCandidate() ? "true" : "false");
//MS_DEBUG_DEV(
// "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]",
// static_cast<uint32_t>(packet->GetPriority()),
// packet->HasUseCandidate() ? "true" : "false");
// Create a success response.
RTC::StunPacket* response = packet->CreateSuccessResponse();
// Create a success response.
RTC::StunPacket* response = packet->CreateSuccessResponse();
sockaddr_storage peerAddr;
socklen_t addr_len = sizeof(peerAddr);
getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len);
sockaddr_storage peerAddr;
socklen_t addr_len = sizeof(peerAddr);
getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len);
// Add XOR-MAPPED-ADDRESS.
response->SetXorMappedAddress((struct sockaddr *)&peerAddr);
// Add XOR-MAPPED-ADDRESS.
response->SetXorMappedAddress((struct sockaddr *)&peerAddr);
// Authenticate the response.
if (this->oldPassword.empty())
response->Authenticate(this->password);
else
response->Authenticate(this->oldPassword);
// Authenticate the response.
if (this->oldPassword.empty())
response->Authenticate(this->password);
else
response->Authenticate(this->oldPassword);
// Send back.
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
// Send back.
response->Serialize(StunSerializeBuffer);
this->listener->OnIceServerSendStunPacket(this, response, tuple);
delete response;
delete response;
// Handle the tuple.
HandleTuple(tuple, packet->HasUseCandidate());
// Handle the tuple.
HandleTuple(tuple, packet->HasUseCandidate());
break;
}
break;
}
case RTC::StunPacket::Class::INDICATION:
{
MS_DEBUG_TAG(ice, "STUN Binding Indication processed");
case RTC::StunPacket::Class::INDICATION:
{
MS_DEBUG_TAG(ice, "STUN Binding Indication processed");
break;
}
break;
}
case RTC::StunPacket::Class::SUCCESS_RESPONSE:
{
MS_DEBUG_TAG(ice, "STUN Binding Success Response processed");
case RTC::StunPacket::Class::SUCCESS_RESPONSE:
{
MS_DEBUG_TAG(ice, "STUN Binding Success Response processed");
break;
}
break;
}
case RTC::StunPacket::Class::ERROR_RESPONSE:
{
MS_DEBUG_TAG(ice, "STUN Binding Error Response processed");
case RTC::StunPacket::Class::ERROR_RESPONSE:
{
MS_DEBUG_TAG(ice, "STUN Binding Error Response processed");
break;
}
}
}
break;
}
}
}
bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const
{
MS_TRACE();
return HasTuple(tuple) != nullptr;
}
void IceServer::RemoveTuple(RTC::TransportTuple* tuple)
{
MS_TRACE();
RTC::TransportTuple* removedTuple{ nullptr };
// Find the removed tuple.
auto it = this->tuples.begin();
for (; it != this->tuples.end(); ++it)
{
RTC::TransportTuple* storedTuple = *it;
if (storedTuple == tuple)
{
removedTuple = storedTuple;
break;
}
}
// If not found, ignore.
if (!removedTuple)
return;
// Remove from the list of tuples.
this->tuples.erase(it);
// If this is not the selected tuple, stop here.
if (removedTuple != this->selectedTuple)
return;
// Otherwise this was the selected tuple.
this->selectedTuple = nullptr;
// Mark the first tuple as selected tuple (if any).
if (!this->tuples.empty())
{
SetSelectedTuple(this->tuples.front());
}
// Or just emit 'disconnected'.
else
{
// Update state.
this->state = IceState::DISCONNECTED;
// Notify the listener.
this->listener->OnIceServerDisconnected(this);
}
}
void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple)
{
MS_TRACE();
MS_ASSERT(
this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple");
auto* storedTuple = HasTuple(tuple);
MS_ASSERT(
storedTuple,
"cannot force the selected tuple if the given tuple was not already a valid tuple");
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
}
void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate)
{
MS_TRACE();
switch (this->state)
{
case IceState::NEW:
{
// There should be no tuples.
MS_ASSERT(
this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size());
// There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple");
if (!hasUseCandidate)
{
MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::CONNECTED;
// Notify the listener.
this->listener->OnIceServerConnected(this);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::DISCONNECTED:
{
// There should be no tuples.
MS_ASSERT(
this->tuples.empty(),
"state is 'disconnected' but there are %zu tuples",
this->tuples.size());
// There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple");
if (!hasUseCandidate)
{
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::CONNECTED;
// Notify the listener.
this->listener->OnIceServerConnected(this);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::CONNECTED:
{
// There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples");
// There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple");
if (!hasUseCandidate)
{
// If a new tuple store it.
if (!HasTuple(tuple))
AddTuple(tuple);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'");
auto* storedTuple = HasTuple(tuple);
// If a new tuple store it.
if (!storedTuple)
storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::COMPLETED:
{
// There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples");
// There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple");
if (!hasUseCandidate)
{
// If a new tuple store it.
if (!HasTuple(tuple))
AddTuple(tuple);
}
else
{
auto* storedTuple = HasTuple(tuple);
// If a new tuple store it.
if (!storedTuple)
storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
}
break;
}
}
}
inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple)
{
MS_TRACE();
// Add the new tuple at the beginning of the list.
this->tuples.push_front(tuple);
// Return the address of the inserted tuple.
return tuple;
}
inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const
{
MS_TRACE();
// If there is no selected tuple yet then we know that the tuples list
// is empty.
if (!this->selectedTuple)
return nullptr;
// Check the current selected tuple.
if (selectedTuple == tuple)
return this->selectedTuple;
// Otherwise check other stored tuples.
for (const auto& it : this->tuples)
{
auto& storedTuple = it;
if (storedTuple == tuple)
return storedTuple;
}
return nullptr;
}
inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple)
{
MS_TRACE();
// If already the selected tuple do nothing.
if (storedTuple == this->selectedTuple)
return;
this->selectedTuple = storedTuple;
bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const
{
MS_TRACE();
return HasTuple(tuple) != nullptr;
}
void IceServer::RemoveTuple(RTC::TransportTuple* tuple)
{
MS_TRACE();
RTC::TransportTuple* removedTuple{ nullptr };
// Find the removed tuple.
auto it = this->tuples.begin();
for (; it != this->tuples.end(); ++it)
{
RTC::TransportTuple* storedTuple = *it;
if (storedTuple == tuple)
{
removedTuple = storedTuple;
break;
}
}
// If not found, ignore.
if (!removedTuple)
return;
// Remove from the list of tuples.
this->tuples.erase(it);
// If this is not the selected tuple, stop here.
if (removedTuple != this->selectedTuple)
return;
// Otherwise this was the selected tuple.
this->selectedTuple = nullptr;
// Mark the first tuple as selected tuple (if any).
if (!this->tuples.empty())
{
SetSelectedTuple(this->tuples.front());
}
// Or just emit 'disconnected'.
else
{
// Update state.
this->state = IceState::DISCONNECTED;
// Notify the listener.
this->listener->OnIceServerDisconnected(this);
}
}
void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple)
{
MS_TRACE();
MS_ASSERT(
this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple");
auto* storedTuple = HasTuple(tuple);
MS_ASSERT(
storedTuple,
"cannot force the selected tuple if the given tuple was not already a valid tuple");
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
}
void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate)
{
MS_TRACE();
switch (this->state)
{
case IceState::NEW:
{
// There should be no tuples.
MS_ASSERT(
this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size());
// There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple");
if (!hasUseCandidate)
{
MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::CONNECTED;
// Notify the listener.
this->listener->OnIceServerConnected(this);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::DISCONNECTED:
{
// There should be no tuples.
MS_ASSERT(
this->tuples.empty(),
"state is 'disconnected' but there are %zu tuples",
this->tuples.size());
// There shouldn't be a selected tuple.
MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple");
if (!hasUseCandidate)
{
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::CONNECTED;
// Notify the listener.
this->listener->OnIceServerConnected(this);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'");
// Store the tuple.
auto* storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::CONNECTED:
{
// There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples");
// There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple");
if (!hasUseCandidate)
{
// If a new tuple store it.
if (!HasTuple(tuple))
AddTuple(tuple);
}
else
{
MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'");
auto* storedTuple = HasTuple(tuple);
// If a new tuple store it.
if (!storedTuple)
storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
// Update state.
this->state = IceState::COMPLETED;
// Notify the listener.
this->listener->OnIceServerCompleted(this);
}
break;
}
case IceState::COMPLETED:
{
// There should be some tuples.
MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples");
// There should be a selected tuple.
MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple");
if (!hasUseCandidate)
{
// If a new tuple store it.
if (!HasTuple(tuple))
AddTuple(tuple);
}
else
{
auto* storedTuple = HasTuple(tuple);
// If a new tuple store it.
if (!storedTuple)
storedTuple = AddTuple(tuple);
// Mark it as selected tuple.
SetSelectedTuple(storedTuple);
}
break;
}
}
}
inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple)
{
MS_TRACE();
// Add the new tuple at the beginning of the list.
this->tuples.push_front(tuple);
// Return the address of the inserted tuple.
return tuple;
}
inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const
{
MS_TRACE();
// If there is no selected tuple yet then we know that the tuples list
// is empty.
if (!this->selectedTuple)
return nullptr;
// Check the current selected tuple.
if (selectedTuple == tuple)
return this->selectedTuple;
// Otherwise check other stored tuples.
for (const auto& it : this->tuples)
{
auto& storedTuple = it;
if (storedTuple == tuple)
return storedTuple;
}
return nullptr;
}
inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple)
{
MS_TRACE();
// If already the selected tuple do nothing.
if (storedTuple == this->selectedTuple)
return;
this->selectedTuple = storedTuple;
this->lastSelectedTuple = storedTuple->shared_from_this();
// Notify the listener.
this->listener->OnIceServerSelectedTuple(this, this->selectedTuple);
}
// Notify the listener.
this->listener->OnIceServerSelectedTuple(this, this->selectedTuple);
}
} // namespace RTC
......@@ -30,109 +30,109 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
namespace RTC
{
using TransportTuple = toolkit::Session;
class IceServer
{
public:
enum class IceState
{
NEW = 1,
CONNECTED,
COMPLETED,
DISCONNECTED
};
using TransportTuple = toolkit::Session;
class IceServer
{
public:
enum class IceState
{
NEW = 1,
CONNECTED,
COMPLETED,
DISCONNECTED
};
public:
class Listener
{
public:
virtual ~Listener() = default;
public:
class Listener
{
public:
virtual ~Listener() = default;
public:
/**
* These callbacks are guaranteed to be called before ProcessStunPacket()
* returns, so the given pointers are still usable.
*/
virtual void OnIceServerSendStunPacket(
const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerSelectedTuple(
const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0;
};
public:
/**
* These callbacks are guaranteed to be called before ProcessStunPacket()
* returns, so the given pointers are still usable.
*/
virtual void OnIceServerSendStunPacket(
const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerSelectedTuple(
const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0;
virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0;
virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0;
};
public:
IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password);
public:
IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password);
public:
void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple);
const std::string& GetUsernameFragment() const
{
return this->usernameFragment;
}
const std::string& GetPassword() const
{
return this->password;
}
IceState GetState() const
{
return this->state;
}
RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const
{
public:
void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple);
const std::string& GetUsernameFragment() const
{
return this->usernameFragment;
}
const std::string& GetPassword() const
{
return this->password;
}
IceState GetState() const
{
return this->state;
}
RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const
{
return try_last_tuple ? this->lastSelectedTuple.lock().get() : this->selectedTuple;
}
void SetUsernameFragment(const std::string& usernameFragment)
{
this->oldUsernameFragment = this->usernameFragment;
this->usernameFragment = usernameFragment;
}
void SetPassword(const std::string& password)
{
this->oldPassword = this->password;
this->password = password;
}
bool IsValidTuple(const RTC::TransportTuple* tuple) const;
void RemoveTuple(RTC::TransportTuple* tuple);
// This should be just called in 'connected' or completed' state
// and the given tuple must be an already valid tuple.
void ForceSelectedTuple(const RTC::TransportTuple* tuple);
void SetUsernameFragment(const std::string& usernameFragment)
{
this->oldUsernameFragment = this->usernameFragment;
this->usernameFragment = usernameFragment;
}
void SetPassword(const std::string& password)
{
this->oldPassword = this->password;
this->password = password;
}
bool IsValidTuple(const RTC::TransportTuple* tuple) const;
void RemoveTuple(RTC::TransportTuple* tuple);
// This should be just called in 'connected' or completed' state
// and the given tuple must be an already valid tuple.
void ForceSelectedTuple(const RTC::TransportTuple* tuple);
const std::list<RTC::TransportTuple *>& GetTuples() const { return tuples; }
private:
void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate);
/**
* Store the given tuple and return its stored address.
*/
RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple);
/**
* If the given tuple exists return its stored address, nullptr otherwise.
*/
RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const;
/**
* Set the given tuple as the selected tuple.
* NOTE: The given tuple MUST be already stored within the list.
*/
void SetSelectedTuple(RTC::TransportTuple* storedTuple);
void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate);
/**
* Store the given tuple and return its stored address.
*/
RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple);
/**
* If the given tuple exists return its stored address, nullptr otherwise.
*/
RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const;
/**
* Set the given tuple as the selected tuple.
* NOTE: The given tuple MUST be already stored within the list.
*/
void SetSelectedTuple(RTC::TransportTuple* storedTuple);
private:
// Passed by argument.
Listener* listener{ nullptr };
// Others.
std::string usernameFragment;
std::string password;
std::string oldUsernameFragment;
std::string oldPassword;
IceState state{ IceState::NEW };
std::list<RTC::TransportTuple *> tuples;
private:
// Passed by argument.
Listener* listener{ nullptr };
// Others.
std::string usernameFragment;
std::string password;
std::string oldUsernameFragment;
std::string oldPassword;
IceState state{ IceState::NEW };
std::list<RTC::TransportTuple *> tuples;
RTC::TransportTuple *selectedTuple;
std::weak_ptr<RTC::TransportTuple> lastSelectedTuple;
//最大不超过mtu
//最大不超过mtu
static constexpr size_t StunSerializeBufferSize{ 1600 };
uint8_t StunSerializeBuffer[StunSerializeBufferSize];
};
};
} // namespace RTC
#endif
......@@ -23,14 +23,14 @@ static constexpr uint16_t MaxSctpStreams{ 65535 };
/* clang-format off */
static constexpr uint16_t EventTypes[] =
{
SCTP_ADAPTATION_INDICATION,
SCTP_ASSOC_CHANGE,
SCTP_ASSOC_RESET_EVENT,
SCTP_REMOTE_ERROR,
SCTP_SHUTDOWN_EVENT,
SCTP_SEND_FAILED_EVENT,
SCTP_STREAM_RESET_EVENT,
SCTP_STREAM_CHANGE_EVENT
SCTP_ADAPTATION_INDICATION,
SCTP_ASSOC_CHANGE,
SCTP_ASSOC_RESET_EVENT,
SCTP_REMOTE_ERROR,
SCTP_SHUTDOWN_EVENT,
SCTP_SEND_FAILED_EVENT,
SCTP_STREAM_RESET_EVENT,
SCTP_STREAM_CHANGE_EVENT
};
/* clang-format on */
......@@ -44,45 +44,45 @@ inline static int onRecvSctpData(
int flags,
void* ulpInfo)
{
auto* sctpAssociation = static_cast<RTC::SctpAssociation*>(ulpInfo);
if (sctpAssociation == nullptr)
{
std::free(data);
return 0;
}
if (flags & MSG_NOTIFICATION)
{
sctpAssociation->OnUsrSctpReceiveSctpNotification(
static_cast<union sctp_notification*>(data), len);
}
else
{
uint16_t streamId = rcv.rcv_sid;
uint32_t ppid = ntohl(rcv.rcv_ppid);
uint16_t ssn = rcv.rcv_ssn;
MS_DEBUG_TAG(
sctp,
"data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32
", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]",
len,
rcv.rcv_sid,
rcv.rcv_ssn,
rcv.rcv_tsn,
ntohl(rcv.rcv_ppid),
rcv.rcv_context,
flags);
sctpAssociation->OnUsrSctpReceiveSctpData(
streamId, ssn, ppid, flags, static_cast<uint8_t*>(data), len);
}
std::free(data);
return 1;
auto* sctpAssociation = static_cast<RTC::SctpAssociation*>(ulpInfo);
if (sctpAssociation == nullptr)
{
std::free(data);
return 0;
}
if (flags & MSG_NOTIFICATION)
{
sctpAssociation->OnUsrSctpReceiveSctpNotification(
static_cast<union sctp_notification*>(data), len);
}
else
{
uint16_t streamId = rcv.rcv_sid;
uint32_t ppid = ntohl(rcv.rcv_ppid);
uint16_t ssn = rcv.rcv_ssn;
MS_DEBUG_TAG(
sctp,
"data chunk received [length:%zu, streamId:%" PRIu16 ", SSN:%" PRIu16 ", TSN:%" PRIu32
", PPID:%" PRIu32 ", context:%" PRIu32 ", flags:%d]",
len,
rcv.rcv_sid,
rcv.rcv_ssn,
rcv.rcv_tsn,
ntohl(rcv.rcv_ppid),
rcv.rcv_context,
flags);
sctpAssociation->OnUsrSctpReceiveSctpData(
streamId, ssn, ppid, flags, static_cast<uint8_t*>(data), len);
}
std::free(data);
return 1;
}
/* Static methods for usrsctp global callbacks. */
......@@ -136,824 +136,824 @@ namespace RTC
////////////////////////////////////////////////////////////////////////////////////
/* Instance methods. */
/* Instance methods. */
SctpAssociation::SctpAssociation(
Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel)
: listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize),
isDataChannel(isDataChannel)
{
MS_TRACE();
SctpAssociation::SctpAssociation(
Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel)
: listener(listener), os(os), mis(mis), maxSctpMessageSize(maxSctpMessageSize),
isDataChannel(isDataChannel)
{
MS_TRACE();
_env = SctpEnv::Instance().shared_from_this();
// Register ourselves in usrsctp.
usrsctp_register_address(static_cast<void*>(this));
// Register ourselves in usrsctp.
usrsctp_register_address(static_cast<void*>(this));
int ret;
int ret;
this->socket = usrsctp_socket(
AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast<void*>(this));
this->socket = usrsctp_socket(
AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast<void*>(this));
if (this->socket == nullptr)
MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno));
if (this->socket == nullptr)
MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno));
usrsctp_set_ulpinfo(this->socket, static_cast<void*>(this));
usrsctp_set_ulpinfo(this->socket, static_cast<void*>(this));
// Make the socket non-blocking.
ret = usrsctp_set_non_blocking(this->socket, 1);
// Make the socket non-blocking.
ret = usrsctp_set_non_blocking(this->socket, 1);
if (ret < 0)
MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno));
if (ret < 0)
MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno));
// Set SO_LINGER.
// This ensures that the usrsctp close call deletes the association. This
// prevents usrsctp from calling the global send callback with references to
// this class as the address.
struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Set SO_LINGER.
// This ensures that the usrsctp close call deletes the association. This
// prevents usrsctp from calling the global send callback with references to
// this class as the address.
struct linger lingerOpt; // NOLINT(cppcoreguidelines-pro-type-member-init)
lingerOpt.l_onoff = 1;
lingerOpt.l_linger = 0;
lingerOpt.l_onoff = 1;
lingerOpt.l_linger = 0;
ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt));
ret = usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_LINGER, &lingerOpt, sizeof(lingerOpt));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno));
// Set SCTP_ENABLE_STREAM_RESET.
struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Set SCTP_ENABLE_STREAM_RESET.
struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init)
av.assoc_value =
SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ;
av.assoc_value =
SCTP_ENABLE_RESET_STREAM_REQ | SCTP_ENABLE_RESET_ASSOC_REQ | SCTP_ENABLE_CHANGE_ASSOC_REQ;
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av));
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av));
if (ret < 0)
{
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno));
}
if (ret < 0)
{
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno));
}
// Set SCTP_NODELAY.
uint32_t noDelay = 1;
// Set SCTP_NODELAY.
uint32_t noDelay = 1;
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay));
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_NODELAY, &noDelay, sizeof(noDelay));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno));
// Enable events.
struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Enable events.
struct sctp_event event; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&event, 0, sizeof(event));
event.se_on = 1;
std::memset(&event, 0, sizeof(event));
event.se_on = 1;
for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i)
{
event.se_type = EventTypes[i];
for (size_t i{ 0 }; i < sizeof(EventTypes) / sizeof(uint16_t); ++i)
{
event.se_type = EventTypes[i];
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event));
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_EVENT, &event, sizeof(event));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno));
}
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno));
}
// Init message.
struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Init message.
struct sctp_initmsg initmsg; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&initmsg, 0, sizeof(initmsg));
initmsg.sinit_num_ostreams = this->os;
initmsg.sinit_max_instreams = this->mis;
std::memset(&initmsg, 0, sizeof(initmsg));
initmsg.sinit_num_ostreams = this->os;
initmsg.sinit_max_instreams = this->mis;
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg));
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_INITMSG, &initmsg, sizeof(initmsg));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno));
// Server side.
struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Server side.
struct sockaddr_conn sconn; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&sconn, 0, sizeof(sconn));
sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(5000);
sconn.sconn_addr = static_cast<void*>(this);
std::memset(&sconn, 0, sizeof(sconn));
sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(5000);
sconn.sconn_addr = static_cast<void*>(this);
#ifdef HAVE_SCONN_LEN
sconn.sconn_len = sizeof(sconn);
#endif
ret = usrsctp_bind(this->socket, reinterpret_cast<struct sockaddr*>(&sconn), sizeof(sconn));
ret = usrsctp_bind(this->socket, reinterpret_cast<struct sockaddr*>(&sconn), sizeof(sconn));
if (ret < 0)
MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno));
}
if (ret < 0)
MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno));
}
SctpAssociation::~SctpAssociation()
{
MS_TRACE();
SctpAssociation::~SctpAssociation()
{
MS_TRACE();
usrsctp_set_ulpinfo(this->socket, nullptr);
usrsctp_close(this->socket);
usrsctp_set_ulpinfo(this->socket, nullptr);
usrsctp_close(this->socket);
// Deregister ourselves from usrsctp.
usrsctp_deregister_address(static_cast<void*>(this));
// Deregister ourselves from usrsctp.
usrsctp_deregister_address(static_cast<void*>(this));
delete[] this->messageBuffer;
}
delete[] this->messageBuffer;
}
void SctpAssociation::TransportConnected()
{
MS_TRACE();
void SctpAssociation::TransportConnected()
{
MS_TRACE();
// Just run the SCTP stack if our state is 'new'.
if (this->state != SctpState::NEW)
return;
// Just run the SCTP stack if our state is 'new'.
if (this->state != SctpState::NEW)
return;
try
{
int ret;
struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init)
try
{
int ret;
struct sockaddr_conn rconn; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&rconn, 0, sizeof(rconn));
rconn.sconn_family = AF_CONN;
rconn.sconn_port = htons(5000);
rconn.sconn_addr = static_cast<void*>(this);
std::memset(&rconn, 0, sizeof(rconn));
rconn.sconn_family = AF_CONN;
rconn.sconn_port = htons(5000);
rconn.sconn_addr = static_cast<void*>(this);
#ifdef HAVE_SCONN_LEN
rconn.sconn_len = sizeof(rconn);
rconn.sconn_len = sizeof(rconn);
#endif
ret = usrsctp_connect(this->socket, reinterpret_cast<struct sockaddr*>(&rconn), sizeof(rconn));
ret = usrsctp_connect(this->socket, reinterpret_cast<struct sockaddr*>(&rconn), sizeof(rconn));
if (ret < 0 && errno != EINPROGRESS)
MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno));
if (ret < 0 && errno != EINPROGRESS)
MS_THROW_ERROR("usrsctp_connect() failed: %s", std::strerror(errno));
// Disable MTU discovery.
sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Disable MTU discovery.
sctp_paddrparams peerAddrParams; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&peerAddrParams, 0, sizeof(peerAddrParams));
std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn));
peerAddrParams.spp_flags = SPP_PMTUD_DISABLE;
std::memset(&peerAddrParams, 0, sizeof(peerAddrParams));
std::memcpy(&peerAddrParams.spp_address, &rconn, sizeof(rconn));
peerAddrParams.spp_flags = SPP_PMTUD_DISABLE;
// The MTU value provided specifies the space available for chunks in the
// packet, so let's subtract the SCTP header size.
peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams);
// The MTU value provided specifies the space available for chunks in the
// packet, so let's subtract the SCTP header size.
peerAddrParams.spp_pathmtu = SctpMtu - sizeof(peerAddrParams);
ret = usrsctp_setsockopt(
this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams));
ret = usrsctp_setsockopt(
this->socket, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &peerAddrParams, sizeof(peerAddrParams));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno));
if (ret < 0)
MS_THROW_ERROR("usrsctp_setsockopt(SCTP_PEER_ADDR_PARAMS) failed: %s", std::strerror(errno));
// Announce connecting state.
this->state = SctpState::CONNECTING;
this->listener->OnSctpAssociationConnecting(this);
}
catch (... /*error*/)
{
this->state = SctpState::FAILED;
this->listener->OnSctpAssociationFailed(this);
// Announce connecting state.
this->state = SctpState::CONNECTING;
this->listener->OnSctpAssociationConnecting(this);
}
catch (... /*error*/)
{
this->state = SctpState::FAILED;
this->listener->OnSctpAssociationFailed(this);
throw;
}
}
}
}
void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len)
{
MS_TRACE();
void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len)
{
MS_TRACE();
#if MS_LOG_DEV_LEVEL == 3
MS_DUMP_DATA(data, len);
MS_DUMP_DATA(data, len);
#endif
usrsctp_conninput(static_cast<void*>(this), data, len, 0);
}
usrsctp_conninput(static_cast<void*>(this), data, len, 0);
}
void SctpAssociation::SendSctpMessage(
void SctpAssociation::SendSctpMessage(
const RTC::SctpStreamParameters &parameters, uint32_t ppid, const uint8_t* msg, size_t len)
{
MS_TRACE();
// This must be controlled by the DataConsumer.
MS_ASSERT(
len <= this->maxSctpMessageSize,
"given message exceeds max allowed message size [message size:%zu, max message size:%zu]",
len,
this->maxSctpMessageSize);
// Fill stcp_sendv_spa.
struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&spa, 0, sizeof(spa));
spa.sendv_flags = SCTP_SEND_SNDINFO_VALID;
spa.sendv_sndinfo.snd_sid = parameters.streamId;
spa.sendv_sndinfo.snd_ppid = htonl(ppid);
spa.sendv_sndinfo.snd_flags = SCTP_EOR;
// If ordered it must be reliable.
if (parameters.ordered)
{
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE;
spa.sendv_prinfo.pr_value = 0;
}
// Configure reliability: https://tools.ietf.org/html/rfc3758
else
{
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
if (parameters.maxPacketLifeTime != 0)
{
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL;
spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime;
}
else if (parameters.maxRetransmits != 0)
{
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
spa.sendv_prinfo.pr_value = parameters.maxRetransmits;
}
}
int ret = usrsctp_sendv(
this->socket, msg, len, nullptr, 0, &spa, static_cast<socklen_t>(sizeof(spa)), SCTP_SENDV_SPA, 0);
if (ret < 0)
{
MS_WARN_TAG(
sctp,
"error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s",
parameters.streamId,
ppid,
len,
std::strerror(errno));
}
}
void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters &params)
{
MS_TRACE();
auto streamId = params.streamId;
// We need more OS.
if (streamId > this->os - 1)
AddOutgoingStreams(/*force*/ false);
}
void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters &params)
{
MS_TRACE();
auto streamId = params.streamId;
// Send SCTP_RESET_STREAMS to the remote.
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
if (this->isDataChannel)
ResetSctpStream(streamId, StreamDirection::OUTGOING);
else
ResetSctpStream(streamId, StreamDirection::INCOMING);
}
void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters &params)
{
MS_TRACE();
auto streamId = params.streamId;
// Send SCTP_RESET_STREAMS to the remote.
ResetSctpStream(streamId, StreamDirection::OUTGOING);
}
void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction)
{
MS_TRACE();
// Do nothing if an outgoing stream that could not be allocated by us.
if (direction == StreamDirection::OUTGOING && streamId > this->os - 1)
return;
int ret;
struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init)
socklen_t len = sizeof(av);
{
MS_TRACE();
// This must be controlled by the DataConsumer.
MS_ASSERT(
len <= this->maxSctpMessageSize,
"given message exceeds max allowed message size [message size:%zu, max message size:%zu]",
len,
this->maxSctpMessageSize);
// Fill stcp_sendv_spa.
struct sctp_sendv_spa spa; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&spa, 0, sizeof(spa));
spa.sendv_flags = SCTP_SEND_SNDINFO_VALID;
spa.sendv_sndinfo.snd_sid = parameters.streamId;
spa.sendv_sndinfo.snd_ppid = htonl(ppid);
spa.sendv_sndinfo.snd_flags = SCTP_EOR;
// If ordered it must be reliable.
if (parameters.ordered)
{
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE;
spa.sendv_prinfo.pr_value = 0;
}
// Configure reliability: https://tools.ietf.org/html/rfc3758
else
{
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED;
if (parameters.maxPacketLifeTime != 0)
{
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL;
spa.sendv_prinfo.pr_value = parameters.maxPacketLifeTime;
}
else if (parameters.maxRetransmits != 0)
{
spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX;
spa.sendv_prinfo.pr_value = parameters.maxRetransmits;
}
}
int ret = usrsctp_sendv(
this->socket, msg, len, nullptr, 0, &spa, static_cast<socklen_t>(sizeof(spa)), SCTP_SENDV_SPA, 0);
if (ret < 0)
{
MS_WARN_TAG(
sctp,
"error sending SCTP message [sid:%" PRIu16 ", ppid:%" PRIu32 ", message size:%zu]: %s",
parameters.streamId,
ppid,
len,
std::strerror(errno));
}
}
void SctpAssociation::HandleDataConsumer(const RTC::SctpStreamParameters &params)
{
MS_TRACE();
auto streamId = params.streamId;
// We need more OS.
if (streamId > this->os - 1)
AddOutgoingStreams(/*force*/ false);
}
void SctpAssociation::DataProducerClosed(const RTC::SctpStreamParameters &params)
{
MS_TRACE();
auto streamId = params.streamId;
// Send SCTP_RESET_STREAMS to the remote.
// https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
if (this->isDataChannel)
ResetSctpStream(streamId, StreamDirection::OUTGOING);
else
ResetSctpStream(streamId, StreamDirection::INCOMING);
}
void SctpAssociation::DataConsumerClosed(const RTC::SctpStreamParameters &params)
{
MS_TRACE();
auto streamId = params.streamId;
// Send SCTP_RESET_STREAMS to the remote.
ResetSctpStream(streamId, StreamDirection::OUTGOING);
}
void SctpAssociation::ResetSctpStream(uint16_t streamId, StreamDirection direction)
{
MS_TRACE();
// Do nothing if an outgoing stream that could not be allocated by us.
if (direction == StreamDirection::OUTGOING && streamId > this->os - 1)
return;
int ret;
struct sctp_assoc_value av; // NOLINT(cppcoreguidelines-pro-type-member-init)
socklen_t len = sizeof(av);
#ifndef SCTP_RECONFIG_SUPPORTED
#define SCTP_RECONFIG_SUPPORTED 0x00000029
#endif
ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len);
if (ret == 0)
{
if (av.assoc_value != 1)
{
MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated");
ret = usrsctp_getsockopt(this->socket, IPPROTO_SCTP, SCTP_RECONFIG_SUPPORTED, &av, &len);
return;
}
}
else
{
MS_WARN_TAG(
sctp,
"could not retrieve whether stream reconfiguration has been negotiated: %s\n",
std::strerror(errno));
if (ret == 0)
{
if (av.assoc_value != 1)
{
MS_DEBUG_TAG(sctp, "stream reconfiguration not negotiated");
return;
}
return;
}
}
else
{
MS_WARN_TAG(
sctp,
"could not retrieve whether stream reconfiguration has been negotiated: %s\n",
std::strerror(errno));
return;
}
// As per spec: https://tools.ietf.org/html/rfc6525#section-4.1
len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t);
// As per spec: https://tools.ietf.org/html/rfc6525#section-4.1
len = sizeof(sctp_assoc_t) + (2 + 1) * sizeof(uint16_t);
auto* srs = static_cast<struct sctp_reset_streams*>(std::malloc(len));
auto* srs = static_cast<struct sctp_reset_streams*>(std::malloc(len));
switch (direction)
{
case StreamDirection::INCOMING:
srs->srs_flags = SCTP_STREAM_RESET_INCOMING;
break;
switch (direction)
{
case StreamDirection::INCOMING:
srs->srs_flags = SCTP_STREAM_RESET_INCOMING;
break;
case StreamDirection::OUTGOING:
srs->srs_flags = SCTP_STREAM_RESET_OUTGOING;
break;
}
case StreamDirection::OUTGOING:
srs->srs_flags = SCTP_STREAM_RESET_OUTGOING;
break;
}
srs->srs_number_streams = 1;
srs->srs_stream_list[0] = streamId; // No need for htonl().
srs->srs_number_streams = 1;
srs->srs_stream_list[0] = streamId; // No need for htonl().
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len);
ret = usrsctp_setsockopt(this->socket, IPPROTO_SCTP, SCTP_RESET_STREAMS, srs, len);
if (ret == 0)
{
MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId);
}
else
{
MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno));
}
if (ret == 0)
{
MS_DEBUG_TAG(sctp, "SCTP_RESET_STREAMS sent [streamId:%" PRIu16 "]", streamId);
}
else
{
MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_RESET_STREAMS) failed: %s", std::strerror(errno));
}
std::free(srs);
}
std::free(srs);
}
void SctpAssociation::AddOutgoingStreams(bool force)
{
MS_TRACE();
void SctpAssociation::AddOutgoingStreams(bool force)
{
MS_TRACE();
uint16_t additionalOs{ 0 };
uint16_t additionalOs{ 0 };
if (MaxSctpStreams - this->os >= 32)
additionalOs = 32;
else
additionalOs = MaxSctpStreams - this->os;
if (MaxSctpStreams - this->os >= 32)
additionalOs = 32;
else
additionalOs = MaxSctpStreams - this->os;
if (additionalOs == 0)
{
MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os);
if (additionalOs == 0)
{
MS_WARN_TAG(sctp, "cannot add more outgoing streams [OS:%" PRIu16 "]", this->os);
return;
}
return;
}
auto nextDesiredOs = this->os + additionalOs;
auto nextDesiredOs = this->os + additionalOs;
// Already in progress, ignore (unless forced).
if (!force && nextDesiredOs == this->desiredOs)
return;
// Already in progress, ignore (unless forced).
if (!force && nextDesiredOs == this->desiredOs)
return;
// Update desired value.
this->desiredOs = nextDesiredOs;
// Update desired value.
this->desiredOs = nextDesiredOs;
// If not connected, defer it.
if (this->state != SctpState::CONNECTED)
{
MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase");
// If not connected, defer it.
if (this->state != SctpState::CONNECTED)
{
MS_DEBUG_TAG(sctp, "SCTP not connected, deferring OS increase");
return;
}
return;
}
struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init)
struct sctp_add_streams sas; // NOLINT(cppcoreguidelines-pro-type-member-init)
std::memset(&sas, 0, sizeof(sas));
sas.sas_instrms = 0;
sas.sas_outstrms = additionalOs;
std::memset(&sas, 0, sizeof(sas));
sas.sas_instrms = 0;
sas.sas_outstrms = additionalOs;
MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs);
MS_DEBUG_TAG(sctp, "adding %" PRIu16 " outgoing streams", additionalOs);
int ret = usrsctp_setsockopt(
this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast<socklen_t>(sizeof(sas)));
int ret = usrsctp_setsockopt(
this->socket, IPPROTO_SCTP, SCTP_ADD_STREAMS, &sas, static_cast<socklen_t>(sizeof(sas)));
if (ret < 0)
MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno));
}
if (ret < 0)
MS_WARN_TAG(sctp, "usrsctp_setsockopt(SCTP_ADD_STREAMS) failed: %s", std::strerror(errno));
}
void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len)
{
MS_TRACE();
void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len)
{
MS_TRACE();
const uint8_t* data = static_cast<uint8_t*>(buffer);
const uint8_t* data = static_cast<uint8_t*>(buffer);
#if MS_LOG_DEV_LEVEL == 3
MS_DUMP_DATA(data, len);
MS_DUMP_DATA(data, len);
#endif
this->listener->OnSctpAssociationSendData(this, data, len);
}
void SctpAssociation::OnUsrSctpReceiveSctpData(
uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len)
{
// Ignore WebRTC DataChannel Control DATA chunks.
if (ppid == 50)
{
MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)");
return;
}
if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived)
{
MS_WARN_TAG(
sctp,
"message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16
", last ssn received:%" PRIu16 "]",
ssn,
this->lastSsnReceived);
this->messageBufferLen = 0;
}
// Update last SSN received.
this->lastSsnReceived = ssn;
auto eor = static_cast<bool>(flags & MSG_EOR);
if (this->messageBufferLen + len > this->maxSctpMessageSize)
{
MS_WARN_TAG(
sctp,
"ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]",
this->messageBufferLen + len,
this->maxSctpMessageSize,
eor ? 1 : 0);
this->lastSsnReceived = 0;
return;
}
// If end of message and there is no buffered data, notify it directly.
if (eor && this->messageBufferLen == 0)
{
MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]");
this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len);
}
// If end of message and there is buffered data, append data and notify buffer.
else if (eor && this->messageBufferLen != 0)
{
std::memcpy(this->messageBuffer + this->messageBufferLen, data, len);
this->messageBufferLen += len;
MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen);
this->listener->OnSctpAssociationMessageReceived(
this, streamId, ppid, this->messageBuffer, this->messageBufferLen);
this->messageBufferLen = 0;
}
// If non end of message, append data to the buffer.
else if (!eor)
{
// Allocate the buffer if not already done.
if (!this->messageBuffer)
this->messageBuffer = new uint8_t[this->maxSctpMessageSize];
std::memcpy(this->messageBuffer + this->messageBufferLen, data, len);
this->messageBufferLen += len;
MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen);
}
}
void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len)
{
if (notification->sn_header.sn_length != (uint32_t)len)
return;
switch (notification->sn_header.sn_type)
{
case SCTP_ADAPTATION_INDICATION:
{
MS_DEBUG_TAG(
sctp,
"SCTP adaptation indication [%x]",
notification->sn_adaptation_event.sai_adaptation_ind);
break;
}
case SCTP_ASSOC_CHANGE:
{
switch (notification->sn_assoc_change.sac_state)
{
case SCTP_COMM_UP:
{
MS_DEBUG_TAG(
sctp,
"SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]",
notification->sn_assoc_change.sac_outbound_streams,
notification->sn_assoc_change.sac_inbound_streams);
// Update our OS.
this->os = notification->sn_assoc_change.sac_outbound_streams;
// Increase if requested before connected.
if (this->desiredOs > this->os)
AddOutgoingStreams(/*force*/ true);
if (this->state != SctpState::CONNECTED)
{
this->state = SctpState::CONNECTED;
this->listener->OnSctpAssociationConnected(this);
}
break;
}
case SCTP_COMM_LOST:
{
if (notification->sn_header.sn_length > 0)
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len = notification->sn_header.sn_length;
for (uint32_t i{ 0 }; i < len; ++i)
{
std::snprintf(
buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]);
}
MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer);
}
else
{
MS_DEBUG_TAG(sctp, "SCTP communication lost");
}
if (this->state != SctpState::CLOSED)
{
this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this);
}
break;
}
case SCTP_RESTART:
{
MS_DEBUG_TAG(
sctp,
"SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]",
notification->sn_assoc_change.sac_outbound_streams,
notification->sn_assoc_change.sac_inbound_streams);
// Update our OS.
this->os = notification->sn_assoc_change.sac_outbound_streams;
// Increase if requested before connected.
if (this->desiredOs > this->os)
AddOutgoingStreams(/*force*/ true);
if (this->state != SctpState::CONNECTED)
{
this->state = SctpState::CONNECTED;
this->listener->OnSctpAssociationConnected(this);
}
break;
}
case SCTP_SHUTDOWN_COMP:
{
MS_DEBUG_TAG(sctp, "SCTP association gracefully closed");
if (this->state != SctpState::CLOSED)
{
this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this);
}
break;
}
case SCTP_CANT_STR_ASSOC:
{
if (notification->sn_header.sn_length > 0)
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len = notification->sn_header.sn_length;
for (uint32_t i{ 0 }; i < len; ++i)
{
std::snprintf(
buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]);
}
MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer);
}
if (this->state != SctpState::FAILED)
{
this->state = SctpState::FAILED;
this->listener->OnSctpAssociationFailed(this);
}
break;
}
default:;
}
break;
}
// https://tools.ietf.org/html/rfc6525#section-6.1.2.
case SCTP_ASSOC_RESET_EVENT:
{
MS_DEBUG_TAG(sctp, "SCTP association reset event received");
break;
}
// An Operation Error is not considered fatal in and of itself, but may be
// used with an ABORT chunk to report a fatal condition.
case SCTP_REMOTE_ERROR:
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error);
for (uint32_t i{ 0 }; i < len; i++)
{
std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]);
}
MS_WARN_TAG(
sctp,
"remote SCTP association error [type:0x%04x, data:%s]",
notification->sn_remote_error.sre_error,
buffer);
break;
}
// When a peer sends a SHUTDOWN, SCTP delivers this notification to
// inform the application that it should cease sending data.
case SCTP_SHUTDOWN_EVENT:
{
MS_DEBUG_TAG(sctp, "remote SCTP association shutdown");
if (this->state != SctpState::CLOSED)
{
this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this);
}
break;
}
case SCTP_SEND_FAILED_EVENT:
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len =
notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event);
for (uint32_t i{ 0 }; i < len; ++i)
{
std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]);
}
MS_WARN_TAG(
sctp,
"SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32
", sent:%s, error:0x%08x, info:%s]",
notification->sn_send_failed_event.ssfe_info.snd_sid,
ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid),
(notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no",
notification->sn_send_failed_event.ssfe_error,
buffer);
break;
}
case SCTP_STREAM_RESET_EVENT:
{
bool incoming{ false };
bool outgoing{ false };
uint16_t numStreams =
(notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) /
sizeof(uint16_t);
if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN)
incoming = true;
if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN)
outgoing = true;
this->listener->OnSctpAssociationSendData(this, data, len);
}
void SctpAssociation::OnUsrSctpReceiveSctpData(
uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len)
{
// Ignore WebRTC DataChannel Control DATA chunks.
if (ppid == 50)
{
MS_WARN_TAG(sctp, "ignoring SCTP data with ppid:50 (WebRTC DataChannel Control)");
return;
}
if (this->messageBufferLen != 0 && ssn != this->lastSsnReceived)
{
MS_WARN_TAG(
sctp,
"message chunk received with different SSN while buffer not empty, buffer discarded [ssn:%" PRIu16
", last ssn received:%" PRIu16 "]",
ssn,
this->lastSsnReceived);
this->messageBufferLen = 0;
}
// Update last SSN received.
this->lastSsnReceived = ssn;
auto eor = static_cast<bool>(flags & MSG_EOR);
if (this->messageBufferLen + len > this->maxSctpMessageSize)
{
MS_WARN_TAG(
sctp,
"ongoing received message exceeds max allowed message size [message size:%zu, max message size:%zu, eor:%u]",
this->messageBufferLen + len,
this->maxSctpMessageSize,
eor ? 1 : 0);
this->lastSsnReceived = 0;
return;
}
// If end of message and there is no buffered data, notify it directly.
if (eor && this->messageBufferLen == 0)
{
MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]");
this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len);
}
// If end of message and there is buffered data, append data and notify buffer.
else if (eor && this->messageBufferLen != 0)
{
std::memcpy(this->messageBuffer + this->messageBufferLen, data, len);
this->messageBufferLen += len;
MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen);
this->listener->OnSctpAssociationMessageReceived(
this, streamId, ppid, this->messageBuffer, this->messageBufferLen);
this->messageBufferLen = 0;
}
// If non end of message, append data to the buffer.
else if (!eor)
{
// Allocate the buffer if not already done.
if (!this->messageBuffer)
this->messageBuffer = new uint8_t[this->maxSctpMessageSize];
std::memcpy(this->messageBuffer + this->messageBufferLen, data, len);
this->messageBufferLen += len;
MS_DEBUG_DEV("data buffered [eor:0, buffer len:%zu]", this->messageBufferLen);
}
}
void SctpAssociation::OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len)
{
if (notification->sn_header.sn_length != (uint32_t)len)
return;
switch (notification->sn_header.sn_type)
{
case SCTP_ADAPTATION_INDICATION:
{
MS_DEBUG_TAG(
sctp,
"SCTP adaptation indication [%x]",
notification->sn_adaptation_event.sai_adaptation_ind);
break;
}
case SCTP_ASSOC_CHANGE:
{
switch (notification->sn_assoc_change.sac_state)
{
case SCTP_COMM_UP:
{
MS_DEBUG_TAG(
sctp,
"SCTP association connected, streams [out:%" PRIu16 ", in:%" PRIu16 "]",
notification->sn_assoc_change.sac_outbound_streams,
notification->sn_assoc_change.sac_inbound_streams);
// Update our OS.
this->os = notification->sn_assoc_change.sac_outbound_streams;
// Increase if requested before connected.
if (this->desiredOs > this->os)
AddOutgoingStreams(/*force*/ true);
if (this->state != SctpState::CONNECTED)
{
this->state = SctpState::CONNECTED;
this->listener->OnSctpAssociationConnected(this);
}
break;
}
case SCTP_COMM_LOST:
{
if (notification->sn_header.sn_length > 0)
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len = notification->sn_header.sn_length;
for (uint32_t i{ 0 }; i < len; ++i)
{
std::snprintf(
buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]);
}
MS_DEBUG_TAG(sctp, "SCTP communication lost [info:%s]", buffer);
}
else
{
MS_DEBUG_TAG(sctp, "SCTP communication lost");
}
if (this->state != SctpState::CLOSED)
{
this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this);
}
break;
}
case SCTP_RESTART:
{
MS_DEBUG_TAG(
sctp,
"SCTP remote association restarted, streams [out:%" PRIu16 ", int:%" PRIu16 "]",
notification->sn_assoc_change.sac_outbound_streams,
notification->sn_assoc_change.sac_inbound_streams);
// Update our OS.
this->os = notification->sn_assoc_change.sac_outbound_streams;
// Increase if requested before connected.
if (this->desiredOs > this->os)
AddOutgoingStreams(/*force*/ true);
if (this->state != SctpState::CONNECTED)
{
this->state = SctpState::CONNECTED;
this->listener->OnSctpAssociationConnected(this);
}
break;
}
case SCTP_SHUTDOWN_COMP:
{
MS_DEBUG_TAG(sctp, "SCTP association gracefully closed");
if (this->state != SctpState::CLOSED)
{
this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this);
}
break;
}
case SCTP_CANT_STR_ASSOC:
{
if (notification->sn_header.sn_length > 0)
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len = notification->sn_header.sn_length;
for (uint32_t i{ 0 }; i < len; ++i)
{
std::snprintf(
buffer, BufferSize, " 0x%02x", notification->sn_assoc_change.sac_info[i]);
}
MS_WARN_TAG(sctp, "SCTP setup failed: %s", buffer);
}
if (this->state != SctpState::FAILED)
{
this->state = SctpState::FAILED;
this->listener->OnSctpAssociationFailed(this);
}
break;
}
default:;
}
break;
}
// https://tools.ietf.org/html/rfc6525#section-6.1.2.
case SCTP_ASSOC_RESET_EVENT:
{
MS_DEBUG_TAG(sctp, "SCTP association reset event received");
break;
}
// An Operation Error is not considered fatal in and of itself, but may be
// used with an ABORT chunk to report a fatal condition.
case SCTP_REMOTE_ERROR:
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len = notification->sn_remote_error.sre_length - sizeof(struct sctp_remote_error);
for (uint32_t i{ 0 }; i < len; i++)
{
std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_remote_error.sre_data[i]);
}
MS_WARN_TAG(
sctp,
"remote SCTP association error [type:0x%04x, data:%s]",
notification->sn_remote_error.sre_error,
buffer);
break;
}
// When a peer sends a SHUTDOWN, SCTP delivers this notification to
// inform the application that it should cease sending data.
case SCTP_SHUTDOWN_EVENT:
{
MS_DEBUG_TAG(sctp, "remote SCTP association shutdown");
if (this->state != SctpState::CLOSED)
{
this->state = SctpState::CLOSED;
this->listener->OnSctpAssociationClosed(this);
}
break;
}
case SCTP_SEND_FAILED_EVENT:
{
static const size_t BufferSize{ 1024 };
static char buffer[BufferSize];
uint32_t len =
notification->sn_send_failed_event.ssfe_length - sizeof(struct sctp_send_failed_event);
for (uint32_t i{ 0 }; i < len; ++i)
{
std::snprintf(buffer, BufferSize, "0x%02x", notification->sn_send_failed_event.ssfe_data[i]);
}
MS_WARN_TAG(
sctp,
"SCTP message sent failure [streamId:%" PRIu16 ", ppid:%" PRIu32
", sent:%s, error:0x%08x, info:%s]",
notification->sn_send_failed_event.ssfe_info.snd_sid,
ntohl(notification->sn_send_failed_event.ssfe_info.snd_ppid),
(notification->sn_send_failed_event.ssfe_flags & SCTP_DATA_SENT) ? "yes" : "no",
notification->sn_send_failed_event.ssfe_error,
buffer);
break;
}
case SCTP_STREAM_RESET_EVENT:
{
bool incoming{ false };
bool outgoing{ false };
uint16_t numStreams =
(notification->sn_strreset_event.strreset_length - sizeof(struct sctp_stream_reset_event)) /
sizeof(uint16_t);
if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN)
incoming = true;
if (notification->sn_strreset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN)
outgoing = true;
//todo 打印sctp调试信息
if (false /*MS_HAS_DEBUG_TAG(sctp)*/)
{
std::string streamIds;
for (uint16_t i{ 0 }; i < numStreams; ++i)
{
auto streamId = notification->sn_strreset_event.strreset_stream_list[i];
// Don't log more than 5 stream ids.
if (i > 4)
{
streamIds.append("...");
break;
}
if (i > 0)
streamIds.append(",");
streamIds.append(std::to_string(streamId));
}
MS_DEBUG_TAG(
sctp,
"SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]",
notification->sn_strreset_event.strreset_flags,
incoming ? "true" : "false",
outgoing ? "true" : "false",
numStreams,
streamIds.c_str());
}
// Special case for WebRTC DataChannels in which we must also reset our
// outgoing SCTP stream.
if (incoming && !outgoing && this->isDataChannel)
{
for (uint16_t i{ 0 }; i < numStreams; ++i)
{
auto streamId = notification->sn_strreset_event.strreset_stream_list[i];
ResetSctpStream(streamId, StreamDirection::OUTGOING);
}
}
break;
}
case SCTP_STREAM_CHANGE_EVENT:
{
if (notification->sn_strchange_event.strchange_flags == 0)
{
MS_DEBUG_TAG(
sctp,
"SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags);
}
else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED)
{
MS_WARN_TAG(
sctp,
"SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags);
break;
}
else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED)
{
MS_WARN_TAG(
sctp,
"SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags);
break;
}
// Update OS.
this->os = notification->sn_strchange_event.strchange_outstrms;
break;
}
default:
{
MS_WARN_TAG(
sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type);
}
}
}
if (false /*MS_HAS_DEBUG_TAG(sctp)*/)
{
std::string streamIds;
for (uint16_t i{ 0 }; i < numStreams; ++i)
{
auto streamId = notification->sn_strreset_event.strreset_stream_list[i];
// Don't log more than 5 stream ids.
if (i > 4)
{
streamIds.append("...");
break;
}
if (i > 0)
streamIds.append(",");
streamIds.append(std::to_string(streamId));
}
MS_DEBUG_TAG(
sctp,
"SCTP stream reset event [flags:%x, i|o:%s|%s, num streams:%" PRIu16 ", stream ids:%s]",
notification->sn_strreset_event.strreset_flags,
incoming ? "true" : "false",
outgoing ? "true" : "false",
numStreams,
streamIds.c_str());
}
// Special case for WebRTC DataChannels in which we must also reset our
// outgoing SCTP stream.
if (incoming && !outgoing && this->isDataChannel)
{
for (uint16_t i{ 0 }; i < numStreams; ++i)
{
auto streamId = notification->sn_strreset_event.strreset_stream_list[i];
ResetSctpStream(streamId, StreamDirection::OUTGOING);
}
}
break;
}
case SCTP_STREAM_CHANGE_EVENT:
{
if (notification->sn_strchange_event.strchange_flags == 0)
{
MS_DEBUG_TAG(
sctp,
"SCTP stream changed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags);
}
else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_DENIED)
{
MS_WARN_TAG(
sctp,
"SCTP stream change denied, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags);
break;
}
else if (notification->sn_strchange_event.strchange_flags & SCTP_STREAM_RESET_FAILED)
{
MS_WARN_TAG(
sctp,
"SCTP stream change failed, streams [out:%" PRIu16 ", in:%" PRIu16 ", flags:%x]",
notification->sn_strchange_event.strchange_outstrms,
notification->sn_strchange_event.strchange_instrms,
notification->sn_strchange_event.strchange_flags);
break;
}
// Update OS.
this->os = notification->sn_strchange_event.strchange_outstrms;
break;
}
default:
{
MS_WARN_TAG(
sctp, "unhandled SCTP event received [type:%" PRIu16 "]", notification->sn_header.sn_type);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -18,104 +18,104 @@ namespace RTC
uint16_t maxRetransmits{ 0u };
};
class SctpAssociation
{
public:
enum class SctpState
{
NEW = 1,
CONNECTING,
CONNECTED,
FAILED,
CLOSED
};
private:
enum class StreamDirection
{
INCOMING = 1,
OUTGOING
};
public:
class Listener
{
public:
virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationSendData(
RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0;
virtual void OnSctpAssociationMessageReceived(
RTC::SctpAssociation* sctpAssociation,
uint16_t streamId,
uint32_t ppid,
const uint8_t* msg,
size_t len) = 0;
};
public:
static bool IsSctp(const uint8_t* data, size_t len)
{
// clang-format off
return (
(len >= 12) &&
// Must have Source Port Number and Destination Port Number set to 5000 (hack).
(Utils::Byte::Get2Bytes(data, 0) == 5000) &&
(Utils::Byte::Get2Bytes(data, 2) == 5000)
);
// clang-format on
}
public:
SctpAssociation(
Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel);
virtual ~SctpAssociation();
public:
void TransportConnected();
size_t GetMaxSctpMessageSize() const
{
return this->maxSctpMessageSize;
}
SctpState GetState() const
{
return this->state;
}
void ProcessSctpData(const uint8_t* data, size_t len);
void SendSctpMessage(const RTC::SctpStreamParameters &params, uint32_t ppid, const uint8_t* msg, size_t len);
void HandleDataConsumer(const RTC::SctpStreamParameters &params);
void DataProducerClosed(const RTC::SctpStreamParameters &params);
void DataConsumerClosed(const RTC::SctpStreamParameters &params);
private:
void ResetSctpStream(uint16_t streamId, StreamDirection);
void AddOutgoingStreams(bool force = false);
public:
class SctpAssociation
{
public:
enum class SctpState
{
NEW = 1,
CONNECTING,
CONNECTED,
FAILED,
CLOSED
};
private:
enum class StreamDirection
{
INCOMING = 1,
OUTGOING
};
public:
class Listener
{
public:
virtual void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationFailed(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationClosed(RTC::SctpAssociation* sctpAssociation) = 0;
virtual void OnSctpAssociationSendData(
RTC::SctpAssociation* sctpAssociation, const uint8_t* data, size_t len) = 0;
virtual void OnSctpAssociationMessageReceived(
RTC::SctpAssociation* sctpAssociation,
uint16_t streamId,
uint32_t ppid,
const uint8_t* msg,
size_t len) = 0;
};
public:
static bool IsSctp(const uint8_t* data, size_t len)
{
// clang-format off
return (
(len >= 12) &&
// Must have Source Port Number and Destination Port Number set to 5000 (hack).
(Utils::Byte::Get2Bytes(data, 0) == 5000) &&
(Utils::Byte::Get2Bytes(data, 2) == 5000)
);
// clang-format on
}
public:
SctpAssociation(
Listener* listener, uint16_t os, uint16_t mis, size_t maxSctpMessageSize, bool isDataChannel);
virtual ~SctpAssociation();
public:
void TransportConnected();
size_t GetMaxSctpMessageSize() const
{
return this->maxSctpMessageSize;
}
SctpState GetState() const
{
return this->state;
}
void ProcessSctpData(const uint8_t* data, size_t len);
void SendSctpMessage(const RTC::SctpStreamParameters &params, uint32_t ppid, const uint8_t* msg, size_t len);
void HandleDataConsumer(const RTC::SctpStreamParameters &params);
void DataProducerClosed(const RTC::SctpStreamParameters &params);
void DataConsumerClosed(const RTC::SctpStreamParameters &params);
private:
void ResetSctpStream(uint16_t streamId, StreamDirection);
void AddOutgoingStreams(bool force = false);
public:
/* Callbacks fired by usrsctp events. */
virtual void OnUsrSctpSendSctpData(void* buffer, size_t len);
virtual void OnUsrSctpReceiveSctpData(uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len);
virtual void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len);
private:
// Passed by argument.
Listener* listener{ nullptr };
uint16_t os{ 1024u };
uint16_t mis{ 1024u };
size_t maxSctpMessageSize{ 262144u };
bool isDataChannel{ false };
// Allocated by this.
uint8_t* messageBuffer{ nullptr };
// Others.
SctpState state{ SctpState::NEW };
struct socket* socket{ nullptr };
uint16_t desiredOs{ 0u };
size_t messageBufferLen{ 0u };
uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support.
private:
// Passed by argument.
Listener* listener{ nullptr };
uint16_t os{ 1024u };
uint16_t mis{ 1024u };
size_t maxSctpMessageSize{ 262144u };
bool isDataChannel{ false };
// Allocated by this.
uint8_t* messageBuffer{ nullptr };
// Others.
SctpState state{ SctpState::NEW };
struct socket* socket{ nullptr };
uint16_t desiredOs{ 0u };
size_t messageBufferLen{ 0u };
uint16_t lastSsnReceived{ 0u }; // Valid for us since no SCTP I-DATA support.
std::shared_ptr<SctpEnv> _env;
};
};
//保证线程安全
class SctpAssociationImp : public SctpAssociation, public std::enable_shared_from_this<SctpAssociationImp>{
......
......@@ -97,785 +97,785 @@ namespace RTC
return str;
}
/* Class variables. */
const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 };
/* Class methods. */
StunPacket* StunPacket::Parse(const uint8_t* data, size_t len)
{
MS_TRACE();
if (!StunPacket::IsStun(data, len))
return nullptr;
/*
The message type field is decomposed further into the following
structure:
0 1
2 3 4 5 6 7 8 9 0 1 2 3 4 5
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|M |M |M|M|M|C|M|M|M|C|M|M|M|M|
|11|10|9|8|7|1|6|5|4|0|3|2|1|0|
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 3: Format of STUN Message Type Field
Here the bits in the message type field are shown as most significant
(M11) through least significant (M0). M11 through M0 represent a 12-
bit encoding of the method. C1 and C0 represent a 2-bit encoding of
the class.
*/
// Get type field.
uint16_t msgType = Utils::Byte::Get2Bytes(data, 0);
// Get length field.
uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2);
// length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes.
if ((static_cast<size_t>(msgLength) != len - 20) || ((msgLength & 0x03) != 0))
{
MS_WARN_TAG(
ice,
"length field + 20 does not match total size (or it is not multiple of 4 bytes), "
"packet discarded");
return nullptr;
}
// Get STUN method.
uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2);
// Get STUN class.
uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4);
// Create a new StunPacket (data + 8 points to the received TransactionID field).
auto* packet = new StunPacket(
static_cast<Class>(msgClass), static_cast<Method>(msgMethod), data + 8, data, len);
/*
STUN Attributes
After the STUN header are zero or more attributes. Each attribute
MUST be TLV encoded, with a 16-bit type, 16-bit length, and value.
Each STUN attribute MUST end on a 32-bit boundary. As mentioned
above, all fields in an attribute are transmitted most significant
bit first.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Value (variable) ....
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
// Start looking for attributes after STUN header (Byte #20).
size_t pos{ 20 };
// Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes.
bool hasMessageIntegrity{ false };
bool hasFingerprint{ false };
size_t fingerprintAttrPos; // Will point to the beginning of the attribute.
uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute.
// Ensure there are at least 4 remaining bytes (attribute with 0 length).
while (pos + 4 <= len)
{
// Get the attribute type.
auto attrType = static_cast<Attribute>(Utils::Byte::Get2Bytes(data, pos));
// Get the attribute length.
uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2);
// Ensure the attribute length is not greater than the remaining size.
if ((pos + 4 + attrLength) > len)
{
MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded");
delete packet;
return nullptr;
}
// FINGERPRINT must be the last attribute.
if (hasFingerprint)
{
MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded");
delete packet;
return nullptr;
}
// After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed.
if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT)
{
MS_WARN_TAG(
ice,
"attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, "
"packet discarded");
delete packet;
return nullptr;
}
const uint8_t* attrValuePos = data + pos + 4;
switch (attrType)
{
case Attribute::USERNAME:
{
packet->SetUsername(
reinterpret_cast<const char*>(attrValuePos), static_cast<size_t>(attrLength));
break;
}
case Attribute::PRIORITY:
{
// Ensure attribute length is 4 bytes.
if (attrLength != 4)
{
MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0));
break;
}
case Attribute::ICE_CONTROLLING:
{
// Ensure attribute length is 8 bytes.
if (attrLength != 8)
{
MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0));
break;
}
case Attribute::ICE_CONTROLLED:
{
// Ensure attribute length is 8 bytes.
if (attrLength != 8)
{
MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0));
break;
}
case Attribute::USE_CANDIDATE:
{
// Ensure attribute length is 0 bytes.
if (attrLength != 0)
{
MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetUseCandidate();
break;
}
case Attribute::MESSAGE_INTEGRITY:
{
// Ensure attribute length is 20 bytes.
if (attrLength != 20)
{
MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded");
delete packet;
return nullptr;
}
hasMessageIntegrity = true;
packet->SetMessageIntegrity(attrValuePos);
break;
}
case Attribute::FINGERPRINT:
{
// Ensure attribute length is 4 bytes.
if (attrLength != 4)
{
MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded");
delete packet;
return nullptr;
}
hasFingerprint = true;
fingerprintAttrPos = pos;
fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0);
packet->SetFingerprint();
break;
}
case Attribute::ERROR_CODE:
{
// Ensure attribute length >= 4bytes.
if (attrLength < 4)
{
MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded");
delete packet;
return nullptr;
}
uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2);
uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3);
auto errorCode = static_cast<uint16_t>(errorClass * 100 + errorNumber);
packet->SetErrorCode(errorCode);
break;
}
default:;
}
// Set next attribute position.
pos =
static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(pos + 4 + attrLength)));
}
// Ensure current position matches the total length.
if (pos != len)
{
MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded");
delete packet;
return nullptr;
}
// If it has FINGERPRINT attribute then verify it.
if (hasFingerprint)
{
// Compute the CRC32 of the received packet up to (but excluding) the
// FINGERPRINT attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e;
// Compare with the FINGERPRINT value in the packet.
if (fingerprint != computedFingerprint)
{
MS_WARN_TAG(
ice,
"computed FINGERPRINT value does not match the value in the packet, "
"packet discarded");
delete packet;
return nullptr;
}
}
return packet;
}
/* Instance methods. */
StunPacket::StunPacket(
Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size)
: klass(klass), method(method), transactionId(transactionId), data(const_cast<uint8_t*>(data)),
size(size)
{
MS_TRACE();
}
StunPacket::~StunPacket()
{
MS_TRACE();
}
/* Class variables. */
const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 };
/* Class methods. */
StunPacket* StunPacket::Parse(const uint8_t* data, size_t len)
{
MS_TRACE();
if (!StunPacket::IsStun(data, len))
return nullptr;
/*
The message type field is decomposed further into the following
structure:
0 1
2 3 4 5 6 7 8 9 0 1 2 3 4 5
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
|M |M |M|M|M|C|M|M|M|C|M|M|M|M|
|11|10|9|8|7|1|6|5|4|0|3|2|1|0|
+--+--+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 3: Format of STUN Message Type Field
Here the bits in the message type field are shown as most significant
(M11) through least significant (M0). M11 through M0 represent a 12-
bit encoding of the method. C1 and C0 represent a 2-bit encoding of
the class.
*/
// Get type field.
uint16_t msgType = Utils::Byte::Get2Bytes(data, 0);
// Get length field.
uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2);
// length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes.
if ((static_cast<size_t>(msgLength) != len - 20) || ((msgLength & 0x03) != 0))
{
MS_WARN_TAG(
ice,
"length field + 20 does not match total size (or it is not multiple of 4 bytes), "
"packet discarded");
return nullptr;
}
// Get STUN method.
uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2);
// Get STUN class.
uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4);
// Create a new StunPacket (data + 8 points to the received TransactionID field).
auto* packet = new StunPacket(
static_cast<Class>(msgClass), static_cast<Method>(msgMethod), data + 8, data, len);
/*
STUN Attributes
After the STUN header are zero or more attributes. Each attribute
MUST be TLV encoded, with a 16-bit type, 16-bit length, and value.
Each STUN attribute MUST end on a 32-bit boundary. As mentioned
above, all fields in an attribute are transmitted most significant
bit first.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Value (variable) ....
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
// Start looking for attributes after STUN header (Byte #20).
size_t pos{ 20 };
// Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes.
bool hasMessageIntegrity{ false };
bool hasFingerprint{ false };
size_t fingerprintAttrPos; // Will point to the beginning of the attribute.
uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute.
// Ensure there are at least 4 remaining bytes (attribute with 0 length).
while (pos + 4 <= len)
{
// Get the attribute type.
auto attrType = static_cast<Attribute>(Utils::Byte::Get2Bytes(data, pos));
// Get the attribute length.
uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2);
// Ensure the attribute length is not greater than the remaining size.
if ((pos + 4 + attrLength) > len)
{
MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded");
delete packet;
return nullptr;
}
// FINGERPRINT must be the last attribute.
if (hasFingerprint)
{
MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded");
delete packet;
return nullptr;
}
// After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed.
if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT)
{
MS_WARN_TAG(
ice,
"attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, "
"packet discarded");
delete packet;
return nullptr;
}
const uint8_t* attrValuePos = data + pos + 4;
switch (attrType)
{
case Attribute::USERNAME:
{
packet->SetUsername(
reinterpret_cast<const char*>(attrValuePos), static_cast<size_t>(attrLength));
break;
}
case Attribute::PRIORITY:
{
// Ensure attribute length is 4 bytes.
if (attrLength != 4)
{
MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0));
break;
}
case Attribute::ICE_CONTROLLING:
{
// Ensure attribute length is 8 bytes.
if (attrLength != 8)
{
MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0));
break;
}
case Attribute::ICE_CONTROLLED:
{
// Ensure attribute length is 8 bytes.
if (attrLength != 8)
{
MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0));
break;
}
case Attribute::USE_CANDIDATE:
{
// Ensure attribute length is 0 bytes.
if (attrLength != 0)
{
MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded");
delete packet;
return nullptr;
}
packet->SetUseCandidate();
break;
}
case Attribute::MESSAGE_INTEGRITY:
{
// Ensure attribute length is 20 bytes.
if (attrLength != 20)
{
MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded");
delete packet;
return nullptr;
}
hasMessageIntegrity = true;
packet->SetMessageIntegrity(attrValuePos);
break;
}
case Attribute::FINGERPRINT:
{
// Ensure attribute length is 4 bytes.
if (attrLength != 4)
{
MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded");
delete packet;
return nullptr;
}
hasFingerprint = true;
fingerprintAttrPos = pos;
fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0);
packet->SetFingerprint();
break;
}
case Attribute::ERROR_CODE:
{
// Ensure attribute length >= 4bytes.
if (attrLength < 4)
{
MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded");
delete packet;
return nullptr;
}
uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2);
uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3);
auto errorCode = static_cast<uint16_t>(errorClass * 100 + errorNumber);
packet->SetErrorCode(errorCode);
break;
}
default:;
}
// Set next attribute position.
pos =
static_cast<size_t>(Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(pos + 4 + attrLength)));
}
// Ensure current position matches the total length.
if (pos != len)
{
MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded");
delete packet;
return nullptr;
}
// If it has FINGERPRINT attribute then verify it.
if (hasFingerprint)
{
// Compute the CRC32 of the received packet up to (but excluding) the
// FINGERPRINT attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e;
// Compare with the FINGERPRINT value in the packet.
if (fingerprint != computedFingerprint)
{
MS_WARN_TAG(
ice,
"computed FINGERPRINT value does not match the value in the packet, "
"packet discarded");
delete packet;
return nullptr;
}
}
return packet;
}
/* Instance methods. */
StunPacket::StunPacket(
Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size)
: klass(klass), method(method), transactionId(transactionId), data(const_cast<uint8_t*>(data)),
size(size)
{
MS_TRACE();
}
StunPacket::~StunPacket()
{
MS_TRACE();
}
#if 0
void StunPacket::Dump() const
{
MS_TRACE();
MS_DUMP("<StunPacket>");
std::string klass;
switch (this->klass)
{
case Class::REQUEST:
klass = "Request";
break;
case Class::INDICATION:
klass = "Indication";
break;
case Class::SUCCESS_RESPONSE:
klass = "SuccessResponse";
break;
case Class::ERROR_RESPONSE:
klass = "ErrorResponse";
break;
}
if (this->method == Method::BINDING)
{
MS_DUMP(" Binding %s", klass.c_str());
}
else
{
// This prints the unknown method number. Example: TURN Allocate => 0x003.
MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast<uint16_t>(this->method));
}
MS_DUMP(" size: %zu bytes", this->size);
static char transactionId[25];
for (int i{ 0 }; i < 12; ++i)
{
// NOTE: n must be 3 because snprintf adds a \0 after printed chars.
std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]);
}
MS_DUMP(" transactionId: %s", transactionId);
if (this->errorCode != 0u)
MS_DUMP(" errorCode: %" PRIu16, this->errorCode);
if (!this->username.empty())
MS_DUMP(" username: %s", this->username.c_str());
if (this->priority != 0u)
MS_DUMP(" priority: %" PRIu32, this->priority);
if (this->iceControlling != 0u)
MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling);
if (this->iceControlled != 0u)
MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled);
if (this->hasUseCandidate)
MS_DUMP(" useCandidate");
if (this->xorMappedAddress != nullptr)
{
int family;
uint16_t port;
std::string ip;
Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port);
MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port);
}
if (this->messageIntegrity != nullptr)
{
static char messageIntegrity[41];
for (int i{ 0 }; i < 20; ++i)
{
std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]);
}
MS_DUMP(" messageIntegrity: %s", messageIntegrity);
}
if (this->hasFingerprint)
MS_DUMP(" has fingerprint");
MS_DUMP("</StunPacket>");
}
void StunPacket::Dump() const
{
MS_TRACE();
MS_DUMP("<StunPacket>");
std::string klass;
switch (this->klass)
{
case Class::REQUEST:
klass = "Request";
break;
case Class::INDICATION:
klass = "Indication";
break;
case Class::SUCCESS_RESPONSE:
klass = "SuccessResponse";
break;
case Class::ERROR_RESPONSE:
klass = "ErrorResponse";
break;
}
if (this->method == Method::BINDING)
{
MS_DUMP(" Binding %s", klass.c_str());
}
else
{
// This prints the unknown method number. Example: TURN Allocate => 0x003.
MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast<uint16_t>(this->method));
}
MS_DUMP(" size: %zu bytes", this->size);
static char transactionId[25];
for (int i{ 0 }; i < 12; ++i)
{
// NOTE: n must be 3 because snprintf adds a \0 after printed chars.
std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]);
}
MS_DUMP(" transactionId: %s", transactionId);
if (this->errorCode != 0u)
MS_DUMP(" errorCode: %" PRIu16, this->errorCode);
if (!this->username.empty())
MS_DUMP(" username: %s", this->username.c_str());
if (this->priority != 0u)
MS_DUMP(" priority: %" PRIu32, this->priority);
if (this->iceControlling != 0u)
MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling);
if (this->iceControlled != 0u)
MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled);
if (this->hasUseCandidate)
MS_DUMP(" useCandidate");
if (this->xorMappedAddress != nullptr)
{
int family;
uint16_t port;
std::string ip;
Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port);
MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port);
}
if (this->messageIntegrity != nullptr)
{
static char messageIntegrity[41];
for (int i{ 0 }; i < 20; ++i)
{
std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]);
}
MS_DUMP(" messageIntegrity: %s", messageIntegrity);
}
if (this->hasFingerprint)
MS_DUMP(" has fingerprint");
MS_DUMP("</StunPacket>");
}
#endif
StunPacket::Authentication StunPacket::CheckAuthentication(
const std::string& localUsername, const std::string& localPassword)
{
MS_TRACE();
switch (this->klass)
{
case Class::REQUEST:
case Class::INDICATION:
{
// Both USERNAME and MESSAGE-INTEGRITY must be present.
if (!this->messageIntegrity || this->username.empty())
return Authentication::BAD_REQUEST;
// Check that USERNAME attribute begins with our local username plus ":".
size_t localUsernameLen = localUsername.length();
if (
this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' ||
(this->username.compare(0, localUsernameLen, localUsername) != 0))
{
return Authentication::UNAUTHORIZED;
}
break;
}
// This method cannot check authentication in received responses (as we
// are ICE-Lite and don't generate requests).
case Class::SUCCESS_RESPONSE:
case Class::ERROR_RESPONSE:
{
MS_ERROR("cannot check authentication for a STUN response");
return Authentication::BAD_REQUEST;
}
}
// If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation,
// so the header length field must be modified (and later restored).
if (this->hasFingerprint)
// Set the header length field: full size - header length (20) - FINGERPRINT length (8).
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules.
StunPacket::Authentication StunPacket::CheckAuthentication(
const std::string& localUsername, const std::string& localPassword)
{
MS_TRACE();
switch (this->klass)
{
case Class::REQUEST:
case Class::INDICATION:
{
// Both USERNAME and MESSAGE-INTEGRITY must be present.
if (!this->messageIntegrity || this->username.empty())
return Authentication::BAD_REQUEST;
// Check that USERNAME attribute begins with our local username plus ":".
size_t localUsernameLen = localUsername.length();
if (
this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' ||
(this->username.compare(0, localUsernameLen, localUsername) != 0))
{
return Authentication::UNAUTHORIZED;
}
break;
}
// This method cannot check authentication in received responses (as we
// are ICE-Lite and don't generate requests).
case Class::SUCCESS_RESPONSE:
case Class::ERROR_RESPONSE:
{
MS_ERROR("cannot check authentication for a STUN response");
return Authentication::BAD_REQUEST;
}
}
// If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation,
// so the header length field must be modified (and later restored).
if (this->hasFingerprint)
// Set the header length field: full size - header length (20) - FINGERPRINT length (8).
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules.
auto computedMessageIntegrity = openssl_HMACsha1(
localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data);
Authentication result;
Authentication result;
// Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet.
if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0)
result = Authentication::OK;
else
result = Authentication::UNAUTHORIZED;
// Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet.
if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0)
result = Authentication::OK;
else
result = Authentication::UNAUTHORIZED;
// Restore the header length field.
if (this->hasFingerprint)
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20));
// Restore the header length field.
if (this->hasFingerprint)
Utils::Byte::Set2Bytes(this->data, 2, static_cast<uint16_t>(this->size - 20));
return result;
}
return result;
}
StunPacket* StunPacket::CreateSuccessResponse()
{
MS_TRACE();
MS_ASSERT(
this->klass == Class::REQUEST,
"attempt to create a success response for a non Request STUN packet");
return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0);
}
StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode)
{
MS_TRACE();
StunPacket* StunPacket::CreateSuccessResponse()
{
MS_TRACE();
MS_ASSERT(
this->klass == Class::REQUEST,
"attempt to create an error response for a non Request STUN packet");
MS_ASSERT(
this->klass == Class::REQUEST,
"attempt to create a success response for a non Request STUN packet");
auto* response =
new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0);
return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0);
}
response->SetErrorCode(errorCode);
StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode)
{
MS_TRACE();
return response;
}
MS_ASSERT(
this->klass == Class::REQUEST,
"attempt to create an error response for a non Request STUN packet");
void StunPacket::Authenticate(const std::string& password)
{
// Just for Request, Indication and SuccessResponse messages.
if (this->klass == Class::ERROR_RESPONSE)
{
MS_ERROR("cannot set password for ErrorResponse messages");
auto* response =
new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0);
return;
}
response->SetErrorCode(errorCode);
this->password = password;
}
return response;
}
void StunPacket::Serialize(uint8_t* buffer)
{
MS_TRACE();
// Some useful variables.
uint16_t usernamePaddedLen{ 0 };
uint16_t xorMappedAddressPaddedLen{ 0 };
bool addXorMappedAddress =
((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING &&
this->klass == Class::SUCCESS_RESPONSE);
bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE);
bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty());
bool addFingerprint{ true }; // Do always.
// Update data pointer.
this->data = buffer;
// First calculate the total required size for the entire packet.
this->size = 20; // Header.
if (!this->username.empty())
{
usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(this->username.length()));
this->size += 4 + usernamePaddedLen;
}
if (this->priority != 0u)
this->size += 4 + 4;
if (this->iceControlling != 0u)
this->size += 4 + 8;
if (this->iceControlled != 0u)
this->size += 4 + 8;
if (this->hasUseCandidate)
this->size += 4;
if (addXorMappedAddress)
{
switch (this->xorMappedAddress->sa_family)
{
case AF_INET:
{
xorMappedAddressPaddedLen = 8;
this->size += 4 + 8;
break;
}
case AF_INET6:
{
xorMappedAddressPaddedLen = 20;
this->size += 4 + 20;
break;
}
default:
{
MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute");
addXorMappedAddress = false;
}
}
}
if (addErrorCode)
this->size += 4 + 4;
if (addMessageIntegrity)
this->size += 4 + 20;
if (addFingerprint)
this->size += 4 + 4;
// Merge class and method fields into type.
uint16_t typeField = (static_cast<uint16_t>(this->method) & 0x0f80) << 2;
typeField |= (static_cast<uint16_t>(this->method) & 0x0070) << 1;
typeField |= (static_cast<uint16_t>(this->method) & 0x000f);
typeField |= (static_cast<uint16_t>(this->klass) & 0x02) << 7;
typeField |= (static_cast<uint16_t>(this->klass) & 0x01) << 4;
// Set type field.
Utils::Byte::Set2Bytes(buffer, 0, typeField);
// Set length field.
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size) - 20);
// Set magic cookie.
std::memcpy(buffer + 4, StunPacket::magicCookie, 4);
// Set TransactionId field.
std::memcpy(buffer + 8, this->transactionId, 12);
// Update the transaction ID pointer.
this->transactionId = buffer + 8;
// Add atributes.
size_t pos{ 20 };
// Add USERNAME.
if (usernamePaddedLen != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USERNAME));
Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast<uint16_t>(this->username.length()));
std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length());
pos += 4 + usernamePaddedLen;
}
void StunPacket::Authenticate(const std::string& password)
{
// Just for Request, Indication and SuccessResponse messages.
if (this->klass == Class::ERROR_RESPONSE)
{
MS_ERROR("cannot set password for ErrorResponse messages");
return;
}
this->password = password;
}
void StunPacket::Serialize(uint8_t* buffer)
{
MS_TRACE();
// Some useful variables.
uint16_t usernamePaddedLen{ 0 };
uint16_t xorMappedAddressPaddedLen{ 0 };
bool addXorMappedAddress =
((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING &&
this->klass == Class::SUCCESS_RESPONSE);
bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE);
bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty());
bool addFingerprint{ true }; // Do always.
// Update data pointer.
this->data = buffer;
// First calculate the total required size for the entire packet.
this->size = 20; // Header.
if (!this->username.empty())
{
usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast<uint16_t>(this->username.length()));
this->size += 4 + usernamePaddedLen;
}
if (this->priority != 0u)
this->size += 4 + 4;
if (this->iceControlling != 0u)
this->size += 4 + 8;
if (this->iceControlled != 0u)
this->size += 4 + 8;
if (this->hasUseCandidate)
this->size += 4;
if (addXorMappedAddress)
{
switch (this->xorMappedAddress->sa_family)
{
case AF_INET:
{
xorMappedAddressPaddedLen = 8;
this->size += 4 + 8;
break;
}
case AF_INET6:
{
xorMappedAddressPaddedLen = 20;
this->size += 4 + 20;
break;
}
default:
{
MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute");
addXorMappedAddress = false;
}
}
}
if (addErrorCode)
this->size += 4 + 4;
if (addMessageIntegrity)
this->size += 4 + 20;
if (addFingerprint)
this->size += 4 + 4;
// Merge class and method fields into type.
uint16_t typeField = (static_cast<uint16_t>(this->method) & 0x0f80) << 2;
typeField |= (static_cast<uint16_t>(this->method) & 0x0070) << 1;
typeField |= (static_cast<uint16_t>(this->method) & 0x000f);
typeField |= (static_cast<uint16_t>(this->klass) & 0x02) << 7;
typeField |= (static_cast<uint16_t>(this->klass) & 0x01) << 4;
// Set type field.
Utils::Byte::Set2Bytes(buffer, 0, typeField);
// Set length field.
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size) - 20);
// Set magic cookie.
std::memcpy(buffer + 4, StunPacket::magicCookie, 4);
// Set TransactionId field.
std::memcpy(buffer + 8, this->transactionId, 12);
// Update the transaction ID pointer.
this->transactionId = buffer + 8;
// Add atributes.
size_t pos{ 20 };
// Add USERNAME.
if (usernamePaddedLen != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USERNAME));
Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast<uint16_t>(this->username.length()));
std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length());
pos += 4 + usernamePaddedLen;
}
// Add PRIORITY.
if (this->priority != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::PRIORITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority);
pos += 4 + 4;
}
// Add ICE-CONTROLLING.
if (this->iceControlling != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLING));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling);
pos += 4 + 8;
}
// Add ICE-CONTROLLED.
if (this->iceControlled != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLED));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled);
pos += 4 + 8;
}
// Add USE-CANDIDATE.
if (this->hasUseCandidate)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USE_CANDIDATE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 0);
pos += 4;
}
// Add XOR-MAPPED-ADDRESS
if (addXorMappedAddress)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::XOR_MAPPED_ADDRESS));
Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen);
uint8_t* attrValue = buffer + pos + 4;
switch (this->xorMappedAddress->sa_family)
{
case AF_INET:
{
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x01;
// Set port and XOR it.
std::memcpy(
attrValue + 2,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_port,
2);
attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_addr.s_addr,
4);
attrValue[4] ^= StunPacket::magicCookie[0];
attrValue[5] ^= StunPacket::magicCookie[1];
attrValue[6] ^= StunPacket::magicCookie[2];
attrValue[7] ^= StunPacket::magicCookie[3];
pos += 4 + 8;
break;
}
case AF_INET6:
{
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x02;
// Set port and XOR it.
std::memcpy(
attrValue + 2,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_port,
2);
attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_addr.s6_addr,
16);
attrValue[4] ^= StunPacket::magicCookie[0];
attrValue[5] ^= StunPacket::magicCookie[1];
attrValue[6] ^= StunPacket::magicCookie[2];
attrValue[7] ^= StunPacket::magicCookie[3];
attrValue[8] ^= this->transactionId[0];
attrValue[9] ^= this->transactionId[1];
attrValue[10] ^= this->transactionId[2];
attrValue[11] ^= this->transactionId[3];
attrValue[12] ^= this->transactionId[4];
attrValue[13] ^= this->transactionId[5];
attrValue[14] ^= this->transactionId[6];
attrValue[15] ^= this->transactionId[7];
attrValue[16] ^= this->transactionId[8];
attrValue[17] ^= this->transactionId[9];
attrValue[18] ^= this->transactionId[10];
attrValue[19] ^= this->transactionId[11];
pos += 4 + 20;
break;
}
}
}
// Add ERROR-CODE.
if (addErrorCode)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ERROR_CODE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
auto codeClass = static_cast<uint8_t>(this->errorCode / 100);
uint8_t codeNumber = static_cast<uint8_t>(this->errorCode) - (codeClass * 100);
Utils::Byte::Set2Bytes(buffer, pos + 4, 0);
Utils::Byte::Set1Byte(buffer, pos + 6, codeClass);
Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber);
pos += 4 + 4;
}
// Add MESSAGE-INTEGRITY.
if (addMessageIntegrity)
{
// Ignore FINGERPRINT.
if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules.
// Add PRIORITY.
if (this->priority != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::PRIORITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority);
pos += 4 + 4;
}
// Add ICE-CONTROLLING.
if (this->iceControlling != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLING));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling);
pos += 4 + 8;
}
// Add ICE-CONTROLLED.
if (this->iceControlled != 0u)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ICE_CONTROLLED));
Utils::Byte::Set2Bytes(buffer, pos + 2, 8);
Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled);
pos += 4 + 8;
}
// Add USE-CANDIDATE.
if (this->hasUseCandidate)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::USE_CANDIDATE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 0);
pos += 4;
}
// Add XOR-MAPPED-ADDRESS
if (addXorMappedAddress)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::XOR_MAPPED_ADDRESS));
Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen);
uint8_t* attrValue = buffer + pos + 4;
switch (this->xorMappedAddress->sa_family)
{
case AF_INET:
{
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x01;
// Set port and XOR it.
std::memcpy(
attrValue + 2,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_port,
2);
attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in*>(this->xorMappedAddress))->sin_addr.s_addr,
4);
attrValue[4] ^= StunPacket::magicCookie[0];
attrValue[5] ^= StunPacket::magicCookie[1];
attrValue[6] ^= StunPacket::magicCookie[2];
attrValue[7] ^= StunPacket::magicCookie[3];
pos += 4 + 8;
break;
}
case AF_INET6:
{
// Set first byte to 0.
attrValue[0] = 0;
// Set inet family.
attrValue[1] = 0x02;
// Set port and XOR it.
std::memcpy(
attrValue + 2,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_port,
2);
attrValue[2] ^= StunPacket::magicCookie[0];
attrValue[3] ^= StunPacket::magicCookie[1];
// Set address and XOR it.
std::memcpy(
attrValue + 4,
&(reinterpret_cast<const sockaddr_in6*>(this->xorMappedAddress))->sin6_addr.s6_addr,
16);
attrValue[4] ^= StunPacket::magicCookie[0];
attrValue[5] ^= StunPacket::magicCookie[1];
attrValue[6] ^= StunPacket::magicCookie[2];
attrValue[7] ^= StunPacket::magicCookie[3];
attrValue[8] ^= this->transactionId[0];
attrValue[9] ^= this->transactionId[1];
attrValue[10] ^= this->transactionId[2];
attrValue[11] ^= this->transactionId[3];
attrValue[12] ^= this->transactionId[4];
attrValue[13] ^= this->transactionId[5];
attrValue[14] ^= this->transactionId[6];
attrValue[15] ^= this->transactionId[7];
attrValue[16] ^= this->transactionId[8];
attrValue[17] ^= this->transactionId[9];
attrValue[18] ^= this->transactionId[10];
attrValue[19] ^= this->transactionId[11];
pos += 4 + 20;
break;
}
}
}
// Add ERROR-CODE.
if (addErrorCode)
{
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::ERROR_CODE));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
auto codeClass = static_cast<uint8_t>(this->errorCode / 100);
uint8_t codeNumber = static_cast<uint8_t>(this->errorCode) - (codeClass * 100);
Utils::Byte::Set2Bytes(buffer, pos + 4, 0);
Utils::Byte::Set1Byte(buffer, pos + 6, codeClass);
Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber);
pos += 4 + 4;
}
// Add MESSAGE-INTEGRITY.
if (addMessageIntegrity)
{
// Ignore FINGERPRINT.
if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20 - 8));
// Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules.
auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos);
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::MESSAGE_INTEGRITY));
Utils::Byte::Set2Bytes(buffer, pos + 2, 20);
std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size());
// Update the pointer.
this->messageIntegrity = buffer + pos + 4;
pos += 4 + 20;
// Restore length field.
if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20));
}
else
{
// Unset the pointer (if it was set).
this->messageIntegrity = nullptr;
}
// Add FINGERPRINT.
if (addFingerprint)
{
// Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT
// attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e;
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint);
pos += 4 + 4;
// Set flag.
this->hasFingerprint = true;
}
else
{
this->hasFingerprint = false;
}
MS_ASSERT(pos == this->size, "pos != this->size");
}
Utils::Byte::Set2Bytes(buffer, pos + 2, 20);
std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size());
// Update the pointer.
this->messageIntegrity = buffer + pos + 4;
pos += 4 + 20;
// Restore length field.
if (addFingerprint)
Utils::Byte::Set2Bytes(buffer, 2, static_cast<uint16_t>(this->size - 20));
}
else
{
// Unset the pointer (if it was set).
this->messageIntegrity = nullptr;
}
// Add FINGERPRINT.
if (addFingerprint)
{
// Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT
// attribute and XOR it with 0x5354554e.
uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e;
Utils::Byte::Set2Bytes(buffer, pos, static_cast<uint16_t>(Attribute::FINGERPRINT));
Utils::Byte::Set2Bytes(buffer, pos + 2, 4);
Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint);
pos += 4 + 4;
// Set flag.
this->hasFingerprint = true;
}
else
{
this->hasFingerprint = false;
}
MS_ASSERT(pos == this->size, "pos != this->size");
}
} // namespace RTC
......@@ -26,188 +26,188 @@ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
namespace RTC
{
class StunPacket
{
public:
// STUN message class.
enum class Class : uint16_t
{
REQUEST = 0,
INDICATION = 1,
SUCCESS_RESPONSE = 2,
ERROR_RESPONSE = 3
};
// STUN message method.
enum class Method : uint16_t
{
BINDING = 1
};
// Attribute type.
enum class Attribute : uint16_t
{
MAPPED_ADDRESS = 0x0001,
USERNAME = 0x0006,
MESSAGE_INTEGRITY = 0x0008,
ERROR_CODE = 0x0009,
UNKNOWN_ATTRIBUTES = 0x000A,
REALM = 0x0014,
NONCE = 0x0015,
XOR_MAPPED_ADDRESS = 0x0020,
PRIORITY = 0x0024,
USE_CANDIDATE = 0x0025,
SOFTWARE = 0x8022,
ALTERNATE_SERVER = 0x8023,
FINGERPRINT = 0x8028,
ICE_CONTROLLED = 0x8029,
ICE_CONTROLLING = 0x802A
};
// Authentication result.
enum class Authentication
{
OK = 0,
UNAUTHORIZED = 1,
BAD_REQUEST = 2
};
public:
static bool IsStun(const uint8_t* data, size_t len)
{
// clang-format off
return (
// STUN headers are 20 bytes.
(len >= 20) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] < 3) &&
// Magic cookie must match.
(data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) &&
(data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3])
);
// clang-format on
}
static StunPacket* Parse(const uint8_t* data, size_t len);
private:
static const uint8_t magicCookie[];
public:
StunPacket(
Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size);
~StunPacket();
void Dump() const;
Class GetClass() const
{
return this->klass;
}
Method GetMethod() const
{
return this->method;
}
const uint8_t* GetData() const
{
return this->data;
}
size_t GetSize() const
{
return this->size;
}
void SetUsername(const char* username, size_t len)
{
this->username.assign(username, len);
}
void SetPriority(uint32_t priority)
{
this->priority = priority;
}
void SetIceControlling(uint64_t iceControlling)
{
this->iceControlling = iceControlling;
}
void SetIceControlled(uint64_t iceControlled)
{
this->iceControlled = iceControlled;
}
void SetUseCandidate()
{
this->hasUseCandidate = true;
}
void SetXorMappedAddress(const struct sockaddr* xorMappedAddress)
{
this->xorMappedAddress = xorMappedAddress;
}
void SetErrorCode(uint16_t errorCode)
{
this->errorCode = errorCode;
}
void SetMessageIntegrity(const uint8_t* messageIntegrity)
{
this->messageIntegrity = messageIntegrity;
}
void SetFingerprint()
{
this->hasFingerprint = true;
}
const std::string& GetUsername() const
{
return this->username;
}
uint32_t GetPriority() const
{
return this->priority;
}
uint64_t GetIceControlling() const
{
return this->iceControlling;
}
uint64_t GetIceControlled() const
{
return this->iceControlled;
}
bool HasUseCandidate() const
{
return this->hasUseCandidate;
}
uint16_t GetErrorCode() const
{
return this->errorCode;
}
bool HasMessageIntegrity() const
{
return (this->messageIntegrity ? true : false);
}
bool HasFingerprint() const
{
return this->hasFingerprint;
}
Authentication CheckAuthentication(
const std::string& localUsername, const std::string& localPassword);
StunPacket* CreateSuccessResponse();
StunPacket* CreateErrorResponse(uint16_t errorCode);
void Authenticate(const std::string& password);
void Serialize(uint8_t* buffer);
private:
// Passed by argument.
Class klass; // 2 bytes.
Method method; // 2 bytes.
const uint8_t* transactionId{ nullptr }; // 12 bytes.
uint8_t* data{ nullptr }; // Pointer to binary data.
size_t size{ 0u }; // The full message size (including header).
// STUN attributes.
std::string username; // Less than 513 bytes.
uint32_t priority{ 0u }; // 4 bytes unsigned integer.
uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer.
uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer.
bool hasUseCandidate{ false }; // 0 bytes.
const uint8_t* messageIntegrity{ nullptr }; // 20 bytes.
bool hasFingerprint{ false }; // 4 bytes.
const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes.
uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase).
std::string password;
};
class StunPacket
{
public:
// STUN message class.
enum class Class : uint16_t
{
REQUEST = 0,
INDICATION = 1,
SUCCESS_RESPONSE = 2,
ERROR_RESPONSE = 3
};
// STUN message method.
enum class Method : uint16_t
{
BINDING = 1
};
// Attribute type.
enum class Attribute : uint16_t
{
MAPPED_ADDRESS = 0x0001,
USERNAME = 0x0006,
MESSAGE_INTEGRITY = 0x0008,
ERROR_CODE = 0x0009,
UNKNOWN_ATTRIBUTES = 0x000A,
REALM = 0x0014,
NONCE = 0x0015,
XOR_MAPPED_ADDRESS = 0x0020,
PRIORITY = 0x0024,
USE_CANDIDATE = 0x0025,
SOFTWARE = 0x8022,
ALTERNATE_SERVER = 0x8023,
FINGERPRINT = 0x8028,
ICE_CONTROLLED = 0x8029,
ICE_CONTROLLING = 0x802A
};
// Authentication result.
enum class Authentication
{
OK = 0,
UNAUTHORIZED = 1,
BAD_REQUEST = 2
};
public:
static bool IsStun(const uint8_t* data, size_t len)
{
// clang-format off
return (
// STUN headers are 20 bytes.
(len >= 20) &&
// DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes
(data[0] < 3) &&
// Magic cookie must match.
(data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) &&
(data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3])
);
// clang-format on
}
static StunPacket* Parse(const uint8_t* data, size_t len);
private:
static const uint8_t magicCookie[];
public:
StunPacket(
Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size);
~StunPacket();
void Dump() const;
Class GetClass() const
{
return this->klass;
}
Method GetMethod() const
{
return this->method;
}
const uint8_t* GetData() const
{
return this->data;
}
size_t GetSize() const
{
return this->size;
}
void SetUsername(const char* username, size_t len)
{
this->username.assign(username, len);
}
void SetPriority(uint32_t priority)
{
this->priority = priority;
}
void SetIceControlling(uint64_t iceControlling)
{
this->iceControlling = iceControlling;
}
void SetIceControlled(uint64_t iceControlled)
{
this->iceControlled = iceControlled;
}
void SetUseCandidate()
{
this->hasUseCandidate = true;
}
void SetXorMappedAddress(const struct sockaddr* xorMappedAddress)
{
this->xorMappedAddress = xorMappedAddress;
}
void SetErrorCode(uint16_t errorCode)
{
this->errorCode = errorCode;
}
void SetMessageIntegrity(const uint8_t* messageIntegrity)
{
this->messageIntegrity = messageIntegrity;
}
void SetFingerprint()
{
this->hasFingerprint = true;
}
const std::string& GetUsername() const
{
return this->username;
}
uint32_t GetPriority() const
{
return this->priority;
}
uint64_t GetIceControlling() const
{
return this->iceControlling;
}
uint64_t GetIceControlled() const
{
return this->iceControlled;
}
bool HasUseCandidate() const
{
return this->hasUseCandidate;
}
uint16_t GetErrorCode() const
{
return this->errorCode;
}
bool HasMessageIntegrity() const
{
return (this->messageIntegrity ? true : false);
}
bool HasFingerprint() const
{
return this->hasFingerprint;
}
Authentication CheckAuthentication(
const std::string& localUsername, const std::string& localPassword);
StunPacket* CreateSuccessResponse();
StunPacket* CreateErrorResponse(uint16_t errorCode);
void Authenticate(const std::string& password);
void Serialize(uint8_t* buffer);
private:
// Passed by argument.
Class klass; // 2 bytes.
Method method; // 2 bytes.
const uint8_t* transactionId{ nullptr }; // 12 bytes.
uint8_t* data{ nullptr }; // Pointer to binary data.
size_t size{ 0u }; // The full message size (including header).
// STUN attributes.
std::string username; // Less than 513 bytes.
uint32_t priority{ 0u }; // 4 bytes unsigned integer.
uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer.
uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer.
bool hasUseCandidate{ false }; // 0 bytes.
const uint8_t* messageIntegrity{ nullptr }; // 20 bytes.
bool hasFingerprint{ false }; // 4 bytes.
const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes.
uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase).
std::string password;
};
} // namespace RTC
#endif
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论