netns xfrm: flush SA/SPDs on netns stop
[deliverable/linux.git] / net / xfrm / xfrm_state.c
index 5f4c5340ba309b6551a84157fc31d73ed819aec6..662e47b0bcc3c576a529f636c4ffe9ad52896259 100644 (file)
@@ -24,9 +24,6 @@
 
 #include "xfrm_hash.h"
 
-struct sock *xfrm_nl;
-EXPORT_SYMBOL(xfrm_nl);
-
 u32 sysctl_xfrm_aevent_etime __read_mostly = XFRM_AE_ETIME;
 EXPORT_SYMBOL(sysctl_xfrm_aevent_etime);
 
@@ -670,13 +667,13 @@ xfrm_init_tempsel(struct xfrm_state *x, struct flowi *fl,
        return 0;
 }
 
-static struct xfrm_state *__xfrm_state_lookup(xfrm_address_t *daddr, __be32 spi, u8 proto, unsigned short family)
+static struct xfrm_state *__xfrm_state_lookup(struct net *net, xfrm_address_t *daddr, __be32 spi, u8 proto, unsigned short family)
 {
-       unsigned int h = xfrm_spi_hash(&init_net, daddr, spi, proto, family);
+       unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
        struct xfrm_state *x;
        struct hlist_node *entry;
 
-       hlist_for_each_entry(x, entry, init_net.xfrm.state_byspi+h, byspi) {
+       hlist_for_each_entry(x, entry, net->xfrm.state_byspi+h, byspi) {
                if (x->props.family != family ||
                    x->id.spi       != spi ||
                    x->id.proto     != proto)
@@ -702,13 +699,13 @@ static struct xfrm_state *__xfrm_state_lookup(xfrm_address_t *daddr, __be32 spi,
        return NULL;
 }
 
-static struct xfrm_state *__xfrm_state_lookup_byaddr(xfrm_address_t *daddr, xfrm_address_t *saddr, u8 proto, unsigned short family)
+static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, xfrm_address_t *daddr, xfrm_address_t *saddr, u8 proto, unsigned short family)
 {
-       unsigned int h = xfrm_src_hash(&init_net, daddr, saddr, family);
+       unsigned int h = xfrm_src_hash(net, daddr, saddr, family);
        struct xfrm_state *x;
        struct hlist_node *entry;
 
-       hlist_for_each_entry(x, entry, init_net.xfrm.state_bysrc+h, bysrc) {
+       hlist_for_each_entry(x, entry, net->xfrm.state_bysrc+h, bysrc) {
                if (x->props.family != family ||
                    x->id.proto     != proto)
                        continue;
@@ -740,11 +737,13 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(xfrm_address_t *daddr, xfrm
 static inline struct xfrm_state *
 __xfrm_state_locate(struct xfrm_state *x, int use_spi, int family)
 {
+       struct net *net = xs_net(x);
+
        if (use_spi)
-               return __xfrm_state_lookup(&x->id.daddr, x->id.spi,
+               return __xfrm_state_lookup(net, &x->id.daddr, x->id.spi,
                                           x->id.proto, family);
        else
-               return __xfrm_state_lookup_byaddr(&x->id.daddr,
+               return __xfrm_state_lookup_byaddr(net, &x->id.daddr,
                                                  &x->props.saddr,
                                                  x->id.proto, family);
 }
@@ -763,6 +762,7 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
                struct xfrm_policy *pol, int *err,
                unsigned short family)
 {
+       struct net *net = xp_net(pol);
        unsigned int h;
        struct hlist_node *entry;
        struct xfrm_state *x, *x0, *to_put;
@@ -773,8 +773,8 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
        to_put = NULL;
 
        spin_lock_bh(&xfrm_state_lock);
-       h = xfrm_dst_hash(&init_net, daddr, saddr, tmpl->reqid, family);
-       hlist_for_each_entry(x, entry, init_net.xfrm.state_bydst+h, bydst) {
+       h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, family);
+       hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
                if (x->props.family == family &&
                    x->props.reqid == tmpl->reqid &&
                    !(x->props.flags & XFRM_STATE_WILDRECV) &&
@@ -818,13 +818,13 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
        x = best;
        if (!x && !error && !acquire_in_progress) {
                if (tmpl->id.spi &&
-                   (x0 = __xfrm_state_lookup(daddr, tmpl->id.spi,
+                   (x0 = __xfrm_state_lookup(net, daddr, tmpl->id.spi,
                                              tmpl->id.proto, family)) != NULL) {
                        to_put = x0;
                        error = -EEXIST;
                        goto out;
                }
-               x = xfrm_state_alloc(&init_net);
+               x = xfrm_state_alloc(net);
                if (x == NULL) {
                        error = -ENOMEM;
                        goto out;
@@ -843,19 +843,19 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
 
                if (km_query(x, tmpl, pol) == 0) {
                        x->km.state = XFRM_STATE_ACQ;
-                       list_add(&x->km.all, &init_net.xfrm.state_all);
-                       hlist_add_head(&x->bydst, init_net.xfrm.state_bydst+h);
-                       h = xfrm_src_hash(&init_net, daddr, saddr, family);
-                       hlist_add_head(&x->bysrc, init_net.xfrm.state_bysrc+h);
+                       list_add(&x->km.all, &net->xfrm.state_all);
+                       hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
+                       h = xfrm_src_hash(net, daddr, saddr, family);
+                       hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
                        if (x->id.spi) {
-                               h = xfrm_spi_hash(&init_net, &x->id.daddr, x->id.spi, x->id.proto, family);
-                               hlist_add_head(&x->byspi, init_net.xfrm.state_byspi+h);
+                               h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, family);
+                               hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
                        }
                        x->lft.hard_add_expires_seconds = sysctl_xfrm_acq_expires;
                        x->timer.expires = jiffies + sysctl_xfrm_acq_expires*HZ;
                        add_timer(&x->timer);
-                       init_net.xfrm.state_num++;
-                       xfrm_hash_grow_check(&init_net, x->bydst.next != NULL);
+                       net->xfrm.state_num++;
+                       xfrm_hash_grow_check(net, x->bydst.next != NULL);
                } else {
                        x->km.state = XFRM_STATE_DEAD;
                        to_put = x;
@@ -875,7 +875,8 @@ out:
 }
 
 struct xfrm_state *
-xfrm_stateonly_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
+xfrm_stateonly_find(struct net *net,
+                   xfrm_address_t *daddr, xfrm_address_t *saddr,
                    unsigned short family, u8 mode, u8 proto, u32 reqid)
 {
        unsigned int h;
@@ -883,8 +884,8 @@ xfrm_stateonly_find(xfrm_address_t *daddr, xfrm_address_t *saddr,
        struct hlist_node *entry;
 
        spin_lock(&xfrm_state_lock);
-       h = xfrm_dst_hash(&init_net, daddr, saddr, reqid, family);
-       hlist_for_each_entry(x, entry, init_net.xfrm.state_bydst+h, bydst) {
+       h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
+       hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
                if (x->props.family == family &&
                    x->props.reqid == reqid &&
                    !(x->props.flags & XFRM_STATE_WILDRECV) &&
@@ -970,13 +971,13 @@ void xfrm_state_insert(struct xfrm_state *x)
 EXPORT_SYMBOL(xfrm_state_insert);
 
 /* xfrm_state_lock is held */
-static struct xfrm_state *__find_acq_core(unsigned short family, u8 mode, u32 reqid, u8 proto, xfrm_address_t *daddr, xfrm_address_t *saddr, int create)
+static struct xfrm_state *__find_acq_core(struct net *net, unsigned short family, u8 mode, u32 reqid, u8 proto, xfrm_address_t *daddr, xfrm_address_t *saddr, int create)
 {
-       unsigned int h = xfrm_dst_hash(&init_net, daddr, saddr, reqid, family);
+       unsigned int h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
        struct hlist_node *entry;
        struct xfrm_state *x;
 
-       hlist_for_each_entry(x, entry, init_net.xfrm.state_bydst+h, bydst) {
+       hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) {
                if (x->props.reqid  != reqid ||
                    x->props.mode   != mode ||
                    x->props.family != family ||
@@ -1008,7 +1009,7 @@ static struct xfrm_state *__find_acq_core(unsigned short family, u8 mode, u32 re
        if (!create)
                return NULL;
 
-       x = xfrm_state_alloc(&init_net);
+       x = xfrm_state_alloc(net);
        if (likely(x)) {
                switch (family) {
                case AF_INET:
@@ -1043,23 +1044,24 @@ static struct xfrm_state *__find_acq_core(unsigned short family, u8 mode, u32 re
                xfrm_state_hold(x);
                x->timer.expires = jiffies + sysctl_xfrm_acq_expires*HZ;
                add_timer(&x->timer);
-               list_add(&x->km.all, &init_net.xfrm.state_all);
-               hlist_add_head(&x->bydst, init_net.xfrm.state_bydst+h);
-               h = xfrm_src_hash(&init_net, daddr, saddr, family);
-               hlist_add_head(&x->bysrc, init_net.xfrm.state_bysrc+h);
+               list_add(&x->km.all, &net->xfrm.state_all);
+               hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
+               h = xfrm_src_hash(net, daddr, saddr, family);
+               hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
 
-               init_net.xfrm.state_num++;
+               net->xfrm.state_num++;
 
-               xfrm_hash_grow_check(&init_net, x->bydst.next != NULL);
+               xfrm_hash_grow_check(net, x->bydst.next != NULL);
        }
 
        return x;
 }
 
-static struct xfrm_state *__xfrm_find_acq_byseq(u32 seq);
+static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 seq);
 
 int xfrm_state_add(struct xfrm_state *x)
 {
+       struct net *net = xs_net(x);
        struct xfrm_state *x1, *to_put;
        int family;
        int err;
@@ -1080,7 +1082,7 @@ int xfrm_state_add(struct xfrm_state *x)
        }
 
        if (use_spi && x->km.seq) {
-               x1 = __xfrm_find_acq_byseq(x->km.seq);
+               x1 = __xfrm_find_acq_byseq(net, x->km.seq);
                if (x1 && ((x1->id.proto != x->id.proto) ||
                    xfrm_addr_cmp(&x1->id.daddr, &x->id.daddr, family))) {
                        to_put = x1;
@@ -1089,7 +1091,7 @@ int xfrm_state_add(struct xfrm_state *x)
        }
 
        if (use_spi && !x1)
-               x1 = __find_acq_core(family, x->props.mode, x->props.reqid,
+               x1 = __find_acq_core(net, family, x->props.mode, x->props.reqid,
                                     x->id.proto,
                                     &x->id.daddr, &x->props.saddr, 0);
 
@@ -1361,40 +1363,41 @@ int xfrm_state_check_expire(struct xfrm_state *x)
 EXPORT_SYMBOL(xfrm_state_check_expire);
 
 struct xfrm_state *
-xfrm_state_lookup(xfrm_address_t *daddr, __be32 spi, u8 proto,
+xfrm_state_lookup(struct net *net, xfrm_address_t *daddr, __be32 spi, u8 proto,
                  unsigned short family)
 {
        struct xfrm_state *x;
 
        spin_lock_bh(&xfrm_state_lock);
-       x = __xfrm_state_lookup(daddr, spi, proto, family);
+       x = __xfrm_state_lookup(net, daddr, spi, proto, family);
        spin_unlock_bh(&xfrm_state_lock);
        return x;
 }
 EXPORT_SYMBOL(xfrm_state_lookup);
 
 struct xfrm_state *
-xfrm_state_lookup_byaddr(xfrm_address_t *daddr, xfrm_address_t *saddr,
+xfrm_state_lookup_byaddr(struct net *net,
+                        xfrm_address_t *daddr, xfrm_address_t *saddr,
                         u8 proto, unsigned short family)
 {
        struct xfrm_state *x;
 
        spin_lock_bh(&xfrm_state_lock);
-       x = __xfrm_state_lookup_byaddr(daddr, saddr, proto, family);
+       x = __xfrm_state_lookup_byaddr(net, daddr, saddr, proto, family);
        spin_unlock_bh(&xfrm_state_lock);
        return x;
 }
 EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
 
 struct xfrm_state *
-xfrm_find_acq(u8 mode, u32 reqid, u8 proto,
+xfrm_find_acq(struct net *net, u8 mode, u32 reqid, u8 proto,
              xfrm_address_t *daddr, xfrm_address_t *saddr,
              int create, unsigned short family)
 {
        struct xfrm_state *x;
 
        spin_lock_bh(&xfrm_state_lock);
-       x = __find_acq_core(family, mode, reqid, proto, daddr, saddr, create);
+       x = __find_acq_core(net, family, mode, reqid, proto, daddr, saddr, create);
        spin_unlock_bh(&xfrm_state_lock);
 
        return x;
@@ -1441,15 +1444,15 @@ EXPORT_SYMBOL(xfrm_state_sort);
 
 /* Silly enough, but I'm lazy to build resolution list */
 
-static struct xfrm_state *__xfrm_find_acq_byseq(u32 seq)
+static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 seq)
 {
        int i;
 
-       for (i = 0; i <= init_net.xfrm.state_hmask; i++) {
+       for (i = 0; i <= net->xfrm.state_hmask; i++) {
                struct hlist_node *entry;
                struct xfrm_state *x;
 
-               hlist_for_each_entry(x, entry, init_net.xfrm.state_bydst+i, bydst) {
+               hlist_for_each_entry(x, entry, net->xfrm.state_bydst+i, bydst) {
                        if (x->km.seq == seq &&
                            x->km.state == XFRM_STATE_ACQ) {
                                xfrm_state_hold(x);
@@ -1460,12 +1463,12 @@ static struct xfrm_state *__xfrm_find_acq_byseq(u32 seq)
        return NULL;
 }
 
-struct xfrm_state *xfrm_find_acq_byseq(u32 seq)
+struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 seq)
 {
        struct xfrm_state *x;
 
        spin_lock_bh(&xfrm_state_lock);
-       x = __xfrm_find_acq_byseq(seq);
+       x = __xfrm_find_acq_byseq(net, seq);
        spin_unlock_bh(&xfrm_state_lock);
        return x;
 }
@@ -1486,6 +1489,7 @@ EXPORT_SYMBOL(xfrm_get_acqseq);
 
 int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
 {
+       struct net *net = xs_net(x);
        unsigned int h;
        struct xfrm_state *x0;
        int err = -ENOENT;
@@ -1503,7 +1507,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
        err = -ENOENT;
 
        if (minspi == maxspi) {
-               x0 = xfrm_state_lookup(&x->id.daddr, minspi, x->id.proto, x->props.family);
+               x0 = xfrm_state_lookup(net, &x->id.daddr, minspi, x->id.proto, x->props.family);
                if (x0) {
                        xfrm_state_put(x0);
                        goto unlock;
@@ -1513,7 +1517,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
                u32 spi = 0;
                for (h=0; h<high-low+1; h++) {
                        spi = low + net_random()%(high-low+1);
-                       x0 = xfrm_state_lookup(&x->id.daddr, htonl(spi), x->id.proto, x->props.family);
+                       x0 = xfrm_state_lookup(net, &x->id.daddr, htonl(spi), x->id.proto, x->props.family);
                        if (x0 == NULL) {
                                x->id.spi = htonl(spi);
                                break;
@@ -1523,8 +1527,8 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
        }
        if (x->id.spi) {
                spin_lock_bh(&xfrm_state_lock);
-               h = xfrm_spi_hash(&init_net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family);
-               hlist_add_head(&x->byspi, init_net.xfrm.state_byspi+h);
+               h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family);
+               hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
                spin_unlock_bh(&xfrm_state_lock);
 
                err = 0;
@@ -1537,7 +1541,7 @@ unlock:
 }
 EXPORT_SYMBOL(xfrm_alloc_spi);
 
-int xfrm_state_walk(struct xfrm_state_walk *walk,
+int xfrm_state_walk(struct net *net, struct xfrm_state_walk *walk,
                    int (*func)(struct xfrm_state *, int, void*),
                    void *data)
 {
@@ -1550,10 +1554,10 @@ int xfrm_state_walk(struct xfrm_state_walk *walk,
 
        spin_lock_bh(&xfrm_state_lock);
        if (list_empty(&walk->all))
-               x = list_first_entry(&init_net.xfrm.state_all, struct xfrm_state_walk, all);
+               x = list_first_entry(&net->xfrm.state_all, struct xfrm_state_walk, all);
        else
                x = list_entry(&walk->all, struct xfrm_state_walk, all);
-       list_for_each_entry_from(x, &init_net.xfrm.state_all, all) {
+       list_for_each_entry_from(x, &net->xfrm.state_all, all) {
                if (x->state == XFRM_STATE_DEAD)
                        continue;
                state = container_of(x, struct xfrm_state, km);
@@ -1652,7 +1656,7 @@ static void xfrm_replay_timer_handler(unsigned long data)
        spin_lock(&x->lock);
 
        if (x->km.state == XFRM_STATE_VALID) {
-               if (xfrm_aevent_is_on())
+               if (xfrm_aevent_is_on(xs_net(x)))
                        xfrm_replay_notify(x, XFRM_REPLAY_TIMEOUT);
                else
                        x->xflags |= XFRM_TIME_DEFER;
@@ -1708,7 +1712,7 @@ void xfrm_replay_advance(struct xfrm_state *x, __be32 net_seq)
                x->replay.bitmap |= (1U << diff);
        }
 
-       if (xfrm_aevent_is_on())
+       if (xfrm_aevent_is_on(xs_net(x)))
                xfrm_replay_notify(x, XFRM_REPLAY_UPDATE);
 }
 
@@ -1829,7 +1833,7 @@ int km_migrate(struct xfrm_selector *sel, u8 dir, u8 type,
 EXPORT_SYMBOL(km_migrate);
 #endif
 
-int km_report(u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr)
+int km_report(struct net *net, u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr)
 {
        int err = -EINVAL;
        int ret;
@@ -1838,7 +1842,7 @@ int km_report(u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr)
        read_lock(&xfrm_km_lock);
        list_for_each_entry(km, &xfrm_km_list, list) {
                if (km->report) {
-                       ret = km->report(proto, sel, addr);
+                       ret = km->report(net, proto, sel, addr);
                        if (!ret)
                                err = ret;
                }
@@ -2110,8 +2114,16 @@ out_bydst:
 
 void xfrm_state_fini(struct net *net)
 {
+       struct xfrm_audit audit_info;
        unsigned int sz;
 
+       flush_work(&net->xfrm.state_hash_work);
+       audit_info.loginuid = -1;
+       audit_info.sessionid = -1;
+       audit_info.secid = 0;
+       xfrm_state_flush(net, IPSEC_PROTO_ANY, &audit_info);
+       flush_work(&net->xfrm.state_gc_work);
+
        WARN_ON(!list_empty(&net->xfrm.state_all));
 
        sz = (net->xfrm.state_hmask + 1) * sizeof(struct hlist_head);
This page took 0.030443 seconds and 5 git commands to generate.