Merge tag 'for-linus' of git://git.kernel.org/pub/scm/virt/kvm/kvm
[deliverable/linux.git] / net / netlink / af_netlink.c
index e6fac7e3db52e5fcb40629a60472ff2c7aa72dcb..c416725d28c49f8b0c1b10bbf35a28594c646bc5 100644 (file)
@@ -58,7 +58,9 @@
 #include <linux/mutex.h>
 #include <linux/vmalloc.h>
 #include <linux/if_arp.h>
+#include <linux/rhashtable.h>
 #include <asm/cacheflush.h>
+#include <linux/hash.h>
 
 #include <net/net_namespace.h>
 #include <net/sock.h>
@@ -100,6 +102,19 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
 
 #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
 
+/* Protects netlink socket hash table mutations */
+DEFINE_MUTEX(nl_sk_hash_lock);
+EXPORT_SYMBOL_GPL(nl_sk_hash_lock);
+
+static int lockdep_nl_sk_hash_is_held(void)
+{
+#ifdef CONFIG_LOCKDEP
+       return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
+#else
+       return 1;
+#endif
+}
+
 static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 
 static DEFINE_SPINLOCK(netlink_tap_lock);
@@ -110,11 +125,6 @@ static inline u32 netlink_group_mask(u32 group)
        return group ? 1 << (group - 1) : 0;
 }
 
-static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid)
-{
-       return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
-}
-
 int netlink_add_tap(struct netlink_tap *nt)
 {
        if (unlikely(nt->dev->type != ARPHRD_NETLINK))
@@ -170,7 +180,6 @@ EXPORT_SYMBOL_GPL(netlink_remove_tap);
 static bool netlink_filter_tap(const struct sk_buff *skb)
 {
        struct sock *sk = skb->sk;
-       bool pass = false;
 
        /* We take the more conservative approach and
         * whitelist socket protocols that may pass.
@@ -184,11 +193,10 @@ static bool netlink_filter_tap(const struct sk_buff *skb)
        case NETLINK_FIB_LOOKUP:
        case NETLINK_NETFILTER:
        case NETLINK_GENERIC:
-               pass = true;
-               break;
+               return true;
        }
 
-       return pass;
+       return false;
 }
 
 static int __netlink_deliver_tap_skb(struct sk_buff *skb,
@@ -205,7 +213,7 @@ static int __netlink_deliver_tap_skb(struct sk_buff *skb,
                nskb->protocol = htons((u16) sk->sk_protocol);
                nskb->pkt_type = netlink_is_kernel(sk) ?
                                 PACKET_KERNEL : PACKET_USER;
-
+               skb_reset_network_header(nskb);
                ret = dev_queue_xmit(nskb);
                if (unlikely(ret > 0))
                        ret = net_xmit_errno(ret);
@@ -376,7 +384,7 @@ static int netlink_set_ring(struct sock *sk, struct nl_mmap_req *req,
 
                if ((int)req->nm_block_size <= 0)
                        return -EINVAL;
-               if (!IS_ALIGNED(req->nm_block_size, PAGE_SIZE))
+               if (!PAGE_ALIGNED(req->nm_block_size))
                        return -EINVAL;
                if (req->nm_frame_size < NL_MMAP_HDRLEN)
                        return -EINVAL;
@@ -985,105 +993,48 @@ netlink_unlock_table(void)
                wake_up(&nl_table_wait);
 }
 
-static bool netlink_compare(struct net *net, struct sock *sk)
+struct netlink_compare_arg
 {
-       return net_eq(sock_net(sk), net);
-}
-
-static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
-{
-       struct netlink_table *table = &nl_table[protocol];
-       struct nl_portid_hash *hash = &table->hash;
-       struct hlist_head *head;
-       struct sock *sk;
-
-       read_lock(&nl_table_lock);
-       head = nl_portid_hashfn(hash, portid);
-       sk_for_each(sk, head) {
-               if (table->compare(net, sk) &&
-                   (nlk_sk(sk)->portid == portid)) {
-                       sock_hold(sk);
-                       goto found;
-               }
-       }
-       sk = NULL;
-found:
-       read_unlock(&nl_table_lock);
-       return sk;
-}
+       struct net *net;
+       u32 portid;
+};
 
-static struct hlist_head *nl_portid_hash_zalloc(size_t size)
+static bool netlink_compare(void *ptr, void *arg)
 {
-       if (size <= PAGE_SIZE)
-               return kzalloc(size, GFP_ATOMIC);
-       else
-               return (struct hlist_head *)
-                       __get_free_pages(GFP_ATOMIC | __GFP_ZERO,
-                                        get_order(size));
-}
+       struct netlink_compare_arg *x = arg;
+       struct sock *sk = ptr;
 
-static void nl_portid_hash_free(struct hlist_head *table, size_t size)
-{
-       if (size <= PAGE_SIZE)
-               kfree(table);
-       else
-               free_pages((unsigned long)table, get_order(size));
+       return nlk_sk(sk)->portid == x->portid &&
+              net_eq(sock_net(sk), x->net);
 }
 
-static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow)
+static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
+                                    struct net *net)
 {
-       unsigned int omask, mask, shift;
-       size_t osize, size;
-       struct hlist_head *otable, *table;
-       int i;
-
-       omask = mask = hash->mask;
-       osize = size = (mask + 1) * sizeof(*table);
-       shift = hash->shift;
-
-       if (grow) {
-               if (++shift > hash->max_shift)
-                       return 0;
-               mask = mask * 2 + 1;
-               size *= 2;
-       }
-
-       table = nl_portid_hash_zalloc(size);
-       if (!table)
-               return 0;
-
-       otable = hash->table;
-       hash->table = table;
-       hash->mask = mask;
-       hash->shift = shift;
-       get_random_bytes(&hash->rnd, sizeof(hash->rnd));
+       struct netlink_compare_arg arg = {
+               .net = net,
+               .portid = portid,
+       };
+       u32 hash;
 
-       for (i = 0; i <= omask; i++) {
-               struct sock *sk;
-               struct hlist_node *tmp;
-
-               sk_for_each_safe(sk, tmp, &otable[i])
-                       __sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid));
-       }
+       hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
 
-       nl_portid_hash_free(otable, osize);
-       hash->rehash_time = jiffies + 10 * 60 * HZ;
-       return 1;
+       return rhashtable_lookup_compare(&table->hash, hash,
+                                        &netlink_compare, &arg);
 }
 
-static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len)
+static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
 {
-       int avg = hash->entries >> hash->shift;
-
-       if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1))
-               return 1;
+       struct netlink_table *table = &nl_table[protocol];
+       struct sock *sk;
 
-       if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) {
-               nl_portid_hash_rehash(hash, 0);
-               return 1;
-       }
+       rcu_read_lock();
+       sk = __netlink_lookup(table, portid, net);
+       if (sk)
+               sock_hold(sk);
+       rcu_read_unlock();
 
-       return 0;
+       return sk;
 }
 
 static const struct proto_ops netlink_ops;
@@ -1115,22 +1066,10 @@ netlink_update_listeners(struct sock *sk)
 static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
 {
        struct netlink_table *table = &nl_table[sk->sk_protocol];
-       struct nl_portid_hash *hash = &table->hash;
-       struct hlist_head *head;
        int err = -EADDRINUSE;
-       struct sock *osk;
-       int len;
 
-       netlink_table_grab();
-       head = nl_portid_hashfn(hash, portid);
-       len = 0;
-       sk_for_each(osk, head) {
-               if (table->compare(net, osk) &&
-                   (nlk_sk(osk)->portid == portid))
-                       break;
-               len++;
-       }
-       if (osk)
+       mutex_lock(&nl_sk_hash_lock);
+       if (__netlink_lookup(table, portid, net))
                goto err;
 
        err = -EBUSY;
@@ -1138,26 +1077,31 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
                goto err;
 
        err = -ENOMEM;
-       if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX))
+       if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
                goto err;
 
-       if (len && nl_portid_hash_dilute(hash, len))
-               head = nl_portid_hashfn(hash, portid);
-       hash->entries++;
        nlk_sk(sk)->portid = portid;
-       sk_add_node(sk, head);
+       sock_hold(sk);
+       rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
        err = 0;
-
 err:
-       netlink_table_ungrab();
+       mutex_unlock(&nl_sk_hash_lock);
        return err;
 }
 
 static void netlink_remove(struct sock *sk)
 {
+       struct netlink_table *table;
+
+       mutex_lock(&nl_sk_hash_lock);
+       table = &nl_table[sk->sk_protocol];
+       if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
+               WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
+               __sock_put(sk);
+       }
+       mutex_unlock(&nl_sk_hash_lock);
+
        netlink_table_grab();
-       if (sk_del_node_init(sk))
-               nl_table[sk->sk_protocol].hash.entries--;
        if (nlk_sk(sk)->subscriptions)
                __sk_del_bind_node(sk);
        netlink_table_ungrab();
@@ -1313,6 +1257,9 @@ static int netlink_release(struct socket *sock)
        }
        netlink_table_ungrab();
 
+       /* Wait for readers to complete */
+       synchronize_net();
+
        kfree(nlk->groups);
        nlk->groups = NULL;
 
@@ -1328,30 +1275,22 @@ static int netlink_autobind(struct socket *sock)
        struct sock *sk = sock->sk;
        struct net *net = sock_net(sk);
        struct netlink_table *table = &nl_table[sk->sk_protocol];
-       struct nl_portid_hash *hash = &table->hash;
-       struct hlist_head *head;
-       struct sock *osk;
        s32 portid = task_tgid_vnr(current);
        int err;
        static s32 rover = -4097;
 
 retry:
        cond_resched();
-       netlink_table_grab();
-       head = nl_portid_hashfn(hash, portid);
-       sk_for_each(osk, head) {
-               if (!table->compare(net, osk))
-                       continue;
-               if (nlk_sk(osk)->portid == portid) {
-                       /* Bind collision, search negative portid values. */
-                       portid = rover--;
-                       if (rover > -4097)
-                               rover = -4097;
-                       netlink_table_ungrab();
-                       goto retry;
-               }
+       rcu_read_lock();
+       if (__netlink_lookup(table, portid, net)) {
+               /* Bind collision, search negative portid values. */
+               portid = rover--;
+               if (rover > -4097)
+                       rover = -4097;
+               rcu_read_unlock();
+               goto retry;
        }
-       netlink_table_ungrab();
+       rcu_read_unlock();
 
        err = netlink_insert(sk, net, portid);
        if (err == -EADDRINUSE)
@@ -1961,25 +1900,25 @@ struct netlink_broadcast_data {
        void *tx_data;
 };
 
-static int do_one_broadcast(struct sock *sk,
-                                  struct netlink_broadcast_data *p)
+static void do_one_broadcast(struct sock *sk,
+                                   struct netlink_broadcast_data *p)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
        int val;
 
        if (p->exclude_sk == sk)
-               goto out;
+               return;
 
        if (nlk->portid == p->portid || p->group - 1 >= nlk->ngroups ||
            !test_bit(p->group - 1, nlk->groups))
-               goto out;
+               return;
 
        if (!net_eq(sock_net(sk), p->net))
-               goto out;
+               return;
 
        if (p->failure) {
                netlink_overrun(sk);
-               goto out;
+               return;
        }
 
        sock_hold(sk);
@@ -2017,9 +1956,6 @@ static int do_one_broadcast(struct sock *sk,
                p->skb2 = NULL;
        }
        sock_put(sk);
-
-out:
-       return 0;
 }
 
 int netlink_broadcast_filtered(struct sock *ssk, struct sk_buff *skb, u32 portid,
@@ -2958,14 +2894,18 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
 {
        struct nl_seq_iter *iter = seq->private;
        int i, j;
+       struct netlink_sock *nlk;
        struct sock *s;
        loff_t off = 0;
 
        for (i = 0; i < MAX_LINKS; i++) {
-               struct nl_portid_hash *hash = &nl_table[i].hash;
+               struct rhashtable *ht = &nl_table[i].hash;
+               const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
+
+               for (j = 0; j < tbl->size; j++) {
+                       rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+                               s = (struct sock *)nlk;
 
-               for (j = 0; j <= hash->mask; j++) {
-                       sk_for_each(s, &hash->table[j]) {
                                if (sock_net(s) != seq_file_net(seq))
                                        continue;
                                if (off == pos) {
@@ -2981,15 +2921,15 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
 }
 
 static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
-       __acquires(nl_table_lock)
+       __acquires(RCU)
 {
-       read_lock(&nl_table_lock);
+       rcu_read_lock();
        return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
 }
 
 static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 {
-       struct sock *s;
+       struct netlink_sock *nlk;
        struct nl_seq_iter *iter;
        struct net *net;
        int i, j;
@@ -3001,28 +2941,26 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 
        net = seq_file_net(seq);
        iter = seq->private;
-       s = v;
-       do {
-               s = sk_next(s);
-       } while (s && !nl_table[s->sk_protocol].compare(net, s));
-       if (s)
-               return s;
+       nlk = v;
+
+       rht_for_each_entry_rcu(nlk, nlk->node.next, node)
+               if (net_eq(sock_net((struct sock *)nlk), net))
+                       return nlk;
 
        i = iter->link;
        j = iter->hash_idx + 1;
 
        do {
-               struct nl_portid_hash *hash = &nl_table[i].hash;
-
-               for (; j <= hash->mask; j++) {
-                       s = sk_head(&hash->table[j]);
+               struct rhashtable *ht = &nl_table[i].hash;
+               const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
 
-                       while (s && !nl_table[s->sk_protocol].compare(net, s))
-                               s = sk_next(s);
-                       if (s) {
-                               iter->link = i;
-                               iter->hash_idx = j;
-                               return s;
+               for (; j < tbl->size; j++) {
+                       rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
+                               if (net_eq(sock_net((struct sock *)nlk), net)) {
+                                       iter->link = i;
+                                       iter->hash_idx = j;
+                                       return nlk;
+                               }
                        }
                }
 
@@ -3033,9 +2971,9 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 }
 
 static void netlink_seq_stop(struct seq_file *seq, void *v)
-       __releases(nl_table_lock)
+       __releases(RCU)
 {
-       read_unlock(&nl_table_lock);
+       rcu_read_unlock();
 }
 
 
@@ -3173,9 +3111,17 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
 static int __init netlink_proto_init(void)
 {
        int i;
-       unsigned long limit;
-       unsigned int order;
        int err = proto_register(&netlink_proto, 0);
+       struct rhashtable_params ht_params = {
+               .head_offset = offsetof(struct netlink_sock, node),
+               .key_offset = offsetof(struct netlink_sock, portid),
+               .key_len = sizeof(u32), /* portid */
+               .hashfn = arch_fast_hash,
+               .max_shift = 16, /* 64K */
+               .grow_decision = rht_grow_above_75,
+               .shrink_decision = rht_shrink_below_30,
+               .mutex_is_held = lockdep_nl_sk_hash_is_held,
+       };
 
        if (err != 0)
                goto out;
@@ -3186,32 +3132,13 @@ static int __init netlink_proto_init(void)
        if (!nl_table)
                goto panic;
 
-       if (totalram_pages >= (128 * 1024))
-               limit = totalram_pages >> (21 - PAGE_SHIFT);
-       else
-               limit = totalram_pages >> (23 - PAGE_SHIFT);
-
-       order = get_bitmask_order(limit) - 1 + PAGE_SHIFT;
-       limit = (1UL << order) / sizeof(struct hlist_head);
-       order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1;
-
        for (i = 0; i < MAX_LINKS; i++) {
-               struct nl_portid_hash *hash = &nl_table[i].hash;
-
-               hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table));
-               if (!hash->table) {
-                       while (i-- > 0)
-                               nl_portid_hash_free(nl_table[i].hash.table,
-                                                1 * sizeof(*hash->table));
+               if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
+                       while (--i > 0)
+                               rhashtable_destroy(&nl_table[i].hash);
                        kfree(nl_table);
                        goto panic;
                }
-               hash->max_shift = order;
-               hash->shift = 0;
-               hash->mask = 0;
-               hash->rehash_time = jiffies;
-
-               nl_table[i].compare = netlink_compare;
        }
 
        INIT_LIST_HEAD(&netlink_tap_all);
This page took 0.039436 seconds and 5 git commands to generate.