inet: Avoid potential NULL peer dereference.
[deliverable/linux.git] / net / ipv6 / route.c
index 8a986be4aedad358b417db5217fffc20572bb248..58a3ec23da2f7fbb45eb3306e2cbb8187267f203 100644 (file)
@@ -258,16 +258,19 @@ static struct rt6_info ip6_blk_hole_entry_template = {
 #endif
 
 /* allocate dst with ip6_dst_ops */
-static inline struct rt6_info *ip6_dst_alloc(struct dst_ops *ops,
+static inline struct rt6_info *ip6_dst_alloc(struct net *net,
                                             struct net_device *dev,
-                                            int flags)
+                                            int flags,
+                                            struct fib6_table *table)
 {
-       struct rt6_info *rt = dst_alloc(ops, dev, 0, 0, flags);
+       struct rt6_info *rt = dst_alloc(&net->ipv6.ip6_dst_ops, dev,
+                                       0, 0, flags);
 
-       if (rt)
+       if (rt) {
                memset(&rt->rt6i_table, 0,
                       sizeof(*rt) - sizeof(struct dst_entry));
-
+               rt6_init_peer(rt, table ? &table->tb6_peers : net->ipv6.peers);
+       }
        return rt;
 }
 
@@ -275,7 +278,6 @@ static void ip6_dst_destroy(struct dst_entry *dst)
 {
        struct rt6_info *rt = (struct rt6_info *)dst;
        struct inet6_dev *idev = rt->rt6i_idev;
-       struct inet_peer *peer = rt->rt6i_peer;
 
        if (!(rt->dst.flags & DST_HOST))
                dst_destroy_metrics_generic(dst);
@@ -288,8 +290,8 @@ static void ip6_dst_destroy(struct dst_entry *dst)
        if (!(rt->rt6i_flags & RTF_EXPIRES) && dst->from)
                dst_release(dst->from);
 
-       if (peer) {
-               rt->rt6i_peer = NULL;
+       if (rt6_has_peer(rt)) {
+               struct inet_peer *peer = rt6_peer_ptr(rt);
                inet_putpeer(peer);
        }
 }
@@ -303,14 +305,20 @@ static u32 rt6_peer_genid(void)
 
 void rt6_bind_peer(struct rt6_info *rt, int create)
 {
-       struct net *net = dev_net(rt->dst.dev);
+       struct inet_peer_base *base;
        struct inet_peer *peer;
 
-       peer = inet_getpeer_v6(net, &rt->rt6i_dst.addr, create);
-       if (peer && cmpxchg(&rt->rt6i_peer, NULL, peer) != NULL)
-               inet_putpeer(peer);
-       else
-               rt->rt6i_peer_genid = rt6_peer_genid();
+       base = inetpeer_base_ptr(rt->_rt6i_peer);
+       if (!base)
+               return;
+
+       peer = inet_getpeer_v6(base, &rt->rt6i_dst.addr, create);
+       if (peer) {
+               if (!rt6_set_peer(rt, peer))
+                       inet_putpeer(peer);
+               else
+                       rt->rt6i_peer_genid = rt6_peer_genid();
+       }
 }
 
 static void ip6_dst_ifdown(struct dst_entry *dst, struct net_device *dev,
@@ -950,6 +958,7 @@ struct dst_entry *ip6_blackhole_route(struct net *net, struct dst_entry *dst_ori
        rt = dst_alloc(&ip6_dst_blackhole_ops, ort->dst.dev, 1, 0, 0);
        if (rt) {
                memset(&rt->rt6i_table, 0, sizeof(*rt) - sizeof(struct dst_entry));
+               rt6_init_peer(rt, net->ipv6.peers);
 
                new = &rt->dst;
 
@@ -994,7 +1003,7 @@ static struct dst_entry *ip6_dst_check(struct dst_entry *dst, u32 cookie)
 
        if (rt->rt6i_node && (rt->rt6i_node->fn_sernum == cookie)) {
                if (rt->rt6i_peer_genid != rt6_peer_genid()) {
-                       if (!rt->rt6i_peer)
+                       if (!rt6_has_peer(rt))
                                rt6_bind_peer(rt, 0);
                        rt->rt6i_peer_genid = rt6_peer_genid();
                }
@@ -1108,7 +1117,7 @@ struct dst_entry *icmp6_dst_alloc(struct net_device *dev,
        if (unlikely(!idev))
                return ERR_PTR(-ENODEV);
 
-       rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops, dev, 0);
+       rt = ip6_dst_alloc(net, dev, 0, NULL);
        if (unlikely(!rt)) {
                in6_dev_put(idev);
                dst = ERR_PTR(-ENOMEM);
@@ -1290,7 +1299,7 @@ int ip6_route_add(struct fib6_config *cfg)
        if (!table)
                goto out;
 
-       rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops, NULL, DST_NOCOUNT);
+       rt = ip6_dst_alloc(net, NULL, DST_NOCOUNT, table);
 
        if (!rt) {
                err = -ENOMEM;
@@ -1812,8 +1821,8 @@ static struct rt6_info *ip6_rt_copy(struct rt6_info *ort,
                                    const struct in6_addr *dest)
 {
        struct net *net = dev_net(ort->dst.dev);
-       struct rt6_info *rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops,
-                                           ort->dst.dev, 0);
+       struct rt6_info *rt = ip6_dst_alloc(net, ort->dst.dev, 0,
+                                           ort->rt6i_table);
 
        if (rt) {
                rt->dst.input = ort->dst.input;
@@ -2097,8 +2106,7 @@ struct rt6_info *addrconf_dst_alloc(struct inet6_dev *idev,
                                    bool anycast)
 {
        struct net *net = dev_net(idev->dev);
-       struct rt6_info *rt = ip6_dst_alloc(&net->ipv6.ip6_dst_ops,
-                                           net->loopback_dev, 0);
+       struct rt6_info *rt = ip6_dst_alloc(net, net->loopback_dev, 0, NULL);
        int err;
 
        if (!rt) {
@@ -2519,7 +2527,9 @@ static int rt6_fill_node(struct net *net,
        else
                expires = INT_MAX;
 
-       peer = rt->rt6i_peer;
+       peer = NULL;
+       if (rt6_has_peer(rt))
+               peer = rt6_peer_ptr(rt);
        ts = tsage = 0;
        if (peer && peer->tcp_ts_stamp) {
                ts = peer->tcp_ts;
@@ -2996,6 +3006,31 @@ static struct pernet_operations ip6_route_net_ops = {
        .exit = ip6_route_net_exit,
 };
 
+static int __net_init ipv6_inetpeer_init(struct net *net)
+{
+       struct inet_peer_base *bp = kmalloc(sizeof(*bp), GFP_KERNEL);
+
+       if (!bp)
+               return -ENOMEM;
+       inet_peer_base_init(bp);
+       net->ipv6.peers = bp;
+       return 0;
+}
+
+static void __net_exit ipv6_inetpeer_exit(struct net *net)
+{
+       struct inet_peer_base *bp = net->ipv6.peers;
+
+       net->ipv6.peers = NULL;
+       inetpeer_invalidate_tree(bp);
+       kfree(bp);
+}
+
+static struct pernet_operations ipv6_inetpeer_ops = {
+       .init   =       ipv6_inetpeer_init,
+       .exit   =       ipv6_inetpeer_exit,
+};
+
 static struct notifier_block ip6_route_dev_notifier = {
        .notifier_call = ip6_route_dev_notify,
        .priority = 0,
@@ -3020,6 +3055,10 @@ int __init ip6_route_init(void)
        if (ret)
                goto out_dst_entries;
 
+       ret = register_pernet_subsys(&ipv6_inetpeer_ops);
+       if (ret)
+               goto out_register_subsys;
+
        ip6_dst_blackhole_ops.kmem_cachep = ip6_dst_ops_template.kmem_cachep;
 
        /* Registering of the loopback is done before this portion of code,
@@ -3035,7 +3074,7 @@ int __init ip6_route_init(void)
   #endif
        ret = fib6_init();
        if (ret)
-               goto out_register_subsys;
+               goto out_register_inetpeer;
 
        ret = xfrm6_init();
        if (ret)
@@ -3064,6 +3103,8 @@ xfrm6_init:
        xfrm6_fini();
 out_fib6_init:
        fib6_gc_cleanup();
+out_register_inetpeer:
+       unregister_pernet_subsys(&ipv6_inetpeer_ops);
 out_register_subsys:
        unregister_pernet_subsys(&ip6_route_net_ops);
 out_dst_entries:
@@ -3079,6 +3120,7 @@ void ip6_route_cleanup(void)
        fib6_rules_cleanup();
        xfrm6_fini();
        fib6_gc_cleanup();
+       unregister_pernet_subsys(&ipv6_inetpeer_ops);
        unregister_pernet_subsys(&ip6_route_net_ops);
        dst_entries_destroy(&ip6_dst_blackhole_ops);
        kmem_cache_destroy(ip6_dst_ops_template.kmem_cachep);
This page took 0.028812 seconds and 5 git commands to generate.