inet: Avoid potential NULL peer dereference.
[deliverable/linux.git] / net / ipv6 / route.c
index 999a982ad3fd7d7abac40211b50320fc4c038109..58a3ec23da2f7fbb45eb3306e2cbb8187267f203 100644 (file)
@@ -99,10 +99,7 @@ static u32 *ipv6_cow_metrics(struct dst_entry *dst, unsigned long old)
        if (!(rt->dst.flags & DST_HOST))
                return NULL;
 
-       if (!rt->rt6i_peer)
-               rt6_bind_peer(rt, 1);
-
-       peer = rt->rt6i_peer;
+       peer = rt6_get_peer_create(rt);
        if (peer) {
                u32 *old_p = __DST_METRICS_PTR(old);
                unsigned long prev, new;
@@ -261,16 +258,19 @@ static struct rt6_info ip6_blk_hole_entry_template = {
 #endif
 
 /* allocate dst with ip6_dst_ops */
-static inline struct rt6_info *ip6_dst_alloc(struct dst_ops *ops,
+static inline struct rt6_info *ip6_dst_alloc(struct net *net,
                                             struct net_device *dev,
-                                            int flags)
+                                            int flags,
+                                            struct fib6_table *table)
 {
-       struct rt6_info *rt = dst_alloc(ops, dev, 0, 0, flags);
+       struct rt6_info *rt = dst_alloc(&net->ipv6.ip6_dst_ops, dev,
+                                       0, 0, flags);
 
-       if (rt)
+       if (rt) {
                memset(&rt->rt6i_table, 0,
                       sizeof(*rt) - sizeof(struct dst_entry));
-
+               rt6_init_peer(rt, table ? &table->tb6_peers : net->ipv6.peers);
+       }
        return rt;
 }
 
@@ -278,7 +278,6 @@ static void ip6_dst_destroy(struct dst_entry *dst)
 {
        struct rt6_info *rt = (struct rt6_info *)dst;
        struct inet6_dev *idev = rt->rt6i_idev;
-       struct inet_peer *peer = rt->rt6i_peer;
 
        if (!(rt->dst.flags & DST_HOST))
                dst_destroy_metrics_generic(dst);
@@ -291,8 +290,8 @@ static void ip6_dst_destroy(struct dst_entry *dst)
        if (!(rt->rt6i_flags & RTF_EXPIRES) && dst->from)
                dst_release(dst->from);
 
-       if (peer) {
-               rt->rt6i_peer = NULL;
+       if (rt6_has_peer(rt)) {
+               struct inet_peer *peer = rt6_peer_ptr(rt);
                inet_putpeer(peer);
        }
 }
@@ -306,13 +305,20 @@ static u32 rt6_peer_genid(void)
 
 void rt6_bind_peer(struct rt6_info *rt, int create)
 {
+       struct inet_peer_base *base;
        struct inet_peer *peer;
 
-       peer = inet_getpeer_v6(&rt->rt6i_dst.addr, create);
-       if (peer && cmpxchg(&rt->rt6i_peer, NULL, peer) != NULL)
-               inet_putpeer(peer);
-       else
-               rt->rt6i_peer_genid = rt6_peer_genid();
+       base = inetpeer_base_ptr(rt->_rt6i_peer);
+       if (!base)
+               return;
+
+       peer = inet_getpeer_v6(base, &rt->rt6i_dst.addr, create);
+       if (peer) {
+               if (!rt6_set_peer(rt, peer))
+                       inet_putpeer(peer);
+               else
+                       rt->rt6i_peer_genid = rt6_peer_genid();
+       }
 }
 
 static void ip6_dst_ifdown(struct dst_entry *dst, struct net_device *dev,
@@ -952,6 +958,7 @@ struct dst_entry *ip6_blackhole_route(struct net *net, struct dst_entry *dst_ori
        rt = dst_alloc(&ip6_dst_blackhole_ops, ort->dst.dev, 1, 0, 0);
        if (rt) {
                memset(&rt->rt6i_table, 0, sizeof(*rt) - sizeof(struct dst_entry));
+               rt6_init_peer(rt, net->ipv6.peers);
 
                new = &rt->dst;
 
@@ -996,7 +1003,7 @@ static struct dst_entry *ip6_dst_check(struct dst_entry *dst, u32 cookie)
 
        if (rt->rt6i_node && (rt->rt6i_node->fn_sernum == cookie)) {
                if (rt->rt6i_peer_genid != rt6_peer_genid()) {
-                       if (!rt->rt6i_peer)
+                       if (!rt6_has_peer(rt))
                                rt6_bind_peer(rt, 0);
                        rt->rt6i_peer_genid = rt6_peer_genid();
                }
@@ -1110,7 +1117,7 @@ struct dst_entry *icmp6_dst_alloc(struct net_device *dev,
        if (unlikely(!idev))
                return ERR_PTR(-ENODEV);
 
-       rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops, dev, 0);
+       rt = ip6_dst_alloc(net, dev, 0, NULL);
        if (unlikely(!rt)) {
                in6_dev_put(idev);
                dst = ERR_PTR(-ENOMEM);
@@ -1292,7 +1299,7 @@ int ip6_route_add(struct fib6_config *cfg)
        if (!table)
                goto out;
 
-       rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops, NULL, DST_NOCOUNT);
+       rt = ip6_dst_alloc(net, NULL, DST_NOCOUNT, table);
 
        if (!rt) {
                err = -ENOMEM;
@@ -1814,8 +1821,8 @@ static struct rt6_info *ip6_rt_copy(struct rt6_info *ort,
                                    const struct in6_addr *dest)
 {
        struct net *net = dev_net(ort->dst.dev);
-       struct rt6_info *rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops,
-                                           ort->dst.dev, 0);
+       struct rt6_info *rt = ip6_dst_alloc(net, ort->dst.dev, 0,
+                                           ort->rt6i_table);
 
        if (rt) {
                rt->dst.input = ort->dst.input;
@@ -2099,8 +2106,7 @@ struct rt6_info *addrconf_dst_alloc(struct inet6_dev *idev,
                                    bool anycast)
 {
        struct net *net = dev_net(idev->dev);
-       struct rt6_info *rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops,
-                                           net->loopback_dev, 0);
+       struct rt6_info *rt = ip6_dst_alloc(net, net->loopback_dev, 0, NULL);
        int err;
 
        if (!rt) {
@@ -2521,7 +2527,9 @@ static int rt6_fill_node(struct net *net,
        else
                expires = INT_MAX;
 
-       peer = rt->rt6i_peer;
+       peer = NULL;
+       if (rt6_has_peer(rt))
+               peer = rt6_peer_ptr(rt);
        ts = tsage = 0;
        if (peer && peer->tcp_ts_stamp) {
                ts = peer->tcp_ts;
@@ -2998,6 +3006,31 @@ static struct pernet_operations ip6_route_net_ops = {
        .exit = ip6_route_net_exit,
 };
 
+static int __net_init ipv6_inetpeer_init(struct net *net)
+{
+       struct inet_peer_base *bp = kmalloc(sizeof(*bp), GFP_KERNEL);
+
+       if (!bp)
+               return -ENOMEM;
+       inet_peer_base_init(bp);
+       net->ipv6.peers = bp;
+       return 0;
+}
+
+static void __net_exit ipv6_inetpeer_exit(struct net *net)
+{
+       struct inet_peer_base *bp = net->ipv6.peers;
+
+       net->ipv6.peers = NULL;
+       inetpeer_invalidate_tree(bp);
+       kfree(bp);
+}
+
+static struct pernet_operations ipv6_inetpeer_ops = {
+       .init   =       ipv6_inetpeer_init,
+       .exit   =       ipv6_inetpeer_exit,
+};
+
 static struct notifier_block ip6_route_dev_notifier = {
        .notifier_call = ip6_route_dev_notify,
        .priority = 0,
@@ -3022,6 +3055,10 @@ int __init ip6_route_init(void)
        if (ret)
                goto out_dst_entries;
 
+       ret = register_pernet_subsys(&ipv6_inetpeer_ops);
+       if (ret)
+               goto out_register_subsys;
+
        ip6_dst_blackhole_ops.kmem_cachep = ip6_dst_ops_template.kmem_cachep;
 
        /* Registering of the loopback is done before this portion of code,
@@ -3037,7 +3074,7 @@ int __init ip6_route_init(void)
   #endif
        ret = fib6_init();
        if (ret)
-               goto out_register_subsys;
+               goto out_register_inetpeer;
 
        ret = xfrm6_init();
        if (ret)
@@ -3066,6 +3103,8 @@ xfrm6_init:
        xfrm6_fini();
 out_fib6_init:
        fib6_gc_cleanup();
+out_register_inetpeer:
+       unregister_pernet_subsys(&ipv6_inetpeer_ops);
 out_register_subsys:
        unregister_pernet_subsys(&ip6_route_net_ops);
 out_dst_entries:
@@ -3081,6 +3120,7 @@ void ip6_route_cleanup(void)
        fib6_rules_cleanup();
        xfrm6_fini();
        fib6_gc_cleanup();
+       unregister_pernet_subsys(&ipv6_inetpeer_ops);
        unregister_pernet_subsys(&ip6_route_net_ops);
        dst_entries_destroy(&ip6_dst_blackhole_ops);
        kmem_cache_destroy(ip6_dst_ops_template.kmem_cachep);
This page took 0.029164 seconds and 5 git commands to generate.