diff --git a/include/net/sctp/sctp.h b/include/net/sctp/sctp.h index 06b4f515e157..d7d8cba01469 100644 --- a/include/net/sctp/sctp.h +++ b/include/net/sctp/sctp.h @@ -127,7 +127,8 @@ int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *), const union sctp_addr *laddr, const union sctp_addr *paddr, void *p); int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *), - struct net *net, int pos, void *p); + int (*cb_done)(struct sctp_transport *, void *), + struct net *net, int *pos, void *p); int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p); int sctp_get_sctp_info(struct sock *sk, struct sctp_association *asoc, struct sctp_info *info); diff --git a/net/sctp/sctp_diag.c b/net/sctp/sctp_diag.c index e99518e79b52..7008a992749b 100644 --- a/net/sctp/sctp_diag.c +++ b/net/sctp/sctp_diag.c @@ -279,9 +279,11 @@ out: return err; } -static int sctp_sock_dump(struct sock *sk, void *p) +static int sctp_sock_dump(struct sctp_transport *tsp, void *p) { + struct sctp_endpoint *ep = tsp->asoc->ep; struct sctp_comm_param *commp = p; + struct sock *sk = ep->base.sk; struct sk_buff *skb = commp->skb; struct netlink_callback *cb = commp->cb; const struct inet_diag_req_v2 *r = commp->r; @@ -289,9 +291,7 @@ static int sctp_sock_dump(struct sock *sk, void *p) int err = 0; lock_sock(sk); - if (!sctp_sk(sk)->ep) - goto release; - list_for_each_entry(assoc, &sctp_sk(sk)->ep->asocs, asocs) { + list_for_each_entry(assoc, &ep->asocs, asocs) { if (cb->args[4] < cb->args[1]) goto next; @@ -327,40 +327,30 @@ next: cb->args[4]++; } cb->args[1] = 0; - cb->args[2]++; cb->args[3] = 0; cb->args[4] = 0; release: release_sock(sk); - sock_put(sk); return err; } -static int sctp_get_sock(struct sctp_transport *tsp, void *p) +static int sctp_sock_filter(struct sctp_transport *tsp, void *p) { struct sctp_endpoint *ep = tsp->asoc->ep; struct sctp_comm_param *commp = p; struct sock *sk = ep->base.sk; - struct netlink_callback *cb = commp->cb; const struct inet_diag_req_v2 *r = commp->r; struct sctp_association *assoc = list_entry(ep->asocs.next, struct sctp_association, asocs); /* find the ep only once through the transports by this condition */ if (tsp->asoc != assoc) - goto out; + return 0; if (r->sdiag_family != AF_UNSPEC && sk->sk_family != r->sdiag_family) - goto out; - - sock_hold(sk); - cb->args[5] = (long)sk; + return 0; return 1; - -out: - cb->args[2]++; - return 0; } static int sctp_ep_dump(struct sctp_endpoint *ep, void *p) @@ -503,12 +493,8 @@ skip: if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE))) goto done; -next: - cb->args[5] = 0; - sctp_for_each_transport(sctp_get_sock, net, cb->args[2], &commp); - - if (cb->args[5] && !sctp_sock_dump((struct sock *)cb->args[5], &commp)) - goto next; + sctp_for_each_transport(sctp_sock_filter, sctp_sock_dump, + net, (int *)&cb->args[2], &commp); done: cb->args[1] = cb->args[4]; diff --git a/net/sctp/socket.c b/net/sctp/socket.c index 1b00a1e09b93..d4730ada7f32 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -4658,29 +4658,39 @@ int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *), EXPORT_SYMBOL_GPL(sctp_transport_lookup_process); int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *), - struct net *net, int pos, void *p) { + int (*cb_done)(struct sctp_transport *, void *), + struct net *net, int *pos, void *p) { struct rhashtable_iter hti; - void *obj; - int err; + struct sctp_transport *tsp; + int ret; - err = sctp_transport_walk_start(&hti); - if (err) - return err; +again: + ret = sctp_transport_walk_start(&hti); + if (ret) + return ret; - obj = sctp_transport_get_idx(net, &hti, pos + 1); - for (; !IS_ERR_OR_NULL(obj); obj = sctp_transport_get_next(net, &hti)) { - struct sctp_transport *transport = obj; - - if (!sctp_transport_hold(transport)) + tsp = sctp_transport_get_idx(net, &hti, *pos + 1); + for (; !IS_ERR_OR_NULL(tsp); tsp = sctp_transport_get_next(net, &hti)) { + if (!sctp_transport_hold(tsp)) continue; - err = cb(transport, p); - sctp_transport_put(transport); - if (err) + ret = cb(tsp, p); + if (ret) break; + (*pos)++; + sctp_transport_put(tsp); } sctp_transport_walk_stop(&hti); - return err; + if (ret) { + if (cb_done && !cb_done(tsp, p)) { + (*pos)++; + sctp_transport_put(tsp); + goto again; + } + sctp_transport_put(tsp); + } + + return ret; } EXPORT_SYMBOL_GPL(sctp_for_each_transport);