tcp: metrics: Allow selective get/del of tcp-metrics based on src IP
[deliverable/linux.git] / net / ipv4 / tcp_metrics.c
index e150f264c8e2b139771de6814a9c83969c795d1a..699a42faab9ceebaf638f96d60670df689f3e677 100644 (file)
@@ -877,44 +877,66 @@ done:
        return skb->len;
 }
 
-static int parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
-                        unsigned int *hash, int optional)
+static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
+                          unsigned int *hash, int optional, int v4, int v6)
 {
        struct nlattr *a;
 
-       a = info->attrs[TCP_METRICS_ATTR_ADDR_IPV4];
+       a = info->attrs[v4];
        if (a) {
                addr->family = AF_INET;
                addr->addr.a4 = nla_get_be32(a);
-               *hash = (__force unsigned int) addr->addr.a4;
+               if (hash)
+                       *hash = (__force unsigned int) addr->addr.a4;
                return 0;
        }
-       a = info->attrs[TCP_METRICS_ATTR_ADDR_IPV6];
+       a = info->attrs[v6];
        if (a) {
                if (nla_len(a) != sizeof(struct in6_addr))
                        return -EINVAL;
                addr->family = AF_INET6;
                memcpy(addr->addr.a6, nla_data(a), sizeof(addr->addr.a6));
-               *hash = ipv6_addr_hash((struct in6_addr *) addr->addr.a6);
+               if (hash)
+                       *hash = ipv6_addr_hash((struct in6_addr *) addr->addr.a6);
                return 0;
        }
        return optional ? 1 : -EAFNOSUPPORT;
 }
 
+static int parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr,
+                        unsigned int *hash, int optional)
+{
+       return __parse_nl_addr(info, addr, hash, optional,
+                              TCP_METRICS_ATTR_ADDR_IPV4,
+                              TCP_METRICS_ATTR_ADDR_IPV6);
+}
+
+static int parse_nl_saddr(struct genl_info *info, struct inetpeer_addr *addr)
+{
+       return __parse_nl_addr(info, addr, NULL, 0,
+                              TCP_METRICS_ATTR_SADDR_IPV4,
+                              TCP_METRICS_ATTR_SADDR_IPV6);
+}
+
 static int tcp_metrics_nl_cmd_get(struct sk_buff *skb, struct genl_info *info)
 {
        struct tcp_metrics_block *tm;
-       struct inetpeer_addr daddr;
+       struct inetpeer_addr saddr, daddr;
        unsigned int hash;
        struct sk_buff *msg;
        struct net *net = genl_info_net(info);
        void *reply;
        int ret;
+       bool src = true;
 
        ret = parse_nl_addr(info, &daddr, &hash, 0);
        if (ret < 0)
                return ret;
 
+       ret = parse_nl_saddr(info, &saddr);
+       if (ret < 0)
+               src = false;
+
        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
        if (!msg)
                return -ENOMEM;
@@ -929,7 +951,8 @@ static int tcp_metrics_nl_cmd_get(struct sk_buff *skb, struct genl_info *info)
        rcu_read_lock();
        for (tm = rcu_dereference(net->ipv4.tcp_metrics_hash[hash].chain); tm;
             tm = rcu_dereference(tm->tcpm_next)) {
-               if (addr_same(&tm->tcpm_daddr, &daddr)) {
+               if (addr_same(&tm->tcpm_daddr, &daddr) &&
+                   (!src || addr_same(&tm->tcpm_saddr, &saddr))) {
                        ret = tcp_metrics_fill_info(msg, tm);
                        break;
                }
@@ -984,23 +1007,28 @@ static int tcp_metrics_nl_cmd_del(struct sk_buff *skb, struct genl_info *info)
        struct tcpm_hash_bucket *hb;
        struct tcp_metrics_block *tm, *tmlist = NULL;
        struct tcp_metrics_block __rcu **pp;
-       struct inetpeer_addr daddr;
+       struct inetpeer_addr saddr, daddr;
        unsigned int hash;
        struct net *net = genl_info_net(info);
        int ret;
+       bool src = true;
 
        ret = parse_nl_addr(info, &daddr, &hash, 1);
        if (ret < 0)
                return ret;
        if (ret > 0)
                return tcp_metrics_flush_all(net);
+       ret = parse_nl_saddr(info, &saddr);
+       if (ret < 0)
+               src = false;
 
        hash = hash_32(hash, net->ipv4.tcp_metrics_hash_log);
        hb = net->ipv4.tcp_metrics_hash + hash;
        pp = &hb->chain;
        spin_lock_bh(&tcp_metrics_lock);
        for (tm = deref_locked_genl(*pp); tm; tm = deref_locked_genl(*pp)) {
-               if (addr_same(&tm->tcpm_daddr, &daddr)) {
+               if (addr_same(&tm->tcpm_daddr, &daddr) &&
+                   (!src || addr_same(&tm->tcpm_saddr, &saddr))) {
                        *pp = tm->tcpm_next;
                        tm->tcpm_next = tmlist;
                        tmlist = tm;
This page took 0.040585 seconds and 5 git commands to generate.