1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
12 #include <net/protocol.h>
14 #include <net/udp_tunnel.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
19 static DEFINE_SPINLOCK(fou_lock
);
20 static LIST_HEAD(fou_list
);
26 struct udp_offload udp_offloads
;
27 struct list_head list
;
33 struct udp_port_cfg udp_config
;
36 static inline struct fou
*fou_from_sock(struct sock
*sk
)
38 return sk
->sk_user_data
;
41 static void fou_recv_pull(struct sk_buff
*skb
, size_t len
)
43 struct iphdr
*iph
= ip_hdr(skb
);
45 /* Remove 'len' bytes from the packet (UDP header and
46 * FOU header if present).
48 iph
->tot_len
= htons(ntohs(iph
->tot_len
) - len
);
50 skb_postpull_rcsum(skb
, udp_hdr(skb
), len
);
51 skb_reset_transport_header(skb
);
54 static int fou_udp_recv(struct sock
*sk
, struct sk_buff
*skb
)
56 struct fou
*fou
= fou_from_sock(sk
);
61 fou_recv_pull(skb
, sizeof(struct udphdr
));
63 return -fou
->protocol
;
66 static struct guehdr
*gue_remcsum(struct sk_buff
*skb
, struct guehdr
*guehdr
,
67 void *data
, int hdrlen
, u8 ipproto
)
70 u16 start
= ntohs(pd
[0]);
71 u16 offset
= ntohs(pd
[1]);
77 if (skb
->remcsum_offload
) {
78 /* Already processed in GRO path */
79 skb
->remcsum_offload
= 0;
83 if (start
> skb
->len
- hdrlen
||
84 offset
> skb
->len
- hdrlen
- sizeof(u16
))
87 if (unlikely(skb
->ip_summed
!= CHECKSUM_COMPLETE
))
88 __skb_checksum_complete(skb
);
90 plen
= hdrlen
+ offset
+ sizeof(u16
);
91 if (!pskb_may_pull(skb
, plen
))
93 guehdr
= (struct guehdr
*)&udp_hdr(skb
)[1];
95 if (ipproto
== IPPROTO_IP
&& sizeof(struct iphdr
) < plen
) {
96 struct iphdr
*ip
= (struct iphdr
*)(skb
->data
+ hdrlen
);
98 /* If next header happens to be IP we can skip that for the
99 * checksum calculation since the IP header checksum is zero
102 poffset
= ip
->ihl
* 4;
105 csum
= csum_sub(skb
->csum
, skb_checksum(skb
, poffset
+ hdrlen
,
106 start
- poffset
- hdrlen
, 0));
108 /* Set derived checksum in packet */
109 psum
= (__sum16
*)(skb
->data
+ hdrlen
+ offset
);
110 delta
= csum_sub(csum_fold(csum
), *psum
);
111 *psum
= csum_fold(csum
);
113 /* Adjust skb->csum since we changed the packet */
114 skb
->csum
= csum_add(skb
->csum
, delta
);
119 static int gue_control_message(struct sk_buff
*skb
, struct guehdr
*guehdr
)
126 static int gue_udp_recv(struct sock
*sk
, struct sk_buff
*skb
)
128 struct fou
*fou
= fou_from_sock(sk
);
129 size_t len
, optlen
, hdrlen
;
130 struct guehdr
*guehdr
;
137 len
= sizeof(struct udphdr
) + sizeof(struct guehdr
);
138 if (!pskb_may_pull(skb
, len
))
141 guehdr
= (struct guehdr
*)&udp_hdr(skb
)[1];
143 optlen
= guehdr
->hlen
<< 2;
146 if (!pskb_may_pull(skb
, len
))
149 /* guehdr may change after pull */
150 guehdr
= (struct guehdr
*)&udp_hdr(skb
)[1];
152 hdrlen
= sizeof(struct guehdr
) + optlen
;
154 if (guehdr
->version
!= 0 || validate_gue_flags(guehdr
, optlen
))
157 hdrlen
= sizeof(struct guehdr
) + optlen
;
159 ip_hdr(skb
)->tot_len
= htons(ntohs(ip_hdr(skb
)->tot_len
) - len
);
161 /* Pull UDP header now, skb->data points to guehdr */
162 __skb_pull(skb
, sizeof(struct udphdr
));
164 /* Pull csum through the guehdr now . This can be used if
165 * there is a remote checksum offload.
167 skb_postpull_rcsum(skb
, udp_hdr(skb
), len
);
171 if (guehdr
->flags
& GUE_FLAG_PRIV
) {
172 __be32 flags
= *(__be32
*)(data
+ doffset
);
174 doffset
+= GUE_LEN_PRIV
;
176 if (flags
& GUE_PFLAG_REMCSUM
) {
177 guehdr
= gue_remcsum(skb
, guehdr
, data
+ doffset
,
178 hdrlen
, guehdr
->proto_ctype
);
184 doffset
+= GUE_PLEN_REMCSUM
;
188 if (unlikely(guehdr
->control
))
189 return gue_control_message(skb
, guehdr
);
191 __skb_pull(skb
, hdrlen
);
192 skb_reset_transport_header(skb
);
194 return -guehdr
->proto_ctype
;
201 static struct sk_buff
**fou_gro_receive(struct sk_buff
**head
,
204 const struct net_offload
*ops
;
205 struct sk_buff
**pp
= NULL
;
206 u8 proto
= NAPI_GRO_CB(skb
)->proto
;
207 const struct net_offload
**offloads
;
210 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
211 ops
= rcu_dereference(offloads
[proto
]);
212 if (!ops
|| !ops
->callbacks
.gro_receive
)
215 pp
= ops
->callbacks
.gro_receive(head
, skb
);
223 static int fou_gro_complete(struct sk_buff
*skb
, int nhoff
)
225 const struct net_offload
*ops
;
226 u8 proto
= NAPI_GRO_CB(skb
)->proto
;
228 const struct net_offload
**offloads
;
231 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
232 ops
= rcu_dereference(offloads
[proto
]);
233 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_complete
))
236 err
= ops
->callbacks
.gro_complete(skb
, nhoff
);
244 static struct guehdr
*gue_gro_remcsum(struct sk_buff
*skb
, unsigned int off
,
245 struct guehdr
*guehdr
, void *data
,
246 size_t hdrlen
, u8 ipproto
)
249 u16 start
= ntohs(pd
[0]);
250 u16 offset
= ntohs(pd
[1]);
257 if (skb
->remcsum_offload
)
260 if (start
> skb_gro_len(skb
) - hdrlen
||
261 offset
> skb_gro_len(skb
) - hdrlen
- sizeof(u16
) ||
262 !NAPI_GRO_CB(skb
)->csum_valid
|| skb
->remcsum_offload
)
265 plen
= hdrlen
+ offset
+ sizeof(u16
);
267 /* Pull checksum that will be written */
268 if (skb_gro_header_hard(skb
, off
+ plen
)) {
269 guehdr
= skb_gro_header_slow(skb
, off
+ plen
, off
);
274 ptr
= (void *)guehdr
+ hdrlen
;
276 if (ipproto
== IPPROTO_IP
&&
277 (hdrlen
+ sizeof(struct iphdr
) < plen
)) {
278 struct iphdr
*ip
= (struct iphdr
*)(ptr
+ hdrlen
);
280 /* If next header happens to be IP we can skip
281 * that for the checksum calculation since the
282 * IP header checksum is zero if correct.
284 poffset
= ip
->ihl
* 4;
287 csum
= csum_sub(NAPI_GRO_CB(skb
)->csum
,
288 csum_partial(ptr
+ poffset
, start
- poffset
, 0));
290 /* Set derived checksum in packet */
291 psum
= (__sum16
*)(ptr
+ offset
);
292 delta
= csum_sub(csum_fold(csum
), *psum
);
293 *psum
= csum_fold(csum
);
295 /* Adjust skb->csum since we changed the packet */
296 skb
->csum
= csum_add(skb
->csum
, delta
);
297 NAPI_GRO_CB(skb
)->csum
= csum_add(NAPI_GRO_CB(skb
)->csum
, delta
);
299 skb
->remcsum_offload
= 1;
304 static struct sk_buff
**gue_gro_receive(struct sk_buff
**head
,
307 const struct net_offload
**offloads
;
308 const struct net_offload
*ops
;
309 struct sk_buff
**pp
= NULL
;
311 struct guehdr
*guehdr
;
312 size_t len
, optlen
, hdrlen
, off
;
317 off
= skb_gro_offset(skb
);
318 len
= off
+ sizeof(*guehdr
);
320 guehdr
= skb_gro_header_fast(skb
, off
);
321 if (skb_gro_header_hard(skb
, len
)) {
322 guehdr
= skb_gro_header_slow(skb
, len
, off
);
323 if (unlikely(!guehdr
))
327 optlen
= guehdr
->hlen
<< 2;
330 if (skb_gro_header_hard(skb
, len
)) {
331 guehdr
= skb_gro_header_slow(skb
, len
, off
);
332 if (unlikely(!guehdr
))
336 if (unlikely(guehdr
->control
) || guehdr
->version
!= 0 ||
337 validate_gue_flags(guehdr
, optlen
))
340 hdrlen
= sizeof(*guehdr
) + optlen
;
342 /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
343 * this is needed if there is a remote checkcsum offload.
345 skb_gro_postpull_rcsum(skb
, guehdr
, hdrlen
);
349 if (guehdr
->flags
& GUE_FLAG_PRIV
) {
350 __be32 flags
= *(__be32
*)(data
+ doffset
);
352 doffset
+= GUE_LEN_PRIV
;
354 if (flags
& GUE_PFLAG_REMCSUM
) {
355 guehdr
= gue_gro_remcsum(skb
, off
, guehdr
,
356 data
+ doffset
, hdrlen
,
357 guehdr
->proto_ctype
);
363 doffset
+= GUE_PLEN_REMCSUM
;
367 skb_gro_pull(skb
, hdrlen
);
371 for (p
= *head
; p
; p
= p
->next
) {
372 const struct guehdr
*guehdr2
;
374 if (!NAPI_GRO_CB(p
)->same_flow
)
377 guehdr2
= (struct guehdr
*)(p
->data
+ off
);
379 /* Compare base GUE header to be equal (covers
380 * hlen, version, proto_ctype, and flags.
382 if (guehdr
->word
!= guehdr2
->word
) {
383 NAPI_GRO_CB(p
)->same_flow
= 0;
387 /* Compare optional fields are the same. */
388 if (guehdr
->hlen
&& memcmp(&guehdr
[1], &guehdr2
[1],
389 guehdr
->hlen
<< 2)) {
390 NAPI_GRO_CB(p
)->same_flow
= 0;
396 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
397 ops
= rcu_dereference(offloads
[guehdr
->proto_ctype
]);
398 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_receive
))
401 pp
= ops
->callbacks
.gro_receive(head
, skb
);
406 NAPI_GRO_CB(skb
)->flush
|= flush
;
411 static int gue_gro_complete(struct sk_buff
*skb
, int nhoff
)
413 const struct net_offload
**offloads
;
414 struct guehdr
*guehdr
= (struct guehdr
*)(skb
->data
+ nhoff
);
415 const struct net_offload
*ops
;
416 unsigned int guehlen
;
420 proto
= guehdr
->proto_ctype
;
422 guehlen
= sizeof(*guehdr
) + (guehdr
->hlen
<< 2);
425 offloads
= NAPI_GRO_CB(skb
)->is_ipv6
? inet6_offloads
: inet_offloads
;
426 ops
= rcu_dereference(offloads
[proto
]);
427 if (WARN_ON(!ops
|| !ops
->callbacks
.gro_complete
))
430 err
= ops
->callbacks
.gro_complete(skb
, nhoff
+ guehlen
);
437 static int fou_add_to_port_list(struct fou
*fou
)
441 spin_lock(&fou_lock
);
442 list_for_each_entry(fout
, &fou_list
, list
) {
443 if (fou
->port
== fout
->port
) {
444 spin_unlock(&fou_lock
);
449 list_add(&fou
->list
, &fou_list
);
450 spin_unlock(&fou_lock
);
455 static void fou_release(struct fou
*fou
)
457 struct socket
*sock
= fou
->sock
;
458 struct sock
*sk
= sock
->sk
;
460 udp_del_offload(&fou
->udp_offloads
);
462 list_del(&fou
->list
);
464 /* Remove hooks into tunnel socket */
465 sk
->sk_user_data
= NULL
;
472 static int fou_encap_init(struct sock
*sk
, struct fou
*fou
, struct fou_cfg
*cfg
)
474 udp_sk(sk
)->encap_rcv
= fou_udp_recv
;
475 fou
->protocol
= cfg
->protocol
;
476 fou
->udp_offloads
.callbacks
.gro_receive
= fou_gro_receive
;
477 fou
->udp_offloads
.callbacks
.gro_complete
= fou_gro_complete
;
478 fou
->udp_offloads
.port
= cfg
->udp_config
.local_udp_port
;
479 fou
->udp_offloads
.ipproto
= cfg
->protocol
;
484 static int gue_encap_init(struct sock
*sk
, struct fou
*fou
, struct fou_cfg
*cfg
)
486 udp_sk(sk
)->encap_rcv
= gue_udp_recv
;
487 fou
->udp_offloads
.callbacks
.gro_receive
= gue_gro_receive
;
488 fou
->udp_offloads
.callbacks
.gro_complete
= gue_gro_complete
;
489 fou
->udp_offloads
.port
= cfg
->udp_config
.local_udp_port
;
494 static int fou_create(struct net
*net
, struct fou_cfg
*cfg
,
495 struct socket
**sockp
)
497 struct fou
*fou
= NULL
;
499 struct socket
*sock
= NULL
;
502 /* Open UDP socket */
503 err
= udp_sock_create(net
, &cfg
->udp_config
, &sock
);
507 /* Allocate FOU port structure */
508 fou
= kzalloc(sizeof(*fou
), GFP_KERNEL
);
516 fou
->port
= cfg
->udp_config
.local_udp_port
;
518 /* Initial for fou type */
520 case FOU_ENCAP_DIRECT
:
521 err
= fou_encap_init(sk
, fou
, cfg
);
526 err
= gue_encap_init(sk
, fou
, cfg
);
535 udp_sk(sk
)->encap_type
= 1;
538 sk
->sk_user_data
= fou
;
541 udp_set_convert_csum(sk
, true);
543 sk
->sk_allocation
= GFP_ATOMIC
;
545 if (cfg
->udp_config
.family
== AF_INET
) {
546 err
= udp_add_offload(&fou
->udp_offloads
);
551 err
= fou_add_to_port_list(fou
);
568 static int fou_destroy(struct net
*net
, struct fou_cfg
*cfg
)
571 u16 port
= cfg
->udp_config
.local_udp_port
;
574 spin_lock(&fou_lock
);
575 list_for_each_entry(fou
, &fou_list
, list
) {
576 if (fou
->port
== port
) {
577 udp_del_offload(&fou
->udp_offloads
);
583 spin_unlock(&fou_lock
);
588 static struct genl_family fou_nl_family
= {
589 .id
= GENL_ID_GENERATE
,
591 .name
= FOU_GENL_NAME
,
592 .version
= FOU_GENL_VERSION
,
593 .maxattr
= FOU_ATTR_MAX
,
597 static struct nla_policy fou_nl_policy
[FOU_ATTR_MAX
+ 1] = {
598 [FOU_ATTR_PORT
] = { .type
= NLA_U16
, },
599 [FOU_ATTR_AF
] = { .type
= NLA_U8
, },
600 [FOU_ATTR_IPPROTO
] = { .type
= NLA_U8
, },
601 [FOU_ATTR_TYPE
] = { .type
= NLA_U8
, },
604 static int parse_nl_config(struct genl_info
*info
,
607 memset(cfg
, 0, sizeof(*cfg
));
609 cfg
->udp_config
.family
= AF_INET
;
611 if (info
->attrs
[FOU_ATTR_AF
]) {
612 u8 family
= nla_get_u8(info
->attrs
[FOU_ATTR_AF
]);
614 if (family
!= AF_INET
&& family
!= AF_INET6
)
617 cfg
->udp_config
.family
= family
;
620 if (info
->attrs
[FOU_ATTR_PORT
]) {
621 u16 port
= nla_get_u16(info
->attrs
[FOU_ATTR_PORT
]);
623 cfg
->udp_config
.local_udp_port
= port
;
626 if (info
->attrs
[FOU_ATTR_IPPROTO
])
627 cfg
->protocol
= nla_get_u8(info
->attrs
[FOU_ATTR_IPPROTO
]);
629 if (info
->attrs
[FOU_ATTR_TYPE
])
630 cfg
->type
= nla_get_u8(info
->attrs
[FOU_ATTR_TYPE
]);
635 static int fou_nl_cmd_add_port(struct sk_buff
*skb
, struct genl_info
*info
)
640 err
= parse_nl_config(info
, &cfg
);
644 return fou_create(&init_net
, &cfg
, NULL
);
647 static int fou_nl_cmd_rm_port(struct sk_buff
*skb
, struct genl_info
*info
)
651 parse_nl_config(info
, &cfg
);
653 return fou_destroy(&init_net
, &cfg
);
656 static const struct genl_ops fou_nl_ops
[] = {
659 .doit
= fou_nl_cmd_add_port
,
660 .policy
= fou_nl_policy
,
661 .flags
= GENL_ADMIN_PERM
,
665 .doit
= fou_nl_cmd_rm_port
,
666 .policy
= fou_nl_policy
,
667 .flags
= GENL_ADMIN_PERM
,
671 static void fou_build_udp(struct sk_buff
*skb
, struct ip_tunnel_encap
*e
,
672 struct flowi4
*fl4
, u8
*protocol
, __be16 sport
)
676 skb_push(skb
, sizeof(struct udphdr
));
677 skb_reset_transport_header(skb
);
683 uh
->len
= htons(skb
->len
);
685 udp_set_csum(!(e
->flags
& TUNNEL_ENCAP_FLAG_CSUM
), skb
,
686 fl4
->saddr
, fl4
->daddr
, skb
->len
);
688 *protocol
= IPPROTO_UDP
;
691 int fou_build_header(struct sk_buff
*skb
, struct ip_tunnel_encap
*e
,
692 u8
*protocol
, struct flowi4
*fl4
)
694 bool csum
= !!(e
->flags
& TUNNEL_ENCAP_FLAG_CSUM
);
695 int type
= csum
? SKB_GSO_UDP_TUNNEL_CSUM
: SKB_GSO_UDP_TUNNEL
;
698 skb
= iptunnel_handle_offloads(skb
, csum
, type
);
703 sport
= e
->sport
? : udp_flow_src_port(dev_net(skb
->dev
),
705 fou_build_udp(skb
, e
, fl4
, protocol
, sport
);
709 EXPORT_SYMBOL(fou_build_header
);
711 int gue_build_header(struct sk_buff
*skb
, struct ip_tunnel_encap
*e
,
712 u8
*protocol
, struct flowi4
*fl4
)
714 bool csum
= !!(e
->flags
& TUNNEL_ENCAP_FLAG_CSUM
);
715 int type
= csum
? SKB_GSO_UDP_TUNNEL_CSUM
: SKB_GSO_UDP_TUNNEL
;
716 struct guehdr
*guehdr
;
717 size_t hdrlen
, optlen
= 0;
720 bool need_priv
= false;
722 if ((e
->flags
& TUNNEL_ENCAP_FLAG_REMCSUM
) &&
723 skb
->ip_summed
== CHECKSUM_PARTIAL
) {
725 optlen
+= GUE_PLEN_REMCSUM
;
726 type
|= SKB_GSO_TUNNEL_REMCSUM
;
730 optlen
+= need_priv
? GUE_LEN_PRIV
: 0;
732 skb
= iptunnel_handle_offloads(skb
, csum
, type
);
737 /* Get source port (based on flow hash) before skb_push */
738 sport
= e
->sport
? : udp_flow_src_port(dev_net(skb
->dev
),
741 hdrlen
= sizeof(struct guehdr
) + optlen
;
743 skb_push(skb
, hdrlen
);
745 guehdr
= (struct guehdr
*)skb
->data
;
749 guehdr
->hlen
= optlen
>> 2;
751 guehdr
->proto_ctype
= *protocol
;
756 __be32
*flags
= data
;
758 guehdr
->flags
|= GUE_FLAG_PRIV
;
760 data
+= GUE_LEN_PRIV
;
762 if (type
& SKB_GSO_TUNNEL_REMCSUM
) {
763 u16 csum_start
= skb_checksum_start_offset(skb
);
766 if (csum_start
< hdrlen
)
769 csum_start
-= hdrlen
;
770 pd
[0] = htons(csum_start
);
771 pd
[1] = htons(csum_start
+ skb
->csum_offset
);
773 if (!skb_is_gso(skb
)) {
774 skb
->ip_summed
= CHECKSUM_NONE
;
775 skb
->encapsulation
= 0;
778 *flags
|= GUE_PFLAG_REMCSUM
;
779 data
+= GUE_PLEN_REMCSUM
;
784 fou_build_udp(skb
, e
, fl4
, protocol
, sport
);
788 EXPORT_SYMBOL(gue_build_header
);
790 static int __init
fou_init(void)
794 ret
= genl_register_family_with_ops(&fou_nl_family
,
800 static void __exit
fou_fini(void)
802 struct fou
*fou
, *next
;
804 genl_unregister_family(&fou_nl_family
);
806 /* Close all the FOU sockets */
808 spin_lock(&fou_lock
);
809 list_for_each_entry_safe(fou
, next
, &fou_list
, list
)
811 spin_unlock(&fou_lock
);
814 module_init(fou_init
);
815 module_exit(fou_fini
);
816 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
817 MODULE_LICENSE("GPL");