rhashtable: Disable automatic shrinking by default
[deliverable/linux.git] / net / netlink / af_netlink.c
index 05919bf3f670ed1267e01f14c1de61e78e4d80c4..4caa809dbbe0332ca32b7b0279e61289b01574e3 100644 (file)
@@ -116,6 +116,8 @@ static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 static DEFINE_SPINLOCK(netlink_tap_lock);
 static struct list_head netlink_tap_all __read_mostly;
 
+static const struct rhashtable_params netlink_rhashtable_params;
+
 static inline u32 netlink_group_mask(u32 group)
 {
        return group ? 1 << (group - 1) : 0;
@@ -970,41 +972,50 @@ netlink_unlock_table(void)
 
 struct netlink_compare_arg
 {
-       struct net *net;
+       possible_net_t pnet;
        u32 portid;
 };
 
-static bool netlink_compare(void *ptr, void *arg)
+/* Doing sizeof directly may yield 4 extra bytes on 64-bit. */
+#define netlink_compare_arg_len \
+       (offsetof(struct netlink_compare_arg, portid) + sizeof(u32))
+
+static inline int netlink_compare(struct rhashtable_compare_arg *arg,
+                                 const void *ptr)
 {
-       struct netlink_compare_arg *x = arg;
-       struct sock *sk = ptr;
+       const struct netlink_compare_arg *x = arg->key;
+       const struct netlink_sock *nlk = ptr;
 
-       return nlk_sk(sk)->portid == x->portid &&
-              net_eq(sock_net(sk), x->net);
+       return nlk->portid != x->portid ||
+              !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet));
+}
+
+static void netlink_compare_arg_init(struct netlink_compare_arg *arg,
+                                    struct net *net, u32 portid)
+{
+       memset(arg, 0, sizeof(*arg));
+       write_pnet(&arg->pnet, net);
+       arg->portid = portid;
 }
 
 static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
                                     struct net *net)
 {
-       struct netlink_compare_arg arg = {
-               .net = net,
-               .portid = portid,
-       };
+       struct netlink_compare_arg arg;
 
-       return rhashtable_lookup_compare(&table->hash, &portid,
-                                        &netlink_compare, &arg);
+       netlink_compare_arg_init(&arg, net, portid);
+       return rhashtable_lookup_fast(&table->hash, &arg,
+                                     netlink_rhashtable_params);
 }
 
-static bool __netlink_insert(struct netlink_table *table, struct sock *sk)
+static int __netlink_insert(struct netlink_table *table, struct sock *sk)
 {
-       struct netlink_compare_arg arg = {
-               .net = sock_net(sk),
-               .portid = nlk_sk(sk)->portid,
-       };
+       struct netlink_compare_arg arg;
 
-       return rhashtable_lookup_compare_insert(&table->hash,
-                                               &nlk_sk(sk)->node,
-                                               &netlink_compare, &arg);
+       netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid);
+       return rhashtable_lookup_insert_key(&table->hash, &arg,
+                                           &nlk_sk(sk)->node,
+                                           netlink_rhashtable_params);
 }
 
 static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
@@ -1066,9 +1077,10 @@ static int netlink_insert(struct sock *sk, u32 portid)
        nlk_sk(sk)->portid = portid;
        sock_hold(sk);
 
-       err = 0;
-       if (!__netlink_insert(table, sk)) {
-               err = -EADDRINUSE;
+       err = __netlink_insert(table, sk);
+       if (err) {
+               if (err == -EEXIST)
+                       err = -EADDRINUSE;
                sock_put(sk);
        }
 
@@ -1082,7 +1094,8 @@ static void netlink_remove(struct sock *sk)
        struct netlink_table *table;
 
        table = &nl_table[sk->sk_protocol];
-       if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) {
+       if (!rhashtable_remove_fast(&table->hash, &nlk_sk(sk)->node,
+                                   netlink_rhashtable_params)) {
                WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
                __sock_put(sk);
        }
@@ -2256,8 +2269,7 @@ static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb)
        put_cmsg(msg, SOL_NETLINK, NETLINK_PKTINFO, sizeof(info), &info);
 }
 
-static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
-                          struct msghdr *msg, size_t len)
+static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 {
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
@@ -2346,8 +2358,7 @@ out:
        return err;
 }
 
-static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
-                          struct msghdr *msg, size_t len,
+static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
                           int flags)
 {
        struct scm_cookie scm;
@@ -3116,17 +3127,28 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
        .exit = netlink_net_exit,
 };
 
+static inline u32 netlink_hash(const void *data, u32 seed)
+{
+       const struct netlink_sock *nlk = data;
+       struct netlink_compare_arg arg;
+
+       netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid);
+       return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed);
+}
+
+static const struct rhashtable_params netlink_rhashtable_params = {
+       .head_offset = offsetof(struct netlink_sock, node),
+       .key_len = netlink_compare_arg_len,
+       .obj_hashfn = netlink_hash,
+       .obj_cmpfn = netlink_compare,
+       .max_size = 65536,
+       .automatic_shrinking = true,
+};
+
 static int __init netlink_proto_init(void)
 {
        int i;
        int err = proto_register(&netlink_proto, 0);
-       struct rhashtable_params ht_params = {
-               .head_offset = offsetof(struct netlink_sock, node),
-               .key_offset = offsetof(struct netlink_sock, portid),
-               .key_len = sizeof(u32), /* portid */
-               .hashfn = jhash,
-               .max_shift = 16, /* 64K */
-       };
 
        if (err != 0)
                goto out;
@@ -3138,7 +3160,8 @@ static int __init netlink_proto_init(void)
                goto panic;
 
        for (i = 0; i < MAX_LINKS; i++) {
-               if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
+               if (rhashtable_init(&nl_table[i].hash,
+                                   &netlink_rhashtable_params) < 0) {
                        while (--i > 0)
                                rhashtable_destroy(&nl_table[i].hash);
                        kfree(nl_table);
This page took 0.040748 seconds and 5 git commands to generate.