Switch LocalSocket to android::base::{Send,Receive}FileDescriptorVector.
The previous implementation allocated an array of size CMSG_SPACE(count) to store CMSG_LEN(count * sizeof(int)) elements, which leads to bad things happening for values of count greater than 1 on 32-bit, and 2 on 64-bit. Test: atest android.net.LocalSocketTest Test: atest android.net.cts.LocalSocketTest Change-Id: I0a9502c3358d8fa92d2d20e344c6270d6baedc07
This commit is contained in:
@@ -33,14 +33,16 @@
|
||||
#include <unistd.h>
|
||||
#include <sys/ioctl.h>
|
||||
|
||||
#include <android-base/cmsg.h>
|
||||
#include <android-base/macros.h>
|
||||
#include <cutils/sockets.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <nativehelper/ScopedUtfChars.h>
|
||||
|
||||
namespace android {
|
||||
using android::base::ReceiveFileDescriptorVector;
|
||||
using android::base::SendFileDescriptorVector;
|
||||
|
||||
template <typename T>
|
||||
void UNUSED(T t) {}
|
||||
namespace android {
|
||||
|
||||
static jfieldID field_inboundFileDescriptors;
|
||||
static jfieldID field_outboundFileDescriptors;
|
||||
@@ -117,67 +119,6 @@ socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes ancillary data, handling only
|
||||
* SCM_RIGHTS. Creates appropriate objects and sets appropriate
|
||||
* fields in the LocalSocketImpl object. Returns 0 on success
|
||||
* or -1 if an exception was thrown.
|
||||
*/
|
||||
static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
|
||||
{
|
||||
struct cmsghdr *cmsgptr;
|
||||
|
||||
for (cmsgptr = CMSG_FIRSTHDR(pMsg);
|
||||
cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
|
||||
|
||||
if (cmsgptr->cmsg_level != SOL_SOCKET) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cmsgptr->cmsg_type == SCM_RIGHTS) {
|
||||
int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
|
||||
jobjectArray fdArray;
|
||||
int count
|
||||
= ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
|
||||
|
||||
if (count < 0) {
|
||||
jniThrowException(env, "java/io/IOException",
|
||||
"invalid cmsg length");
|
||||
return -1;
|
||||
}
|
||||
|
||||
fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
|
||||
|
||||
if (fdArray == NULL) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < count; i++) {
|
||||
jobject fdObject
|
||||
= jniCreateFileDescriptor(env, pDescriptors[i]);
|
||||
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
env->SetObjectArrayElement(fdArray, i, fdObject);
|
||||
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
|
||||
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads data from a socket into buf, processing any ancillary data
|
||||
* and adding it to thisJ.
|
||||
@@ -189,47 +130,48 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
|
||||
void *buffer, size_t len)
|
||||
{
|
||||
ssize_t ret;
|
||||
struct msghdr msg;
|
||||
struct iovec iv;
|
||||
unsigned char *buf = (unsigned char *)buffer;
|
||||
// Enough buffer for a pile of fd's. We throw an exception if
|
||||
// this buffer is too small.
|
||||
struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
|
||||
std::vector<android::base::unique_fd> received_fds;
|
||||
|
||||
memset(&msg, 0, sizeof(msg));
|
||||
memset(&iv, 0, sizeof(iv));
|
||||
|
||||
iv.iov_base = buf;
|
||||
iv.iov_len = len;
|
||||
|
||||
msg.msg_iov = &iv;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = cmsgbuf;
|
||||
msg.msg_controllen = sizeof(cmsgbuf);
|
||||
|
||||
ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
|
||||
|
||||
if (ret < 0 && errno == EPIPE) {
|
||||
// Treat this as an end of stream
|
||||
return 0;
|
||||
}
|
||||
ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds);
|
||||
|
||||
if (ret < 0) {
|
||||
if (errno == EPIPE) {
|
||||
// Treat this as an end of stream
|
||||
return 0;
|
||||
}
|
||||
|
||||
jniThrowIOException(env, errno);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
|
||||
// To us, any of the above flags are a fatal error
|
||||
if (received_fds.size() > 0) {
|
||||
jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL);
|
||||
|
||||
jniThrowException(env, "java/io/IOException",
|
||||
"Unexpected error or truncation during recvmsg()");
|
||||
if (fdArray == NULL) {
|
||||
// NewObjectArray has thrown.
|
||||
return -1;
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
for (size_t i = 0; i < received_fds.size(); i++) {
|
||||
jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get());
|
||||
|
||||
if (ret >= 0) {
|
||||
socket_process_cmsg(env, thisJ, &msg);
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
env->SetObjectArrayElement(fdArray, i, fdObject);
|
||||
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &fd : received_fds) {
|
||||
// The fds are stored in java.io.FileDescriptors now.
|
||||
static_cast<void>(fd.release());
|
||||
}
|
||||
|
||||
env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
|
||||
}
|
||||
|
||||
return ret;
|
||||
@@ -243,7 +185,6 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
|
||||
static int socket_write_all(JNIEnv *env, jobject object, int fd,
|
||||
void *buf, size_t len)
|
||||
{
|
||||
ssize_t ret;
|
||||
struct msghdr msg;
|
||||
unsigned char *buffer = (unsigned char *)buf;
|
||||
memset(&msg, 0, sizeof(msg));
|
||||
@@ -256,14 +197,11 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct cmsghdr *cmsg;
|
||||
int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
|
||||
int fds[countFds];
|
||||
char msgbuf[CMSG_SPACE(countFds)];
|
||||
std::vector<int> fds;
|
||||
|
||||
// Add any pending outbound file descriptors to the message
|
||||
if (outboundFds != NULL) {
|
||||
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
@@ -274,47 +212,25 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
|
||||
return -1;
|
||||
}
|
||||
|
||||
fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
|
||||
fds.push_back(jniGetFDFromFileDescriptor(env, fdObject));
|
||||
if (env->ExceptionCheck()) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// See "man cmsg" really
|
||||
msg.msg_control = msgbuf;
|
||||
msg.msg_controllen = sizeof msgbuf;
|
||||
cmsg = CMSG_FIRSTHDR(&msg);
|
||||
cmsg->cmsg_level = SOL_SOCKET;
|
||||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
cmsg->cmsg_len = CMSG_LEN(sizeof fds);
|
||||
memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
|
||||
}
|
||||
|
||||
// We only write our msg_control during the first write
|
||||
while (len > 0) {
|
||||
struct iovec iv;
|
||||
memset(&iv, 0, sizeof(iv));
|
||||
ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds);
|
||||
|
||||
iv.iov_base = buffer;
|
||||
iv.iov_len = len;
|
||||
|
||||
msg.msg_iov = &iv;
|
||||
msg.msg_iovlen = 1;
|
||||
|
||||
do {
|
||||
ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
|
||||
} while (ret < 0 && errno == EINTR);
|
||||
|
||||
if (ret < 0) {
|
||||
while (rc != len) {
|
||||
if (rc == -1) {
|
||||
jniThrowIOException(env, errno);
|
||||
return -1;
|
||||
}
|
||||
|
||||
buffer += ret;
|
||||
len -= ret;
|
||||
buffer += rc;
|
||||
len -= rc;
|
||||
|
||||
// Wipes out any msg_control too
|
||||
memset(&msg, 0, sizeof(msg));
|
||||
rc = send(fd, buffer, len, MSG_NOSIGNAL);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user