inet_diag: Split inet_diag_get_exact into parts
[deliverable/linux.git] / net / ipv4 / inet_diag.c
index 57a1bd97ea35e69264cf120ab499b5163061666d..f50df2ed9af5f911c82864081dc6af2ccfdc0281 100644 (file)
@@ -46,23 +46,9 @@ struct inet_diag_entry {
        u16 userlocks;
 };
 
-static struct sock *sdiagnl;
-
 #define INET_DIAG_PUT(skb, attrtype, attrlen) \
        RTA_DATA(__RTA_PUT(skb, attrtype, attrlen))
 
-static inline int inet_diag_type2proto(int type)
-{
-       switch (type) {
-       case TCPDIAG_GETSOCK:
-               return IPPROTO_TCP;
-       case DCCPDIAG_GETSOCK:
-               return IPPROTO_DCCP;
-       default:
-               return 0;
-       }
-}
-
 static DEFINE_MUTEX(inet_diag_table_mutex);
 
 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
@@ -85,8 +71,8 @@ static inline void inet_diag_unlock_handler(
 }
 
 static int inet_csk_diag_fill(struct sock *sk,
-                             struct sk_buff *skb,
-                             int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+                             struct sk_buff *skb, struct inet_diag_req *req,
+                             u32 pid, u32 seq, u16 nlmsg_flags,
                              const struct nlmsghdr *unlh)
 {
        const struct inet_sock *inet = inet_sk(sk);
@@ -97,8 +83,9 @@ static int inet_csk_diag_fill(struct sock *sk,
        struct inet_diag_meminfo  *minfo = NULL;
        unsigned char    *b = skb_tail_pointer(skb);
        const struct inet_diag_handler *handler;
+       int ext = req->idiag_ext;
 
-       handler = inet_diag_table[inet_diag_type2proto(unlh->nlmsg_type)];
+       handler = inet_diag_table[req->sdiag_protocol];
        BUG_ON(handler == NULL);
 
        nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
@@ -111,8 +98,7 @@ static int inet_csk_diag_fill(struct sock *sk,
                minfo = INET_DIAG_PUT(skb, INET_DIAG_MEMINFO, sizeof(*minfo));
 
        if (ext & (1 << (INET_DIAG_INFO - 1)))
-               info = INET_DIAG_PUT(skb, INET_DIAG_INFO,
-                                    handler->idiag_info_size);
+               info = INET_DIAG_PUT(skb, INET_DIAG_INFO, sizeof(struct tcp_info));
 
        if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) {
                const size_t len = strlen(icsk->icsk_ca_ops->name);
@@ -198,8 +184,8 @@ nlmsg_failure:
 }
 
 static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
-                              struct sk_buff *skb, int ext, u32 pid,
-                              u32 seq, u16 nlmsg_flags,
+                              struct sk_buff *skb, struct inet_diag_req *req,
+                              u32 pid, u32 seq, u16 nlmsg_flags,
                               const struct nlmsghdr *unlh)
 {
        long tmo;
@@ -250,35 +236,36 @@ nlmsg_failure:
 }
 
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
-                       int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+                       struct inet_diag_req *r, u32 pid, u32 seq, u16 nlmsg_flags,
                        const struct nlmsghdr *unlh)
 {
        if (sk->sk_state == TCP_TIME_WAIT)
                return inet_twsk_diag_fill((struct inet_timewait_sock *)sk,
-                                          skb, ext, pid, seq, nlmsg_flags,
+                                          skb, r, pid, seq, nlmsg_flags,
                                           unlh);
-       return inet_csk_diag_fill(sk, skb, ext, pid, seq, nlmsg_flags, unlh);
+       return inet_csk_diag_fill(sk, skb, r, pid, seq, nlmsg_flags, unlh);
 }
 
-static int inet_diag_get_exact(struct sk_buff *in_skb,
-                              const struct nlmsghdr *nlh,
-                              struct inet_diag_req *req)
+int inet_diag_check_cookie(struct sock *sk, struct inet_diag_req *req)
+{
+       if ((req->id.idiag_cookie[0] != INET_DIAG_NOCOOKIE ||
+            req->id.idiag_cookie[1] != INET_DIAG_NOCOOKIE) &&
+           ((u32)(unsigned long)sk != req->id.idiag_cookie[0] ||
+            (u32)((((unsigned long)sk) >> 31) >> 1) != req->id.idiag_cookie[1]))
+               return -ESTALE;
+       else
+               return 0;
+}
+EXPORT_SYMBOL_GPL(inet_diag_check_cookie);
+
+static int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_skb,
+               const struct nlmsghdr *nlh, struct inet_diag_req *req)
 {
        int err;
        struct sock *sk;
        struct sk_buff *rep;
-       struct inet_hashinfo *hashinfo;
-       const struct inet_diag_handler *handler;
-
-       handler = inet_diag_lock_handler(req->sdiag_protocol);
-       if (IS_ERR(handler)) {
-               err = PTR_ERR(handler);
-               goto unlock;
-       }
 
-       hashinfo = handler->idiag_hashinfo;
        err = -EINVAL;
-
        if (req->sdiag_family == AF_INET) {
                sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0],
                                 req->id.idiag_dport, req->id.idiag_src[0],
@@ -295,29 +282,26 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
        }
 #endif
        else {
-               goto unlock;
+               goto out_nosk;
        }
 
        err = -ENOENT;
        if (sk == NULL)
-               goto unlock;
+               goto out_nosk;
 
-       err = -ESTALE;
-       if ((req->id.idiag_cookie[0] != INET_DIAG_NOCOOKIE ||
-            req->id.idiag_cookie[1] != INET_DIAG_NOCOOKIE) &&
-           ((u32)(unsigned long)sk != req->id.idiag_cookie[0] ||
-            (u32)((((unsigned long)sk) >> 31) >> 1) != req->id.idiag_cookie[1]))
+       err = inet_diag_check_cookie(sk, req);
+       if (err)
                goto out;
 
        err = -ENOMEM;
        rep = alloc_skb(NLMSG_SPACE((sizeof(struct inet_diag_msg) +
                                     sizeof(struct inet_diag_meminfo) +
-                                    handler->idiag_info_size + 64)),
+                                    sizeof(struct tcp_info) + 64)),
                        GFP_KERNEL);
        if (!rep)
                goto out;
 
-       err = sk_diag_fill(sk, rep, req->idiag_ext,
+       err = sk_diag_fill(sk, rep, req,
                           NETLINK_CB(in_skb).pid,
                           nlh->nlmsg_seq, 0, nlh);
        if (err < 0) {
@@ -325,7 +309,7 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
                kfree_skb(rep);
                goto out;
        }
-       err = netlink_unicast(sdiagnl, rep, NETLINK_CB(in_skb).pid,
+       err = netlink_unicast(sock_diag_nlsk, rep, NETLINK_CB(in_skb).pid,
                              MSG_DONTWAIT);
        if (err > 0)
                err = 0;
@@ -337,8 +321,25 @@ out:
                else
                        sock_put(sk);
        }
-unlock:
+out_nosk:
+       return err;
+}
+
+static int inet_diag_get_exact(struct sk_buff *in_skb,
+                              const struct nlmsghdr *nlh,
+                              struct inet_diag_req *req)
+{
+       const struct inet_diag_handler *handler;
+       int err;
+
+       handler = inet_diag_lock_handler(req->sdiag_protocol);
+       if (IS_ERR(handler))
+               err = PTR_ERR(handler);
+       else
+               err = inet_diag_dump_one_icsk(handler->idiag_hashinfo,
+                               in_skb, nlh, req);
        inet_diag_unlock_handler(handler);
+
        return err;
 }
 
@@ -369,9 +370,12 @@ static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
 }
 
 
-static int inet_diag_bc_run(const void *bc, int len,
-                           const struct inet_diag_entry *entry)
+static int inet_diag_bc_run(const struct nlattr *_bc,
+               const struct inet_diag_entry *entry)
 {
+       const void *bc = nla_data(_bc);
+       int len = nla_len(_bc);
+
        while (len > 0) {
                int yes = 1;
                const struct inet_diag_bc_op *op = bc;
@@ -526,11 +530,11 @@ static int inet_csk_diag_dump(struct sock *sk,
                entry.dport = ntohs(inet->inet_dport);
                entry.userlocks = sk->sk_userlocks;
 
-               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
+               if (!inet_diag_bc_run(bc, &entry))
                        return 0;
        }
 
-       return inet_csk_diag_fill(sk, skb, r->idiag_ext,
+       return inet_csk_diag_fill(sk, skb, r,
                                  NETLINK_CB(cb->skb).pid,
                                  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
@@ -561,11 +565,11 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                entry.dport = ntohs(tw->tw_dport);
                entry.userlocks = 0;
 
-               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
+               if (!inet_diag_bc_run(bc, &entry))
                        return 0;
        }
 
-       return inet_twsk_diag_fill(tw, skb, r->idiag_ext,
+       return inet_twsk_diag_fill(tw, skb, r,
                                   NETLINK_CB(cb->skb).pid,
                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
@@ -682,8 +686,7 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
                                        &ireq->rmt_addr;
                                entry.dport = ntohs(ireq->rmt_port);
 
-                               if (!inet_diag_bc_run(nla_data(bc),
-                                                     nla_len(bc), &entry))
+                               if (!inet_diag_bc_run(bc, &entry))
                                        continue;
                        }
 
@@ -706,19 +709,11 @@ out:
        return err;
 }
 
-static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
-               struct inet_diag_req *r, struct nlattr *bc)
+static void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
+               struct netlink_callback *cb, struct inet_diag_req *r, struct nlattr *bc)
 {
        int i, num;
        int s_i, s_num;
-       const struct inet_diag_handler *handler;
-       struct inet_hashinfo *hashinfo;
-
-       handler = inet_diag_lock_handler(r->sdiag_protocol);
-       if (IS_ERR(handler))
-               goto unlock;
-
-       hashinfo = handler->idiag_hashinfo;
 
        s_i = cb->args[1];
        s_num = num = cb->args[2];
@@ -743,6 +738,10 @@ static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
                                        continue;
                                }
 
+                               if (r->sdiag_family != AF_UNSPEC &&
+                                               sk->sk_family != r->sdiag_family)
+                                       goto next_listen;
+
                                if (r->id.idiag_sport != inet->inet_sport &&
                                    r->id.idiag_sport)
                                        goto next_listen;
@@ -783,7 +782,7 @@ skip_listen_ht:
        }
 
        if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
-               goto unlock;
+               goto out;
 
        for (i = s_i; i <= hashinfo->ehash_mask; i++) {
                struct inet_ehash_bucket *head = &hashinfo->ehash[i];
@@ -808,6 +807,9 @@ skip_listen_ht:
                                goto next_normal;
                        if (!(r->idiag_states & (1 << sk->sk_state)))
                                goto next_normal;
+                       if (r->sdiag_family != AF_UNSPEC &&
+                                       sk->sk_family != r->sdiag_family)
+                               goto next_normal;
                        if (r->id.idiag_sport != inet->inet_sport &&
                            r->id.idiag_sport)
                                goto next_normal;
@@ -830,6 +832,9 @@ next_normal:
 
                                if (num < s_num)
                                        goto next_dying;
+                               if (r->sdiag_family != AF_UNSPEC &&
+                                               tw->tw_family != r->sdiag_family)
+                                       goto next_dying;
                                if (r->id.idiag_sport != tw->tw_sport &&
                                    r->id.idiag_sport)
                                        goto next_dying;
@@ -850,8 +855,20 @@ next_dying:
 done:
        cb->args[1] = i;
        cb->args[2] = num;
-unlock:
+out:
+       ;
+}
+
+static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
+               struct inet_diag_req *r, struct nlattr *bc)
+{
+       const struct inet_diag_handler *handler;
+
+       handler = inet_diag_lock_handler(r->sdiag_protocol);
+       if (!IS_ERR(handler))
+               inet_diag_dump_icsk(handler->idiag_hashinfo, skb, cb, r, bc);
        inet_diag_unlock_handler(handler);
+
        return skb->len;
 }
 
@@ -866,6 +883,18 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
        return __inet_diag_dump(skb, cb, (struct inet_diag_req *)NLMSG_DATA(cb->nlh), bc);
 }
 
+static inline int inet_diag_type2proto(int type)
+{
+       switch (type) {
+       case TCPDIAG_GETSOCK:
+               return IPPROTO_TCP;
+       case DCCPDIAG_GETSOCK:
+               return IPPROTO_DCCP;
+       default:
+               return 0;
+       }
+}
+
 static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb)
 {
        struct inet_diag_req_compat *rc = NLMSG_DATA(cb->nlh);
@@ -873,7 +902,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *c
        struct nlattr *bc = NULL;
        int hdrlen = sizeof(struct inet_diag_req_compat);
 
-       req.sdiag_family = rc->idiag_family;
+       req.sdiag_family = AF_UNSPEC; /* compatibility */
        req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
        req.idiag_ext = rc->idiag_ext;
        req.idiag_states = rc->idiag_states;
@@ -920,7 +949,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
                                return -EINVAL;
                }
 
-               return netlink_dump_start(sdiagnl, skb, nlh,
+               return netlink_dump_start(sock_diag_nlsk, skb, nlh,
                                          inet_diag_dump_compat, NULL, 0);
        }
 
@@ -945,7 +974,7 @@ static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
                                return -EINVAL;
                }
 
-               return netlink_dump_start(sdiagnl, skb, h,
+               return netlink_dump_start(sock_diag_nlsk, skb, h,
                                          inet_diag_dump, NULL, 0);
        }
 
@@ -962,91 +991,6 @@ static struct sock_diag_handler inet6_diag_handler = {
        .dump = inet_diag_handler_dump,
 };
 
-static struct sock_diag_handler *sock_diag_handlers[AF_MAX];
-static DEFINE_MUTEX(sock_diag_table_mutex);
-
-int sock_diag_register(struct sock_diag_handler *hndl)
-{
-       int err = 0;
-
-       if (hndl->family > AF_MAX)
-               return -EINVAL;
-
-       mutex_lock(&sock_diag_table_mutex);
-       if (sock_diag_handlers[hndl->family])
-               err = -EBUSY;
-       else
-               sock_diag_handlers[hndl->family] = hndl;
-       mutex_unlock(&sock_diag_table_mutex);
-
-       return err;
-}
-
-void sock_diag_unregister(struct sock_diag_handler *hnld)
-{
-       int family = hnld->family;
-
-       if (family > AF_MAX)
-               return;
-
-       mutex_lock(&sock_diag_table_mutex);
-       BUG_ON(sock_diag_handlers[family] != hnld);
-       sock_diag_handlers[family] = NULL;
-       mutex_unlock(&sock_diag_table_mutex);
-}
-
-static inline struct sock_diag_handler *sock_diag_lock_handler(int family)
-{
-       mutex_lock(&sock_diag_table_mutex);
-       return sock_diag_handlers[family];
-}
-
-static inline void sock_diag_unlock_handler(struct sock_diag_handler *h)
-{
-       mutex_unlock(&sock_diag_table_mutex);
-}
-
-static int __sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
-{
-       int err;
-       struct sock_diag_req *req = NLMSG_DATA(nlh);
-       struct sock_diag_handler *hndl;
-
-       if (nlmsg_len(nlh) < sizeof(*req))
-               return -EINVAL;
-
-       hndl = sock_diag_lock_handler(req->sdiag_family);
-       if (hndl == NULL)
-               err = -ENOENT;
-       else
-               err = hndl->dump(skb, nlh);
-       sock_diag_unlock_handler(hndl);
-
-       return err;
-}
-
-static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
-{
-       switch (nlh->nlmsg_type) {
-       case TCPDIAG_GETSOCK:
-       case DCCPDIAG_GETSOCK:
-               return inet_diag_rcv_msg_compat(skb, nlh);
-       case SOCK_DIAG_BY_FAMILY:
-               return __sock_diag_rcv_msg(skb, nlh);
-       default:
-               return -EINVAL;
-       }
-}
-
-static DEFINE_MUTEX(sock_diag_mutex);
-
-static void sock_diag_rcv(struct sk_buff *skb)
-{
-       mutex_lock(&sock_diag_mutex);
-       netlink_rcv_skb(skb, &sock_diag_rcv_msg);
-       mutex_unlock(&sock_diag_mutex);
-}
-
 int inet_diag_register(const struct inet_diag_handler *h)
 {
        const __u16 type = h->idiag_type;
@@ -1090,11 +1034,6 @@ static int __init inet_diag_init(void)
        if (!inet_diag_table)
                goto out;
 
-       sdiagnl = netlink_kernel_create(&init_net, NETLINK_SOCK_DIAG, 0,
-                                       sock_diag_rcv, NULL, THIS_MODULE);
-       if (sdiagnl == NULL)
-               goto out_free_table;
-
        err = sock_diag_register(&inet_diag_handler);
        if (err)
                goto out_free_nl;
@@ -1103,14 +1042,13 @@ static int __init inet_diag_init(void)
        if (err)
                goto out_free_inet;
 
+       sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
 out:
        return err;
 
 out_free_inet:
        sock_diag_unregister(&inet_diag_handler);
 out_free_nl:
-       netlink_kernel_release(sdiagnl);
-out_free_table:
        kfree(inet_diag_table);
        goto out;
 }
@@ -1119,11 +1057,11 @@ static void __exit inet_diag_exit(void)
 {
        sock_diag_unregister(&inet6_diag_handler);
        sock_diag_unregister(&inet_diag_handler);
-       netlink_kernel_release(sdiagnl);
+       sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
        kfree(inet_diag_table);
 }
 
 module_init(inet_diag_init);
 module_exit(inet_diag_exit);
 MODULE_LICENSE("GPL");
-MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_SOCK_DIAG);
+MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 0);
This page took 0.04568 seconds and 5 git commands to generate.