netlink: implement nla_get_in_addr and nla_get_in6_addr
[deliverable/linux.git] / net / ipv4 / tcp_metrics.c
index e5f41bd5ec1bcfe88199ec077f1558917b1be61b..71ec14c87579337a2cb9dc4752ebe37c189e1e20 100644 (file)
@@ -40,6 +40,7 @@ struct tcp_fastopen_metrics {
 
 struct tcp_metrics_block {
        struct tcp_metrics_block __rcu  *tcpm_next;
+       possible_net_t                  tcpm_net;
        struct inetpeer_addr            tcpm_saddr;
        struct inetpeer_addr            tcpm_daddr;
        unsigned long                   tcpm_stamp;
@@ -52,6 +53,11 @@ struct tcp_metrics_block {
        struct rcu_head                 rcu_head;
 };
 
+static inline struct net *tm_net(struct tcp_metrics_block *tm)
+{
+       return read_pnet(&tm->tcpm_net);
+}
+
 static bool tcp_metric_locked(struct tcp_metrics_block *tm,
                              enum tcp_metric_index idx)
 {
@@ -74,23 +80,20 @@ static void tcp_metric_set(struct tcp_metrics_block *tm,
 static bool addr_same(const struct inetpeer_addr *a,
                      const struct inetpeer_addr *b)
 {
-       const struct in6_addr *a6, *b6;
-
        if (a->family != b->family)
                return false;
        if (a->family == AF_INET)
                return a->addr.a4 == b->addr.a4;
-
-       a6 = (const struct in6_addr *) &a->addr.a6[0];
-       b6 = (const struct in6_addr *) &b->addr.a6[0];
-
-       return ipv6_addr_equal(a6, b6);
+       return ipv6_addr_equal(&a->addr.in6, &b->addr.in6);
 }
 
 struct tcpm_hash_bucket {
        struct tcp_metrics_block __rcu  *chain;
 };
 
+static struct tcpm_hash_bucket *tcp_metrics_hash __read_mostly;
+static unsigned int            tcp_metrics_hash_log __read_mostly;
+
 static DEFINE_SPINLOCK(tcp_metrics_lock);
 
 static void tcpm_suck_dst(struct tcp_metrics_block *tm,
@@ -143,6 +146,9 @@ static void tcpm_check_stamp(struct tcp_metrics_block *tm, struct dst_entry *dst
 #define TCP_METRICS_RECLAIM_DEPTH      5
 #define TCP_METRICS_RECLAIM_PTR                (struct tcp_metrics_block *) 0x1UL
 
+#define deref_locked(p)        \
+       rcu_dereference_protected(p, lockdep_is_held(&tcp_metrics_lock))
+
 static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst,
                                          struct inetpeer_addr *saddr,
                                          struct inetpeer_addr *daddr,
@@ -171,9 +177,9 @@ static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst,
        if (unlikely(reclaim)) {
                struct tcp_metrics_block *oldest;
 
-               oldest = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain);
-               for (tm = rcu_dereference(oldest->tcpm_next); tm;
-                    tm = rcu_dereference(tm->tcpm_next)) {
+               oldest = deref_locked(tcp_metrics_hash[hash].chain);
+               for (tm = deref_locked(oldest->tcpm_next); tm;
+                    tm = deref_locked(tm->tcpm_next)) {
                        if (time_before(tm->tcpm_stamp, oldest->tcpm_stamp))
                                oldest = tm;
                }
@@ -183,14 +189,15 @@ static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst,
                if (!tm)
                        goto out_unlock;
        }
+       write_pnet(&tm->tcpm_net, net);
        tm->tcpm_saddr = *saddr;
        tm->tcpm_daddr = *daddr;
 
        tcpm_suck_dst(tm, dst, true);
 
        if (likely(!reclaim)) {
-               tm->tcpm_next = net->ipv4.tcp_metrics_hash[hash].chain;
-               rcu_assign_pointer(net->ipv4.tcp_metrics_hash[hash].chain, tm);
+               tm->tcpm_next = tcp_metrics_hash[hash].chain;
+               rcu_assign_pointer(tcp_metrics_hash[hash].chain, tm);
        }
 
 out_unlock:
@@ -214,10 +221,11 @@ static struct tcp_metrics_block *__tcp_get_metrics(const struct inetpeer_addr *s
        struct tcp_metrics_block *tm;
        int depth = 0;
 
-       for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
+       for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
             tm = rcu_dereference(tm->tcpm_next)) {
                if (addr_same(&tm->tcpm_saddr, saddr) &&
-                   addr_same(&tm->tcpm_daddr, daddr))
+                   addr_same(&tm->tcpm_daddr, daddr) &&
+                   net_eq(tm_net(tm), net))
                        break;
                depth++;
        }
@@ -242,8 +250,8 @@ static struct tcp_metrics_block *__tcp_get_metrics_req(struct request_sock *req,
                break;
 #if IS_ENABLED(CONFIG_IPV6)
        case AF_INET6:
-               *(struct in6_addr *)saddr.addr.a6 = inet_rsk(req)->ir_v6_loc_addr;
-               *(struct in6_addr *)daddr.addr.a6 = inet_rsk(req)->ir_v6_rmt_addr;
+               saddr.addr.in6 = inet_rsk(req)->ir_v6_loc_addr;
+               daddr.addr.in6 = inet_rsk(req)->ir_v6_rmt_addr;
                hash = ipv6_addr_hash(&inet_rsk(req)->ir_v6_rmt_addr);
                break;
 #endif
@@ -252,12 +260,14 @@ static struct tcp_metrics_block *__tcp_get_metrics_req(struct request_sock *req,
        }
 
        net = dev_net(dst->dev);
-       hash = hash_32(hash, net->ipv4.tcp_metrics_hash_log);
+       hash ^= net_hash_mix(net);
+       hash = hash_32(hash, tcp_metrics_hash_log);
 
-       for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
+       for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
             tm = rcu_dereference(tm->tcpm_next)) {
                if (addr_same(&tm->tcpm_saddr, &saddr) &&
-                   addr_same(&tm->tcpm_daddr, &daddr))
+                   addr_same(&tm->tcpm_daddr, &daddr) &&
+                   net_eq(tm_net(tm), net))
                        break;
        }
        tcpm_check_stamp(tm, dst);
@@ -288,9 +298,9 @@ static struct tcp_metrics_block *__tcp_get_metrics_tw(struct inet_timewait_sock
                        hash = (__force unsigned int) daddr.addr.a4;
                } else {
                        saddr.family = AF_INET6;
-                       *(struct in6_addr *)saddr.addr.a6 = tw->tw_v6_rcv_saddr;
+                       saddr.addr.in6 = tw->tw_v6_rcv_saddr;
                        daddr.family = AF_INET6;
-                       *(struct in6_addr *)daddr.addr.a6 = tw->tw_v6_daddr;
+                       daddr.addr.in6 = tw->tw_v6_daddr;
                        hash = ipv6_addr_hash(&tw->tw_v6_daddr);
                }
        }
@@ -299,12 +309,14 @@ static struct tcp_metrics_block *__tcp_get_metrics_tw(struct inet_timewait_sock
                return NULL;
 
        net = twsk_net(tw);
-       hash = hash_32(hash, net->ipv4.tcp_metrics_hash_log);
+       hash ^= net_hash_mix(net);
+       hash = hash_32(hash, tcp_metrics_hash_log);
 
-       for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
+       for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
             tm = rcu_dereference(tm->tcpm_next)) {
                if (addr_same(&tm->tcpm_saddr, &saddr) &&
-                   addr_same(&tm->tcpm_daddr, &daddr))
+                   addr_same(&tm->tcpm_daddr, &daddr) &&
+                   net_eq(tm_net(tm), net))
                        break;
        }
        return tm;
@@ -336,9 +348,9 @@ static struct tcp_metrics_block *tcp_get_metrics(struct sock *sk,
                        hash = (__force unsigned int) daddr.addr.a4;
                } else {
                        saddr.family = AF_INET6;
-                       *(struct in6_addr *)saddr.addr.a6 = sk->sk_v6_rcv_saddr;
+                       saddr.addr.in6 = sk->sk_v6_rcv_saddr;
                        daddr.family = AF_INET6;
-                       *(struct in6_addr *)daddr.addr.a6 = sk->sk_v6_daddr;
+                       daddr.addr.in6 = sk->sk_v6_daddr;
                        hash = ipv6_addr_hash(&sk->sk_v6_daddr);
                }
        }
@@ -347,7 +359,8 @@ static struct tcp_metrics_block *tcp_get_metrics(struct sock *sk,
                return NULL;
 
        net = dev_net(dst->dev);
-       hash = hash_32(hash, net->ipv4.tcp_metrics_hash_log);
+       hash ^= net_hash_mix(net);
+       hash = hash_32(hash, tcp_metrics_hash_log);
 
        tm = __tcp_get_metrics(&saddr, &daddr, net, hash);
        if (tm == TCP_METRICS_RECLAIM_PTR)
@@ -773,19 +786,19 @@ static int tcp_metrics_fill_info(struct sk_buff *msg,
 
        switch (tm->tcpm_daddr.family) {
        case AF_INET:
-               if (nla_put_be32(msg, TCP_METRICS_ATTR_ADDR_IPV4,
-                               tm->tcpm_daddr.addr.a4) < 0)
+               if (nla_put_in_addr(msg, TCP_METRICS_ATTR_ADDR_IPV4,
+                                   tm->tcpm_daddr.addr.a4) < 0)
                        goto nla_put_failure;
-               if (nla_put_be32(msg, TCP_METRICS_ATTR_SADDR_IPV4,
-                               tm->tcpm_saddr.addr.a4) < 0)
+               if (nla_put_in_addr(msg, TCP_METRICS_ATTR_SADDR_IPV4,
+                                   tm->tcpm_saddr.addr.a4) < 0)
                        goto nla_put_failure;
                break;
        case AF_INET6:
-               if (nla_put(msg, TCP_METRICS_ATTR_ADDR_IPV6, 16,
-                           tm->tcpm_daddr.addr.a6) < 0)
+               if (nla_put_in6_addr(msg, TCP_METRICS_ATTR_ADDR_IPV6,
+                                    &tm->tcpm_daddr.addr.in6) < 0)
                        goto nla_put_failure;
-               if (nla_put(msg, TCP_METRICS_ATTR_SADDR_IPV6, 16,
-                           tm->tcpm_saddr.addr.a6) < 0)
+               if (nla_put_in6_addr(msg, TCP_METRICS_ATTR_SADDR_IPV6,
+                                    &tm->tcpm_saddr.addr.in6) < 0)
                        goto nla_put_failure;
                break;
        default:
@@ -898,17 +911,19 @@ static int tcp_metrics_nl_dump(struct sk_buff *skb,
                               struct netlink_callback *cb)
 {
        struct net *net = sock_net(skb->sk);
-       unsigned int max_rows = 1U << net->ipv4.tcp_metrics_hash_log;
+       unsigned int max_rows = 1U << tcp_metrics_hash_log;
        unsigned int row, s_row = cb->args[0];
        int s_col = cb->args[1], col = s_col;
 
        for (row = s_row; row < max_rows; row++, s_col = 0) {
                struct tcp_metrics_block *tm;
-               struct tcpm_hash_bucket *hb = net->ipv4.tcp_metrics_hash + row;
+               struct tcpm_hash_bucket *hb = tcp_metrics_hash + row;
 
                rcu_read_lock();
                for (col = 0, tm = rcu_dereference(hb->chain); tm;
                     tm = rcu_dereference(tm->tcpm_next), col++) {
+                       if (!net_eq(tm_net(tm), net))
+                               continue;
                        if (col < s_col)
                                continue;
                        if (tcp_metrics_dump_info(skb, cb, tm) < 0) {
@@ -933,7 +948,7 @@ static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
        a = info->attrs[v4];
        if (a) {
                addr->family = AF_INET;
-               addr->addr.a4 = nla_get_be32(a);
+               addr->addr.a4 = nla_get_in_addr(a);
                if (hash)
                        *hash = (__force unsigned int) addr->addr.a4;
                return 0;
@@ -943,9 +958,9 @@ static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
                if (nla_len(a) != sizeof(struct in6_addr))
                        return -EINVAL;
                addr->family = AF_INET6;
-               memcpy(addr->addr.a6, nla_data(a), sizeof(addr->addr.a6));
+               addr->addr.in6 = nla_get_in6_addr(a);
                if (hash)
-                       *hash = ipv6_addr_hash((struct in6_addr *) addr->addr.a6);
+                       *hash = ipv6_addr_hash(&addr->addr.in6);
                return 0;
        }
        return optional ? 1 : -EAFNOSUPPORT;
@@ -994,13 +1009,15 @@ static int tcp_metrics_nl_cmd_get(struct sk_buff *skb, struct genl_info *info)
        if (!reply)
                goto nla_put_failure;
 
-       hash = hash_32(hash, net->ipv4.tcp_metrics_hash_log);
+       hash ^= net_hash_mix(net);
+       hash = hash_32(hash, tcp_metrics_hash_log);
        ret = -ESRCH;
        rcu_read_lock();
-       for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
+       for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm;
             tm = rcu_dereference(tm->tcpm_next)) {
                if (addr_same(&tm->tcpm_daddr, &daddr) &&
-                   (!src || addr_same(&tm->tcpm_saddr, &saddr))) {
+                   (!src || addr_same(&tm->tcpm_saddr, &saddr)) &&
+                   net_eq(tm_net(tm), net)) {
                        ret = tcp_metrics_fill_info(msg, tm);
                        break;
                }
@@ -1020,34 +1037,27 @@ out_free:
        return ret;
 }
 
-#define deref_locked_genl(p)   \
-       rcu_dereference_protected(p, lockdep_genl_is_held() && \
-                                    lockdep_is_held(&tcp_metrics_lock))
-
-#define deref_genl(p)  rcu_dereference_protected(p, lockdep_genl_is_held())
-
-static int tcp_metrics_flush_all(struct net *net)
+static void tcp_metrics_flush_all(struct net *net)
 {
-       unsigned int max_rows = 1U << net->ipv4.tcp_metrics_hash_log;
-       struct tcpm_hash_bucket *hb = net->ipv4.tcp_metrics_hash;
+       unsigned int max_rows = 1U << tcp_metrics_hash_log;
+       struct tcpm_hash_bucket *hb = tcp_metrics_hash;
        struct tcp_metrics_block *tm;
        unsigned int row;
 
        for (row = 0; row < max_rows; row++, hb++) {
+               struct tcp_metrics_block __rcu **pp;
                spin_lock_bh(&tcp_metrics_lock);
-               tm = deref_locked_genl(hb->chain);
-               if (tm)
-                       hb->chain = NULL;
-               spin_unlock_bh(&tcp_metrics_lock);
-               while (tm) {
-                       struct tcp_metrics_block *next;
-
-                       next = deref_genl(tm->tcpm_next);
-                       kfree_rcu(tm, rcu_head);
-                       tm = next;
+               pp = &hb->chain;
+               for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) {
+                       if (net_eq(tm_net(tm), net)) {
+                               *pp = tm->tcpm_next;
+                               kfree_rcu(tm, rcu_head);
+                       } else {
+                               pp = &tm->tcpm_next;
+                       }
                }
+               spin_unlock_bh(&tcp_metrics_lock);
        }
-       return 0;
 }
 
 static int tcp_metrics_nl_cmd_del(struct sk_buff *skb, struct genl_info *info)
@@ -1064,19 +1074,23 @@ static int tcp_metrics_nl_cmd_del(struct sk_buff *skb, struct genl_info *info)
        ret = parse_nl_addr(info, &daddr, &hash, 1);
        if (ret < 0)
                return ret;
-       if (ret > 0)
-               return tcp_metrics_flush_all(net);
+       if (ret > 0) {
+               tcp_metrics_flush_all(net);
+               return 0;
+       }
        ret = parse_nl_saddr(info, &saddr);
        if (ret < 0)
                src = false;
 
-       hash = hash_32(hash, net->ipv4.tcp_metrics_hash_log);
-       hb = net->ipv4.tcp_metrics_hash + hash;
+       hash ^= net_hash_mix(net);
+       hash = hash_32(hash, tcp_metrics_hash_log);
+       hb = tcp_metrics_hash + hash;
        pp = &hb->chain;
        spin_lock_bh(&tcp_metrics_lock);
-       for (tm = deref_locked_genl(*pp); tm; tm = deref_locked_genl(*pp)) {
+       for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) {
                if (addr_same(&tm->tcpm_daddr, &daddr) &&
-                   (!src || addr_same(&tm->tcpm_saddr, &saddr))) {
+                   (!src || addr_same(&tm->tcpm_saddr, &saddr)) &&
+                   net_eq(tm_net(tm), net)) {
                        *pp = tm->tcpm_next;
                        kfree_rcu(tm, rcu_head);
                        found = true;
@@ -1126,6 +1140,9 @@ static int __net_init tcp_net_metrics_init(struct net *net)
        size_t size;
        unsigned int slots;
 
+       if (!net_eq(net, &init_net))
+               return 0;
+
        slots = tcpmhash_entries;
        if (!slots) {
                if (totalram_pages >= 128 * 1024)
@@ -1134,14 +1151,14 @@ static int __net_init tcp_net_metrics_init(struct net *net)
                        slots = 8 * 1024;
        }
 
-       net->ipv4.tcp_metrics_hash_log = order_base_2(slots);
-       size = sizeof(struct tcpm_hash_bucket) << net->ipv4.tcp_metrics_hash_log;
+       tcp_metrics_hash_log = order_base_2(slots);
+       size = sizeof(struct tcpm_hash_bucket) << tcp_metrics_hash_log;
 
-       net->ipv4.tcp_metrics_hash = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
-       if (!net->ipv4.tcp_metrics_hash)
-               net->ipv4.tcp_metrics_hash = vzalloc(size);
+       tcp_metrics_hash = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
+       if (!tcp_metrics_hash)
+               tcp_metrics_hash = vzalloc(size);
 
-       if (!net->ipv4.tcp_metrics_hash)
+       if (!tcp_metrics_hash)
                return -ENOMEM;
 
        return 0;
@@ -1149,19 +1166,7 @@ static int __net_init tcp_net_metrics_init(struct net *net)
 
 static void __net_exit tcp_net_metrics_exit(struct net *net)
 {
-       unsigned int i;
-
-       for (i = 0; i < (1U << net->ipv4.tcp_metrics_hash_log) ; i++) {
-               struct tcp_metrics_block *tm, *next;
-
-               tm = rcu_dereference_protected(net->ipv4.tcp_metrics_hash[i].chain, 1);
-               while (tm) {
-                       next = rcu_dereference_protected(tm->tcpm_next, 1);
-                       kfree(tm);
-                       tm = next;
-               }
-       }
-       kvfree(net->ipv4.tcp_metrics_hash);
+       tcp_metrics_flush_all(net);
 }
 
 static __net_initdata struct pernet_operations tcp_net_metrics_ops = {
@@ -1175,16 +1180,10 @@ void __init tcp_metrics_init(void)
 
        ret = register_pernet_subsys(&tcp_net_metrics_ops);
        if (ret < 0)
-               goto cleanup;
+               panic("Could not allocate the tcp_metrics hash table\n");
+
        ret = genl_register_family_with_ops(&tcp_metrics_nl_family,
                                            tcp_metrics_nl_ops);
        if (ret < 0)
-               goto cleanup_subsys;
-       return;
-
-cleanup_subsys:
-       unregister_pernet_subsys(&tcp_net_metrics_ops);
-
-cleanup:
-       return;
+               panic("Could not register tcp_metrics generic netlink\n");
 }
This page took 0.041933 seconds and 5 git commands to generate.