net/ipv4: Move && and || to end of previous line
[deliverable/linux.git] / net / ipv4 / udp.c
index d73e9170536be12bee0efc3bdd3aad4150c448cf..1f9534846ca9a3b457e00266cd84ba7b2cb42450 100644 (file)
@@ -136,12 +136,12 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
        struct hlist_nulls_node *node;
 
        sk_nulls_for_each(sk2, node, &hslot->head)
-               if (net_eq(sock_net(sk2), net)                  &&
-                   sk2 != sk                                   &&
+               if (net_eq(sock_net(sk2), net) &&
+                   sk2 != sk &&
                    (bitmap || udp_sk(sk2)->udp_port_hash == num) &&
-                   (!sk2->sk_reuse || !sk->sk_reuse)           &&
-                   (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
-                       || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
+                   (!sk2->sk_reuse || !sk->sk_reuse) &&
+                   (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
+                    sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
                    (*saddr_comp)(sk, sk2)) {
                        if (bitmap)
                                __set_bit(udp_sk(sk2)->udp_port_hash >> log,
@@ -152,16 +152,49 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
        return 0;
 }
 
+/*
+ * Note: we still hold spinlock of primary hash chain, so no other writer
+ * can insert/delete a socket with local_port == num
+ */
+static int udp_lib_lport_inuse2(struct net *net, __u16 num,
+                              struct udp_hslot *hslot2,
+                              struct sock *sk,
+                              int (*saddr_comp)(const struct sock *sk1,
+                                                const struct sock *sk2))
+{
+       struct sock *sk2;
+       struct hlist_nulls_node *node;
+       int res = 0;
+
+       spin_lock(&hslot2->lock);
+       udp_portaddr_for_each_entry(sk2, node, &hslot2->head)
+               if (net_eq(sock_net(sk2), net) &&
+                   sk2 != sk &&
+                   (udp_sk(sk2)->udp_port_hash == num) &&
+                   (!sk2->sk_reuse || !sk->sk_reuse) &&
+                   (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
+                    sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
+                   (*saddr_comp)(sk, sk2)) {
+                       res = 1;
+                       break;
+               }
+       spin_unlock(&hslot2->lock);
+       return res;
+}
+
 /**
  *  udp_lib_get_port  -  UDP/-Lite port lookup for IPv4 and IPv6
  *
  *  @sk:          socket struct in question
  *  @snum:        port number to look up
  *  @saddr_comp:  AF-dependent comparison of bound local IP addresses
+ *  @hash2_nulladdr: AF-dependant hash value in secondary hash chains,
+ *                   with NULL address
  */
 int udp_lib_get_port(struct sock *sk, unsigned short snum,
                       int (*saddr_comp)(const struct sock *sk1,
-                                        const struct sock *sk2))
+                                        const struct sock *sk2),
+                    unsigned int hash2_nulladdr)
 {
        struct udp_hslot *hslot, *hslot2;
        struct udp_table *udptable = sk->sk_prot->h.udp_table;
@@ -210,6 +243,30 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
        } else {
                hslot = udp_hashslot(udptable, net, snum);
                spin_lock_bh(&hslot->lock);
+               if (hslot->count > 10) {
+                       int exist;
+                       unsigned int slot2 = udp_sk(sk)->udp_portaddr_hash ^ snum;
+
+                       slot2          &= udptable->mask;
+                       hash2_nulladdr &= udptable->mask;
+
+                       hslot2 = udp_hashslot2(udptable, slot2);
+                       if (hslot->count < hslot2->count)
+                               goto scan_primary_hash;
+
+                       exist = udp_lib_lport_inuse2(net, snum, hslot2,
+                                                    sk, saddr_comp);
+                       if (!exist && (hash2_nulladdr != slot2)) {
+                               hslot2 = udp_hashslot2(udptable, hash2_nulladdr);
+                               exist = udp_lib_lport_inuse2(net, snum, hslot2,
+                                                            sk, saddr_comp);
+                       }
+                       if (exist)
+                               goto fail_unlock;
+                       else
+                               goto found;
+               }
+scan_primary_hash:
                if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk,
                                        saddr_comp, 0))
                        goto fail_unlock;
@@ -255,12 +312,14 @@ static unsigned int udp4_portaddr_hash(struct net *net, __be32 saddr,
 
 int udp_v4_get_port(struct sock *sk, unsigned short snum)
 {
+       unsigned int hash2_nulladdr =
+               udp4_portaddr_hash(sock_net(sk), INADDR_ANY, snum);
+       unsigned int hash2_partial =
+               udp4_portaddr_hash(sock_net(sk), inet_sk(sk)->inet_rcv_saddr, 0);
+
        /* precompute partial secondary hash */
-       udp_sk(sk)->udp_portaddr_hash =
-               udp4_portaddr_hash(sock_net(sk),
-                                  inet_sk(sk)->inet_rcv_saddr,
-                                  0);
-       return udp_lib_get_port(sk, snum, ipv4_rcv_saddr_equal);
+       udp_sk(sk)->udp_portaddr_hash = hash2_partial;
+       return udp_lib_get_port(sk, snum, ipv4_rcv_saddr_equal, hash2_nulladdr);
 }
 
 static inline int compute_score(struct sock *sk, struct net *net, __be32 saddr,
@@ -336,8 +395,6 @@ static inline int compute_score2(struct sock *sk, struct net *net,
        return score;
 }
 
-#define udp_portaddr_for_each_entry_rcu(__sk, node, list) \
-       hlist_nulls_for_each_entry_rcu(__sk, node, list, __sk_common.skc_portaddr_node)
 
 /* called with read_rcu_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
@@ -488,13 +545,13 @@ static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
        sk_nulls_for_each_from(s, node) {
                struct inet_sock *inet = inet_sk(s);
 
-               if (!net_eq(sock_net(s), net)                           ||
-                   udp_sk(s)->udp_port_hash != hnum                    ||
-                   (inet->inet_daddr && inet->inet_daddr != rmt_addr)  ||
-                   (inet->inet_dport != rmt_port && inet->inet_dport)  ||
-                   (inet->inet_rcv_saddr       &&
-                    inet->inet_rcv_saddr != loc_addr)                  ||
-                   ipv6_only_sock(s)                                   ||
+               if (!net_eq(sock_net(s), net) ||
+                   udp_sk(s)->udp_port_hash != hnum ||
+                   (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
+                   (inet->inet_dport != rmt_port && inet->inet_dport) ||
+                   (inet->inet_rcv_saddr &&
+                    inet->inet_rcv_saddr != loc_addr) ||
+                   ipv6_only_sock(s) ||
                    (s->sk_bound_dev_if && s->sk_bound_dev_if != dif))
                        continue;
                if (!ip_mc_sf_allow(s, loc_addr, rmt_addr, dif))
This page took 0.026792 seconds and 5 git commands to generate.