diff --git a/core/jni/android_net_LocalSocketImpl.cpp b/core/jni/android_net_LocalSocketImpl.cpp index a1f2377041e89..1163b860977d8 100644 --- a/core/jni/android_net_LocalSocketImpl.cpp +++ b/core/jni/android_net_LocalSocketImpl.cpp @@ -33,14 +33,16 @@ #include #include +#include +#include #include #include #include -namespace android { +using android::base::ReceiveFileDescriptorVector; +using android::base::SendFileDescriptorVector; -template -void UNUSED(T t) {} +namespace android { static jfieldID field_inboundFileDescriptors; static jfieldID field_outboundFileDescriptors; @@ -117,67 +119,6 @@ socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor, } } -/** - * Processes ancillary data, handling only - * SCM_RIGHTS. Creates appropriate objects and sets appropriate - * fields in the LocalSocketImpl object. Returns 0 on success - * or -1 if an exception was thrown. - */ -static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg) -{ - struct cmsghdr *cmsgptr; - - for (cmsgptr = CMSG_FIRSTHDR(pMsg); - cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) { - - if (cmsgptr->cmsg_level != SOL_SOCKET) { - continue; - } - - if (cmsgptr->cmsg_type == SCM_RIGHTS) { - int *pDescriptors = (int *)CMSG_DATA(cmsgptr); - jobjectArray fdArray; - int count - = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int)); - - if (count < 0) { - jniThrowException(env, "java/io/IOException", - "invalid cmsg length"); - return -1; - } - - fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL); - - if (fdArray == NULL) { - return -1; - } - - for (int i = 0; i < count; i++) { - jobject fdObject - = jniCreateFileDescriptor(env, pDescriptors[i]); - - if (env->ExceptionCheck()) { - return -1; - } - - env->SetObjectArrayElement(fdArray, i, fdObject); - - if (env->ExceptionCheck()) { - return -1; - } - } - - env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray); - - if (env->ExceptionCheck()) { - return -1; - } - } - } - - return 0; -} - /** * Reads data from a socket into buf, processing any ancillary data * and adding it to thisJ. @@ -189,47 +130,48 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd, void *buffer, size_t len) { ssize_t ret; - struct msghdr msg; - struct iovec iv; - unsigned char *buf = (unsigned char *)buffer; - // Enough buffer for a pile of fd's. We throw an exception if - // this buffer is too small. - struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100]; + std::vector received_fds; - memset(&msg, 0, sizeof(msg)); - memset(&iv, 0, sizeof(iv)); - - iv.iov_base = buf; - iv.iov_len = len; - - msg.msg_iov = &iv; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC)); - - if (ret < 0 && errno == EPIPE) { - // Treat this as an end of stream - return 0; - } + ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds); if (ret < 0) { + if (errno == EPIPE) { + // Treat this as an end of stream + return 0; + } + jniThrowIOException(env, errno); return -1; } - if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) { - // To us, any of the above flags are a fatal error + if (received_fds.size() > 0) { + jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL); - jniThrowException(env, "java/io/IOException", - "Unexpected error or truncation during recvmsg()"); + if (fdArray == NULL) { + // NewObjectArray has thrown. + return -1; + } - return -1; - } + for (size_t i = 0; i < received_fds.size(); i++) { + jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get()); - if (ret >= 0) { - socket_process_cmsg(env, thisJ, &msg); + if (env->ExceptionCheck()) { + return -1; + } + + env->SetObjectArrayElement(fdArray, i, fdObject); + + if (env->ExceptionCheck()) { + return -1; + } + } + + for (auto &fd : received_fds) { + // The fds are stored in java.io.FileDescriptors now. + static_cast(fd.release()); + } + + env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray); } return ret; @@ -243,7 +185,6 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd, static int socket_write_all(JNIEnv *env, jobject object, int fd, void *buf, size_t len) { - ssize_t ret; struct msghdr msg; unsigned char *buffer = (unsigned char *)buf; memset(&msg, 0, sizeof(msg)); @@ -256,14 +197,11 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd, return -1; } - struct cmsghdr *cmsg; int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds); - int fds[countFds]; - char msgbuf[CMSG_SPACE(countFds)]; + std::vector fds; // Add any pending outbound file descriptors to the message if (outboundFds != NULL) { - if (env->ExceptionCheck()) { return -1; } @@ -274,47 +212,25 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd, return -1; } - fds[i] = jniGetFDFromFileDescriptor(env, fdObject); + fds.push_back(jniGetFDFromFileDescriptor(env, fdObject)); if (env->ExceptionCheck()) { return -1; } } - - // See "man cmsg" really - msg.msg_control = msgbuf; - msg.msg_controllen = sizeof msgbuf; - cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof fds); - memcpy(CMSG_DATA(cmsg), fds, sizeof fds); } - // We only write our msg_control during the first write - while (len > 0) { - struct iovec iv; - memset(&iv, 0, sizeof(iv)); + ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds); - iv.iov_base = buffer; - iv.iov_len = len; - - msg.msg_iov = &iv; - msg.msg_iovlen = 1; - - do { - ret = sendmsg(fd, &msg, MSG_NOSIGNAL); - } while (ret < 0 && errno == EINTR); - - if (ret < 0) { + while (rc != len) { + if (rc == -1) { jniThrowIOException(env, errno); return -1; } - buffer += ret; - len -= ret; + buffer += rc; + len -= rc; - // Wipes out any msg_control too - memset(&msg, 0, sizeof(msg)); + rc = send(fd, buffer, len, MSG_NOSIGNAL); } return 0;