Merge remote-tracking branch 'spi/fix/core' into spi-linus
[deliverable/linux.git] / net / kcm / kcmsock.c
index f938d7d3e6e2a2d851922db48cfabedd442665ad..40662d73204f73356a76a108b6e82ce6e5454c86 100644 (file)
@@ -55,6 +55,8 @@ static void kcm_abort_rx_psock(struct kcm_psock *psock, int err,
 
        /* Unrecoverable error in receive */
 
+       del_timer(&psock->rx_msg_timer);
+
        if (psock->rx_stopped)
                return;
 
@@ -351,6 +353,12 @@ static void unreserve_rx_kcm(struct kcm_psock *psock,
        spin_unlock_bh(&mux->rx_lock);
 }
 
+static void kcm_start_rx_timer(struct kcm_psock *psock)
+{
+       if (psock->sk->sk_rcvtimeo)
+               mod_timer(&psock->rx_msg_timer, psock->sk->sk_rcvtimeo);
+}
+
 /* Macro to invoke filter function. */
 #define KCM_RUN_FILTER(prog, ctx) \
        (*prog->bpf_func)(ctx, prog->insnsi)
@@ -375,6 +383,19 @@ static int kcm_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
        if (head) {
                /* Message already in progress */
 
+               rxm = kcm_rx_msg(head);
+               if (unlikely(rxm->early_eaten)) {
+                       /* Already some number of bytes on the receive sock
+                        * data saved in rx_skb_head, just indicate they
+                        * are consumed.
+                        */
+                       eaten = orig_len <= rxm->early_eaten ?
+                               orig_len : rxm->early_eaten;
+                       rxm->early_eaten -= eaten;
+
+                       return eaten;
+               }
+
                if (unlikely(orig_offset)) {
                        /* Getting data with a non-zero offset when a message is
                         * in progress is not expected. If it does happen, we
@@ -487,11 +508,22 @@ static int kcm_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
 
                        if (!len) {
                                /* Need more header to determine length */
+                               if (!rxm->accum_len) {
+                                       /* Start RX timer for new message */
+                                       kcm_start_rx_timer(psock);
+                               }
                                rxm->accum_len += cand_len;
                                eaten += cand_len;
                                KCM_STATS_INCR(psock->stats.rx_need_more_hdr);
                                WARN_ON(eaten != orig_len);
                                break;
+                       } else if (len > psock->sk->sk_rcvbuf) {
+                               /* Message length exceeds maximum allowed */
+                               KCM_STATS_INCR(psock->stats.rx_msg_too_big);
+                               desc->error = -EMSGSIZE;
+                               psock->rx_skb_head = NULL;
+                               kcm_abort_rx_psock(psock, EMSGSIZE, head);
+                               break;
                        } else if (len <= (ssize_t)head->len -
                                          skb->len - rxm->offset) {
                                /* Length must be into new skb (and also
@@ -511,6 +543,28 @@ static int kcm_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
 
                if (extra < 0) {
                        /* Message not complete yet. */
+                       if (rxm->full_len - rxm->accum_len >
+                           tcp_inq(psock->sk)) {
+                               /* Don't have the whole messages in the socket
+                                * buffer. Set psock->rx_need_bytes to wait for
+                                * the rest of the message. Also, set "early
+                                * eaten" since we've already buffered the skb
+                                * but don't consume yet per tcp_read_sock.
+                                */
+
+                               if (!rxm->accum_len) {
+                                       /* Start RX timer for new message */
+                                       kcm_start_rx_timer(psock);
+                               }
+
+                               psock->rx_need_bytes = rxm->full_len -
+                                                      rxm->accum_len;
+                               rxm->accum_len += cand_len;
+                               rxm->early_eaten = cand_len;
+                               KCM_STATS_ADD(psock->stats.rx_bytes, cand_len);
+                               desc->count = 0; /* Stop reading socket */
+                               break;
+                       }
                        rxm->accum_len += cand_len;
                        eaten += cand_len;
                        WARN_ON(eaten != orig_len);
@@ -526,6 +580,7 @@ static int kcm_tcp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
                eaten += (cand_len - extra);
 
                /* Hurray, we have a new message! */
+               del_timer(&psock->rx_msg_timer);
                psock->rx_skb_head = NULL;
                KCM_STATS_INCR(psock->stats.rx_msgs);
 
@@ -582,6 +637,13 @@ static void psock_tcp_data_ready(struct sock *sk)
        if (psock->ready_rx_msg)
                goto out;
 
+       if (psock->rx_need_bytes) {
+               if (tcp_inq(sk) >= psock->rx_need_bytes)
+                       psock->rx_need_bytes = 0;
+               else
+                       goto out;
+       }
+
        if (psock_tcp_read_sock(psock) == -ENOMEM)
                queue_delayed_work(kcm_wq, &psock->rx_delayed_work, 0);
 
@@ -990,6 +1052,149 @@ static void kcm_push(struct kcm_sock *kcm)
                kcm_write_msgs(kcm);
 }
 
+static ssize_t kcm_sendpage(struct socket *sock, struct page *page,
+                           int offset, size_t size, int flags)
+
+{
+       struct sock *sk = sock->sk;
+       struct kcm_sock *kcm = kcm_sk(sk);
+       struct sk_buff *skb = NULL, *head = NULL;
+       long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+       bool eor;
+       int err = 0;
+       int i;
+
+       if (flags & MSG_SENDPAGE_NOTLAST)
+               flags |= MSG_MORE;
+
+       /* No MSG_EOR from splice, only look at MSG_MORE */
+       eor = !(flags & MSG_MORE);
+
+       lock_sock(sk);
+
+       sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
+
+       err = -EPIPE;
+       if (sk->sk_err)
+               goto out_error;
+
+       if (kcm->seq_skb) {
+               /* Previously opened message */
+               head = kcm->seq_skb;
+               skb = kcm_tx_msg(head)->last_skb;
+               i = skb_shinfo(skb)->nr_frags;
+
+               if (skb_can_coalesce(skb, i, page, offset)) {
+                       skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size);
+                       skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
+                       goto coalesced;
+               }
+
+               if (i >= MAX_SKB_FRAGS) {
+                       struct sk_buff *tskb;
+
+                       tskb = alloc_skb(0, sk->sk_allocation);
+                       while (!tskb) {
+                               kcm_push(kcm);
+                               err = sk_stream_wait_memory(sk, &timeo);
+                               if (err)
+                                       goto out_error;
+                       }
+
+                       if (head == skb)
+                               skb_shinfo(head)->frag_list = tskb;
+                       else
+                               skb->next = tskb;
+
+                       skb = tskb;
+                       skb->ip_summed = CHECKSUM_UNNECESSARY;
+                       i = 0;
+               }
+       } else {
+               /* Call the sk_stream functions to manage the sndbuf mem. */
+               if (!sk_stream_memory_free(sk)) {
+                       kcm_push(kcm);
+                       set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+                       err = sk_stream_wait_memory(sk, &timeo);
+                       if (err)
+                               goto out_error;
+               }
+
+               head = alloc_skb(0, sk->sk_allocation);
+               while (!head) {
+                       kcm_push(kcm);
+                       err = sk_stream_wait_memory(sk, &timeo);
+                       if (err)
+                               goto out_error;
+               }
+
+               skb = head;
+               i = 0;
+       }
+
+       get_page(page);
+       skb_fill_page_desc(skb, i, page, offset, size);
+       skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
+
+coalesced:
+       skb->len += size;
+       skb->data_len += size;
+       skb->truesize += size;
+       sk->sk_wmem_queued += size;
+       sk_mem_charge(sk, size);
+
+       if (head != skb) {
+               head->len += size;
+               head->data_len += size;
+               head->truesize += size;
+       }
+
+       if (eor) {
+               bool not_busy = skb_queue_empty(&sk->sk_write_queue);
+
+               /* Message complete, queue it on send buffer */
+               __skb_queue_tail(&sk->sk_write_queue, head);
+               kcm->seq_skb = NULL;
+               KCM_STATS_INCR(kcm->stats.tx_msgs);
+
+               if (flags & MSG_BATCH) {
+                       kcm->tx_wait_more = true;
+               } else if (kcm->tx_wait_more || not_busy) {
+                       err = kcm_write_msgs(kcm);
+                       if (err < 0) {
+                               /* We got a hard error in write_msgs but have
+                                * already queued this message. Report an error
+                                * in the socket, but don't affect return value
+                                * from sendmsg
+                                */
+                               pr_warn("KCM: Hard failure on kcm_write_msgs\n");
+                               report_csk_error(&kcm->sk, -err);
+                       }
+               }
+       } else {
+               /* Message not complete, save state */
+               kcm->seq_skb = head;
+               kcm_tx_msg(head)->last_skb = skb;
+       }
+
+       KCM_STATS_ADD(kcm->stats.tx_bytes, size);
+
+       release_sock(sk);
+       return size;
+
+out_error:
+       kcm_push(kcm);
+
+       err = sk_stream_error(sk, flags, err);
+
+       /* make sure we wake any epoll edge trigger waiter */
+       if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
+               sk->sk_write_space(sk);
+
+       release_sock(sk);
+       return err;
+}
+
 static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 {
        struct sock *sk = sock->sk;
@@ -1256,6 +1461,76 @@ out:
        return copied ? : err;
 }
 
+static ssize_t kcm_sock_splice(struct sock *sk,
+                              struct pipe_inode_info *pipe,
+                              struct splice_pipe_desc *spd)
+{
+       int ret;
+
+       release_sock(sk);
+       ret = splice_to_pipe(pipe, spd);
+       lock_sock(sk);
+
+       return ret;
+}
+
+static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
+                              struct pipe_inode_info *pipe, size_t len,
+                              unsigned int flags)
+{
+       struct sock *sk = sock->sk;
+       struct kcm_sock *kcm = kcm_sk(sk);
+       long timeo;
+       struct kcm_rx_msg *rxm;
+       int err = 0;
+       size_t copied;
+       struct sk_buff *skb;
+
+       /* Only support splice for SOCKSEQPACKET */
+
+       timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+
+       lock_sock(sk);
+
+       skb = kcm_wait_data(sk, flags, timeo, &err);
+       if (!skb)
+               goto err_out;
+
+       /* Okay, have a message on the receive queue */
+
+       rxm = kcm_rx_msg(skb);
+
+       if (len > rxm->full_len)
+               len = rxm->full_len;
+
+       copied = skb_splice_bits(skb, sk, rxm->offset, pipe, len, flags,
+                                kcm_sock_splice);
+       if (copied < 0) {
+               err = copied;
+               goto err_out;
+       }
+
+       KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
+
+       rxm->offset += copied;
+       rxm->full_len -= copied;
+
+       /* We have no way to return MSG_EOR. If all the bytes have been
+        * read we still leave the message in the receive socket buffer.
+        * A subsequent recvmsg needs to be done to return MSG_EOR and
+        * finish reading the message.
+        */
+
+       release_sock(sk);
+
+       return copied;
+
+err_out:
+       release_sock(sk);
+
+       return err;
+}
+
 /* kcm sock lock held */
 static void kcm_recv_disable(struct kcm_sock *kcm)
 {
@@ -1399,6 +1674,15 @@ static void init_kcm_sock(struct kcm_sock *kcm, struct kcm_mux *mux)
        spin_unlock_bh(&mux->rx_lock);
 }
 
+static void kcm_rx_msg_timeout(unsigned long arg)
+{
+       struct kcm_psock *psock = (struct kcm_psock *)arg;
+
+       /* Message assembly timed out */
+       KCM_STATS_INCR(psock->stats.rx_msg_timeouts);
+       kcm_abort_rx_psock(psock, ETIMEDOUT, NULL);
+}
+
 static int kcm_attach(struct socket *sock, struct socket *csock,
                      struct bpf_prog *prog)
 {
@@ -1428,6 +1712,10 @@ static int kcm_attach(struct socket *sock, struct socket *csock,
        psock->mux = mux;
        psock->sk = csk;
        psock->bpf_prog = prog;
+
+       setup_timer(&psock->rx_msg_timer, kcm_rx_msg_timeout,
+                   (unsigned long)psock);
+
        INIT_WORK(&psock->rx_work, psock_rx_work);
        INIT_DELAYED_WORK(&psock->rx_delayed_work, psock_rx_delayed_work);
 
@@ -1539,6 +1827,7 @@ static void kcm_unattach(struct kcm_psock *psock)
 
        write_unlock_bh(&csk->sk_callback_lock);
 
+       del_timer_sync(&psock->rx_msg_timer);
        cancel_work_sync(&psock->rx_work);
        cancel_delayed_work_sync(&psock->rx_delayed_work);
 
@@ -1907,7 +2196,7 @@ static int kcm_release(struct socket *sock)
        return 0;
 }
 
-static const struct proto_ops kcm_ops = {
+static const struct proto_ops kcm_dgram_ops = {
        .family =       PF_KCM,
        .owner =        THIS_MODULE,
        .release =      kcm_release,
@@ -1925,7 +2214,29 @@ static const struct proto_ops kcm_ops = {
        .sendmsg =      kcm_sendmsg,
        .recvmsg =      kcm_recvmsg,
        .mmap =         sock_no_mmap,
-       .sendpage =     sock_no_sendpage,
+       .sendpage =     kcm_sendpage,
+};
+
+static const struct proto_ops kcm_seqpacket_ops = {
+       .family =       PF_KCM,
+       .owner =        THIS_MODULE,
+       .release =      kcm_release,
+       .bind =         sock_no_bind,
+       .connect =      sock_no_connect,
+       .socketpair =   sock_no_socketpair,
+       .accept =       sock_no_accept,
+       .getname =      sock_no_getname,
+       .poll =         datagram_poll,
+       .ioctl =        kcm_ioctl,
+       .listen =       sock_no_listen,
+       .shutdown =     sock_no_shutdown,
+       .setsockopt =   kcm_setsockopt,
+       .getsockopt =   kcm_getsockopt,
+       .sendmsg =      kcm_sendmsg,
+       .recvmsg =      kcm_recvmsg,
+       .mmap =         sock_no_mmap,
+       .sendpage =     kcm_sendpage,
+       .splice_read =  kcm_splice_read,
 };
 
 /* Create proto operation for kcm sockets */
@@ -1938,8 +2249,10 @@ static int kcm_create(struct net *net, struct socket *sock,
 
        switch (sock->type) {
        case SOCK_DGRAM:
+               sock->ops = &kcm_dgram_ops;
+               break;
        case SOCK_SEQPACKET:
-               sock->ops = &kcm_ops;
+               sock->ops = &kcm_seqpacket_ops;
                break;
        default:
                return -ESOCKTNOSUPPORT;
This page took 0.044362 seconds and 5 git commands to generate.