Switch LocalSocket to android::base::{Send,Receive}FileDescriptorVector.

The previous implementation allocated an array of size
CMSG_SPACE(count) to store CMSG_LEN(count * sizeof(int)) elements, which
leads to bad things happening for values of count greater than 1 on
32-bit, and 2 on 64-bit.

Test: atest android.net.LocalSocketTest
Test: atest android.net.cts.LocalSocketTest
Change-Id: I0a9502c3358d8fa92d2d20e344c6270d6baedc07
This commit is contained in:
Josh Gao
2019-02-11 14:37:21 -08:00
parent e24b30b7d4
commit 79e3be8a84

View File

@@ -33,14 +33,16 @@
#include <unistd.h>
#include <sys/ioctl.h>
#include <android-base/cmsg.h>
#include <android-base/macros.h>
#include <cutils/sockets.h>
#include <netinet/tcp.h>
#include <nativehelper/ScopedUtfChars.h>
namespace android {
using android::base::ReceiveFileDescriptorVector;
using android::base::SendFileDescriptorVector;
template <typename T>
void UNUSED(T t) {}
namespace android {
static jfieldID field_inboundFileDescriptors;
static jfieldID field_outboundFileDescriptors;
@@ -117,67 +119,6 @@ socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
}
}
/**
* Processes ancillary data, handling only
* SCM_RIGHTS. Creates appropriate objects and sets appropriate
* fields in the LocalSocketImpl object. Returns 0 on success
* or -1 if an exception was thrown.
*/
static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
{
struct cmsghdr *cmsgptr;
for (cmsgptr = CMSG_FIRSTHDR(pMsg);
cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
if (cmsgptr->cmsg_level != SOL_SOCKET) {
continue;
}
if (cmsgptr->cmsg_type == SCM_RIGHTS) {
int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
jobjectArray fdArray;
int count
= ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
if (count < 0) {
jniThrowException(env, "java/io/IOException",
"invalid cmsg length");
return -1;
}
fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
if (fdArray == NULL) {
return -1;
}
for (int i = 0; i < count; i++) {
jobject fdObject
= jniCreateFileDescriptor(env, pDescriptors[i]);
if (env->ExceptionCheck()) {
return -1;
}
env->SetObjectArrayElement(fdArray, i, fdObject);
if (env->ExceptionCheck()) {
return -1;
}
}
env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
if (env->ExceptionCheck()) {
return -1;
}
}
}
return 0;
}
/**
* Reads data from a socket into buf, processing any ancillary data
* and adding it to thisJ.
@@ -189,47 +130,48 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
void *buffer, size_t len)
{
ssize_t ret;
struct msghdr msg;
struct iovec iv;
unsigned char *buf = (unsigned char *)buffer;
// Enough buffer for a pile of fd's. We throw an exception if
// this buffer is too small.
struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
std::vector<android::base::unique_fd> received_fds;
memset(&msg, 0, sizeof(msg));
memset(&iv, 0, sizeof(iv));
iv.iov_base = buf;
iv.iov_len = len;
msg.msg_iov = &iv;
msg.msg_iovlen = 1;
msg.msg_control = cmsgbuf;
msg.msg_controllen = sizeof(cmsgbuf);
ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
if (ret < 0 && errno == EPIPE) {
// Treat this as an end of stream
return 0;
}
ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds);
if (ret < 0) {
if (errno == EPIPE) {
// Treat this as an end of stream
return 0;
}
jniThrowIOException(env, errno);
return -1;
}
if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
// To us, any of the above flags are a fatal error
if (received_fds.size() > 0) {
jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL);
jniThrowException(env, "java/io/IOException",
"Unexpected error or truncation during recvmsg()");
if (fdArray == NULL) {
// NewObjectArray has thrown.
return -1;
}
return -1;
}
for (size_t i = 0; i < received_fds.size(); i++) {
jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get());
if (ret >= 0) {
socket_process_cmsg(env, thisJ, &msg);
if (env->ExceptionCheck()) {
return -1;
}
env->SetObjectArrayElement(fdArray, i, fdObject);
if (env->ExceptionCheck()) {
return -1;
}
}
for (auto &fd : received_fds) {
// The fds are stored in java.io.FileDescriptors now.
static_cast<void>(fd.release());
}
env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
}
return ret;
@@ -243,7 +185,6 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
static int socket_write_all(JNIEnv *env, jobject object, int fd,
void *buf, size_t len)
{
ssize_t ret;
struct msghdr msg;
unsigned char *buffer = (unsigned char *)buf;
memset(&msg, 0, sizeof(msg));
@@ -256,14 +197,11 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
return -1;
}
struct cmsghdr *cmsg;
int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
int fds[countFds];
char msgbuf[CMSG_SPACE(countFds)];
std::vector<int> fds;
// Add any pending outbound file descriptors to the message
if (outboundFds != NULL) {
if (env->ExceptionCheck()) {
return -1;
}
@@ -274,47 +212,25 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
return -1;
}
fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
fds.push_back(jniGetFDFromFileDescriptor(env, fdObject));
if (env->ExceptionCheck()) {
return -1;
}
}
// See "man cmsg" really
msg.msg_control = msgbuf;
msg.msg_controllen = sizeof msgbuf;
cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(sizeof fds);
memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
}
// We only write our msg_control during the first write
while (len > 0) {
struct iovec iv;
memset(&iv, 0, sizeof(iv));
ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds);
iv.iov_base = buffer;
iv.iov_len = len;
msg.msg_iov = &iv;
msg.msg_iovlen = 1;
do {
ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
} while (ret < 0 && errno == EINTR);
if (ret < 0) {
while (rc != len) {
if (rc == -1) {
jniThrowIOException(env, errno);
return -1;
}
buffer += ret;
len -= ret;
buffer += rc;
len -= rc;
// Wipes out any msg_control too
memset(&msg, 0, sizeof(msg));
rc = send(fd, buffer, len, MSG_NOSIGNAL);
}
return 0;