1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
12 #include <net/protocol.h>
14 #include <net/udp_tunnel.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
19 static DEFINE_SPINLOCK(fou_lock
);
20 static LIST_HEAD(fou_list
);
26 struct udp_offload udp_offloads
;
27 struct list_head list
;
33 struct udp_port_cfg udp_config
;
36 static inline struct fou
*fou_from_sock(struct sock
*sk
)
38 return sk
->sk_user_data
;
41 static int fou_udp_encap_recv_deliver(struct sk_buff
*skb
,
42 u8 protocol
, size_t len
)
44 struct iphdr
*iph
= ip_hdr(skb
);
46 /* Remove 'len' bytes from the packet (UDP header and
47 * FOU header if present), modify the protocol to the one
48 * we found, and then call rcv_encap.
50 iph
->tot_len
= htons(ntohs(iph
->tot_len
) - len
);
52 skb_postpull_rcsum(skb
, udp_hdr(skb
), len
);
53 skb_reset_transport_header(skb
);
58 static int fou_udp_recv(struct sock
*sk
, struct sk_buff
*skb
)
60 struct fou
*fou
= fou_from_sock(sk
);
65 return fou_udp_encap_recv_deliver(skb
, fou
->protocol
,
66 sizeof(struct udphdr
));
69 static int gue_udp_recv(struct sock
*sk
, struct sk_buff
*skb
)
71 struct fou
*fou
= fou_from_sock(sk
);
73 struct guehdr
*guehdr
;
79 len
= sizeof(struct udphdr
) + sizeof(struct guehdr
);
80 if (!pskb_may_pull(skb
, len
))
84 guehdr
= (struct guehdr
*)&uh
[1];
86 len
+= guehdr
->hlen
<< 2;
87 if (!pskb_may_pull(skb
, len
))
91 guehdr
= (struct guehdr
*)&uh
[1];
93 if (guehdr
->version
!= 0)
101 return fou_udp_encap_recv_deliver(skb
, guehdr
->next_hdr
, len
);
107 static struct sk_buff
**fou_gro_receive(struct sk_buff
**head
,
110 const struct net_offload
*ops
;
111 struct sk_buff
**pp
= NULL
;
112 u8 proto
= NAPI_GRO_CB(skb
)->proto
;
113 const struct net_offload
**offloads
;
116 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
117 ops
= rcu_dereference(offloads
[proto
]);
118 if (!ops
|| !ops
->callbacks
.gro_receive
)
121 pp
= ops
->callbacks
.gro_receive(head
, skb
);
129 static int fou_gro_complete(struct sk_buff
*skb
, int nhoff
)
131 const struct net_offload
*ops
;
132 u8 proto
= NAPI_GRO_CB(skb
)->proto
;
134 const struct net_offload
**offloads
;
137 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
138 ops
= rcu_dereference(offloads
[proto
]);
139 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_complete
))
142 err
= ops
->callbacks
.gro_complete(skb
, nhoff
);
150 static struct sk_buff
**gue_gro_receive(struct sk_buff
**head
,
153 const struct net_offload
**offloads
;
154 const struct net_offload
*ops
;
155 struct sk_buff
**pp
= NULL
;
158 struct guehdr
*guehdr
;
159 unsigned int hlen
, guehlen
;
163 off
= skb_gro_offset(skb
);
164 hlen
= off
+ sizeof(*guehdr
);
165 guehdr
= skb_gro_header_fast(skb
, off
);
166 if (skb_gro_header_hard(skb
, hlen
)) {
167 guehdr
= skb_gro_header_slow(skb
, hlen
, off
);
168 if (unlikely(!guehdr
))
172 proto
= guehdr
->next_hdr
;
175 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
176 ops
= rcu_dereference(offloads
[proto
]);
177 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_receive
))
180 guehlen
= sizeof(*guehdr
) + (guehdr
->hlen
<< 2);
182 hlen
= off
+ guehlen
;
183 if (skb_gro_header_hard(skb
, hlen
)) {
184 guehdr
= skb_gro_header_slow(skb
, hlen
, off
);
185 if (unlikely(!guehdr
))
191 for (p
= *head
; p
; p
= p
->next
) {
192 const struct guehdr
*guehdr2
;
194 if (!NAPI_GRO_CB(p
)->same_flow
)
197 guehdr2
= (struct guehdr
*)(p
->data
+ off
);
199 /* Compare base GUE header to be equal (covers
200 * hlen, version, next_hdr, and flags.
202 if (guehdr
->word
!= guehdr2
->word
) {
203 NAPI_GRO_CB(p
)->same_flow
= 0;
207 /* Compare optional fields are the same. */
208 if (guehdr
->hlen
&& memcmp(&guehdr
[1], &guehdr2
[1],
209 guehdr
->hlen
<< 2)) {
210 NAPI_GRO_CB(p
)->same_flow
= 0;
215 skb_gro_pull(skb
, guehlen
);
217 /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/
218 skb_gro_postpull_rcsum(skb
, guehdr
, guehlen
);
220 pp
= ops
->callbacks
.gro_receive(head
, skb
);
225 NAPI_GRO_CB(skb
)->flush
|= flush
;
230 static int gue_gro_complete(struct sk_buff
*skb
, int nhoff
)
232 const struct net_offload
**offloads
;
233 struct guehdr
*guehdr
= (struct guehdr
*)(skb
->data
+ nhoff
);
234 const struct net_offload
*ops
;
235 unsigned int guehlen
;
239 proto
= guehdr
->next_hdr
;
241 guehlen
= sizeof(*guehdr
) + (guehdr
->hlen
<< 2);
244 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
245 ops
= rcu_dereference(offloads
[proto
]);
246 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_complete
))
249 err
= ops
->callbacks
.gro_complete(skb
, nhoff
+ guehlen
);
256 static int fou_add_to_port_list(struct fou
*fou
)
260 spin_lock(&fou_lock
);
261 list_for_each_entry(fout
, &fou_list
, list
) {
262 if (fou
->port
== fout
->port
) {
263 spin_unlock(&fou_lock
);
268 list_add(&fou
->list
, &fou_list
);
269 spin_unlock(&fou_lock
);
274 static void fou_release(struct fou
*fou
)
276 struct socket
*sock
= fou
->sock
;
277 struct sock
*sk
= sock
->sk
;
279 udp_del_offload(&fou
->udp_offloads
);
281 list_del(&fou
->list
);
283 /* Remove hooks into tunnel socket */
284 sk
->sk_user_data
= NULL
;
291 static int fou_encap_init(struct sock
*sk
, struct fou
*fou
, struct fou_cfg
*cfg
)
293 udp_sk(sk
)->encap_rcv
= fou_udp_recv
;
294 fou
->protocol
= cfg
->protocol
;
295 fou
->udp_offloads
.callbacks
.gro_receive
= fou_gro_receive
;
296 fou
->udp_offloads
.callbacks
.gro_complete
= fou_gro_complete
;
297 fou
->udp_offloads
.port
= cfg
->udp_config
.local_udp_port
;
298 fou
->udp_offloads
.ipproto
= cfg
->protocol
;
303 static int gue_encap_init(struct sock
*sk
, struct fou
*fou
, struct fou_cfg
*cfg
)
305 udp_sk(sk
)->encap_rcv
= gue_udp_recv
;
306 fou
->udp_offloads
.callbacks
.gro_receive
= gue_gro_receive
;
307 fou
->udp_offloads
.callbacks
.gro_complete
= gue_gro_complete
;
308 fou
->udp_offloads
.port
= cfg
->udp_config
.local_udp_port
;
313 static int fou_create(struct net
*net
, struct fou_cfg
*cfg
,
314 struct socket
**sockp
)
316 struct fou
*fou
= NULL
;
318 struct socket
*sock
= NULL
;
321 /* Open UDP socket */
322 err
= udp_sock_create(net
, &cfg
->udp_config
, &sock
);
326 /* Allocate FOU port structure */
327 fou
= kzalloc(sizeof(*fou
), GFP_KERNEL
);
335 fou
->port
= cfg
->udp_config
.local_udp_port
;
337 /* Initial for fou type */
339 case FOU_ENCAP_DIRECT
:
340 err
= fou_encap_init(sk
, fou
, cfg
);
345 err
= gue_encap_init(sk
, fou
, cfg
);
354 udp_sk(sk
)->encap_type
= 1;
357 sk
->sk_user_data
= fou
;
360 udp_set_convert_csum(sk
, true);
362 sk
->sk_allocation
= GFP_ATOMIC
;
364 if (cfg
->udp_config
.family
== AF_INET
) {
365 err
= udp_add_offload(&fou
->udp_offloads
);
370 err
= fou_add_to_port_list(fou
);
387 static int fou_destroy(struct net
*net
, struct fou_cfg
*cfg
)
390 u16 port
= cfg
->udp_config
.local_udp_port
;
393 spin_lock(&fou_lock
);
394 list_for_each_entry(fou
, &fou_list
, list
) {
395 if (fou
->port
== port
) {
396 udp_del_offload(&fou
->udp_offloads
);
402 spin_unlock(&fou_lock
);
407 static struct genl_family fou_nl_family
= {
408 .id
= GENL_ID_GENERATE
,
410 .name
= FOU_GENL_NAME
,
411 .version
= FOU_GENL_VERSION
,
412 .maxattr
= FOU_ATTR_MAX
,
416 static struct nla_policy fou_nl_policy
[FOU_ATTR_MAX
+ 1] = {
417 [FOU_ATTR_PORT
] = { .type
= NLA_U16
, },
418 [FOU_ATTR_AF
] = { .type
= NLA_U8
, },
419 [FOU_ATTR_IPPROTO
] = { .type
= NLA_U8
, },
420 [FOU_ATTR_TYPE
] = { .type
= NLA_U8
, },
423 static int parse_nl_config(struct genl_info
*info
,
426 memset(cfg
, 0, sizeof(*cfg
));
428 cfg
->udp_config
.family
= AF_INET
;
430 if (info
->attrs
[FOU_ATTR_AF
]) {
431 u8 family
= nla_get_u8(info
->attrs
[FOU_ATTR_AF
]);
433 if (family
!= AF_INET
&& family
!= AF_INET6
)
436 cfg
->udp_config
.family
= family
;
439 if (info
->attrs
[FOU_ATTR_PORT
]) {
440 u16 port
= nla_get_u16(info
->attrs
[FOU_ATTR_PORT
]);
442 cfg
->udp_config
.local_udp_port
= port
;
445 if (info
->attrs
[FOU_ATTR_IPPROTO
])
446 cfg
->protocol
= nla_get_u8(info
->attrs
[FOU_ATTR_IPPROTO
]);
448 if (info
->attrs
[FOU_ATTR_TYPE
])
449 cfg
->type
= nla_get_u8(info
->attrs
[FOU_ATTR_TYPE
]);
454 static int fou_nl_cmd_add_port(struct sk_buff
*skb
, struct genl_info
*info
)
459 err
= parse_nl_config(info
, &cfg
);
463 return fou_create(&init_net
, &cfg
, NULL
);
466 static int fou_nl_cmd_rm_port(struct sk_buff
*skb
, struct genl_info
*info
)
470 parse_nl_config(info
, &cfg
);
472 return fou_destroy(&init_net
, &cfg
);
475 static const struct genl_ops fou_nl_ops
[] = {
478 .doit
= fou_nl_cmd_add_port
,
479 .policy
= fou_nl_policy
,
480 .flags
= GENL_ADMIN_PERM
,
484 .doit
= fou_nl_cmd_rm_port
,
485 .policy
= fou_nl_policy
,
486 .flags
= GENL_ADMIN_PERM
,
490 static int __init
fou_init(void)
494 ret
= genl_register_family_with_ops(&fou_nl_family
,
500 static void __exit
fou_fini(void)
502 struct fou
*fou
, *next
;
504 genl_unregister_family(&fou_nl_family
);
506 /* Close all the FOU sockets */
508 spin_lock(&fou_lock
);
509 list_for_each_entry_safe(fou
, next
, &fou_list
, list
)
511 spin_unlock(&fou_lock
);
514 module_init(fou_init
);
515 module_exit(fou_fini
);
516 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
517 MODULE_LICENSE("GPL");