diff --git a/src/srtp/srtcp.c b/src/srtp/srtcp.c index 23f48f4..94155e0 100644 --- a/src/srtp/srtcp.c +++ b/src/srtp/srtcp.c @@ -47,9 +47,9 @@ int srtcp_encrypt(struct srtp *srtp, struct mbuf *mb) if (err) return err; - strm = stream_get(srtp, ssrc); - if (!strm) - return ENOSR; + err = stream_get(&strm, srtp, ssrc); + if (err) + return err; strm->rtcp_index = (strm->rtcp_index+1) & 0x7fffffff; @@ -116,9 +116,9 @@ int srtcp_decrypt(struct srtp *srtp, struct mbuf *mb) if (err) return err; - strm = stream_get(srtp, ssrc); - if (!strm) - return ENOSR; + err = stream_get(&strm, srtp, ssrc); + if (err) + return err; pld_start = mb->pos; diff --git a/src/srtp/srtp.c b/src/srtp/srtp.c index 3cfc1c7..1e99e0e 100644 --- a/src/srtp/srtp.c +++ b/src/srtp/srtp.c @@ -163,9 +163,9 @@ int srtp_encrypt(struct srtp *srtp, struct mbuf *mb) if (err) return err; - strm = stream_get_seq(srtp, hdr.ssrc, hdr.seq); - if (!strm) - return ENOSR; + err = stream_get_seq(&strm, srtp, hdr.ssrc, hdr.seq); + if (err) + return err; /* Roll-Over Counter (ROC) */ if (seq_diff(strm->s_l, hdr.seq) <= -32768) { @@ -241,9 +241,9 @@ int srtp_decrypt(struct srtp *srtp, struct mbuf *mb) if (err) return err; - strm = stream_get_seq(srtp, hdr.ssrc, hdr.seq); - if (!strm) - return ENOSR; + err = stream_get_seq(&strm, srtp, hdr.ssrc, hdr.seq); + if (err) + return err; diff = seq_diff(strm->s_l, hdr.seq); if (diff > 32768) diff --git a/src/srtp/srtp.h b/src/srtp/srtp.h index 9cb94a3..a937337 100644 --- a/src/srtp/srtp.h +++ b/src/srtp/srtp.h @@ -49,9 +49,9 @@ struct srtp { }; -struct srtp_stream *stream_get(struct srtp *srtp, uint32_t ssrc); -struct srtp_stream *stream_get_seq(struct srtp *srtp, uint32_t ssrc, - uint16_t seq); +int stream_get(struct srtp_stream **strmp, struct srtp *srtp, uint32_t ssrc); +int stream_get_seq(struct srtp_stream **strmp, struct srtp *srtp, + uint32_t ssrc, uint16_t seq); int srtp_derive(uint8_t *out, size_t out_len, uint8_t label, diff --git a/src/srtp/stream.c b/src/srtp/stream.c index 78524cf..abc6c48 100644 --- a/src/srtp/stream.c +++ b/src/srtp/stream.c @@ -12,9 +12,9 @@ /** SRTP protocol values */ -enum { - MAX_STREAMS = 8, /**< Maximum number of SRTP streams */ -}; +#ifndef SRTP_MAX_STREAMS +#define SRTP_MAX_STREAMS (8) /**< Maximum number of SRTP streams */ +#endif static void stream_destructor(void *arg) @@ -41,16 +41,17 @@ static struct srtp_stream *stream_find(struct srtp *srtp, uint32_t ssrc) } -static struct srtp_stream *stream_new(struct srtp *srtp, uint32_t ssrc) +static int stream_new(struct srtp_stream **strmp, struct srtp *srtp, + uint32_t ssrc) { struct srtp_stream *strm; - if (list_count(&srtp->streaml) >= MAX_STREAMS) - return NULL; + if (list_count(&srtp->streaml) >= SRTP_MAX_STREAMS) + return ENOSR; strm = mem_zalloc(sizeof(*strm), stream_destructor); if (!strm) - return NULL; + return ENOMEM; strm->ssrc = ssrc; srtp_replay_init(&strm->replay_rtp); @@ -58,33 +59,42 @@ static struct srtp_stream *stream_new(struct srtp *srtp, uint32_t ssrc) list_append(&srtp->streaml, &strm->le, strm); - return strm; + if (strmp) + *strmp = strm; + + return 0; } -struct srtp_stream *stream_get(struct srtp *srtp, uint32_t ssrc) +int stream_get(struct srtp_stream **strmp, struct srtp *srtp, uint32_t ssrc) { struct srtp_stream *strm; - if (!srtp) - return NULL; + if (!strmp || !srtp) + return EINVAL; strm = stream_find(srtp, ssrc); - if (strm) - return strm; + if (strm) { + *strmp = strm; + return 0; + } - return stream_new(srtp, ssrc); + return stream_new(strmp, srtp, ssrc); } -struct srtp_stream *stream_get_seq(struct srtp *srtp, uint32_t ssrc, - uint16_t seq) +int stream_get_seq(struct srtp_stream **strmp, struct srtp *srtp, + uint32_t ssrc, uint16_t seq) { struct srtp_stream *strm; + int err; - strm = stream_get(srtp, ssrc); - if (!strm) - return NULL; + if (!strmp || !srtp) + return EINVAL; + + err = stream_get(&strm, srtp, ssrc); + if (err) + return err; /* Set the initial sequence number once only */ if (!strm->s_l_set) { @@ -92,5 +102,7 @@ struct srtp_stream *stream_get_seq(struct srtp *srtp, uint32_t ssrc, strm->s_l_set = true; } - return strm; + *strmp = strm; + + return 0; }