Implement abstract unix domain sockets

This commit is contained in:
Theodore Dubois 2019-07-13 21:51:03 -07:00
parent 1fb6a0ab53
commit 24d2eaa34a
2 changed files with 128 additions and 34 deletions

View File

@ -43,8 +43,11 @@ struct fd {
int domain;
int type;
int protocol;
// These are only used as strong references, to keep the inode
// alive while there is a listener.
struct inode_data *unix_name_inode;
struct inode_data *unix_peer_inode;
struct unix_abstract *unix_name_abstract;
// TODO add a field for unix socket name
} socket;
};
// fs data

157
fs/sock.c
View File

@ -69,17 +69,26 @@ static struct fd *sock_getfd(fd_t sock_fd) {
return sock;
}
static struct inode_data *unix_socket_get(const char *path_raw, int flag) {
static uint32_t unix_socket_next_id() {
static uint32_t next_id = 0;
static lock_t next_id_lock = LOCK_INITIALIZER;
lock(&next_id_lock);
uint32_t id = ++next_id;
unlock(&next_id_lock);
return id;
}
static int unix_socket_get(const char *path_raw, struct fd *bind_fd, uint32_t *socket_id) {
char path[MAX_PATH];
int err = path_normalize(AT_PWD, path_raw, path, true);
if (err < 0)
return ERR_PTR(err);
return err;
struct mount *mount = find_mount_and_trim_path(path);
struct statbuf stat;
err = mount->fs->stat(mount, path, &stat, true);
// If bind was called, there are some funny semantics.
if (flag & O_CREAT_) {
if (bind_fd != NULL) {
// If the file exists, fail.
if (err == 0) {
err = _EADDRINUSE;
@ -103,7 +112,7 @@ static struct inode_data *unix_socket_get(const char *path_raw, int flag) {
// If something other than bind was called, just do the obvious thing and
// fail if stat failed.
if (!(flag & O_CREAT_) && err < 0)
if (bind_fd == NULL && err < 0)
goto out;
if (!S_ISSOCK(stat.mode)) {
@ -114,26 +123,94 @@ static struct inode_data *unix_socket_get(const char *path_raw, int flag) {
// Look up the socket ID for the inode number.
struct inode_data *inode = inode_get(mount, stat.inode);
lock(&inode->lock);
if (inode->socket_id == 0) {
static uint32_t next_socket_id = 0;
static lock_t next_socket_id_lock = LOCK_INITIALIZER;
lock(&next_socket_id_lock);
inode->socket_id = ++next_socket_id;
unlock(&next_socket_id_lock);
}
if (inode->socket_id == 0)
inode->socket_id = unix_socket_next_id();
unlock(&inode->lock);
*socket_id = inode->socket_id;
mount_release(mount);
return inode;
if (bind_fd != NULL)
bind_fd->socket.unix_name_inode = inode;
else
inode_release(inode);
return 0;
out:
mount_release(mount);
return ERR_PTR(err);
return err;
}
// Dan Bernstein's simple and decently effective hash function
static uint32_t str_hash(const char *str) {
uint32_t hash = 5381;
for (int i = 0; str[i] != '\0'; i++) {
hash = 33 * hash ^ str[i];
}
return hash;
}
// The abstract socket namespace is a lot simpler than it sounds: if the first
// byte of the path is a null byte, then it gets looked up in this hashtable
// instead of the filesystem.
struct unix_abstract {
unsigned refcount;
uint32_t hash;
uint32_t socket_id;
struct list links;
};
#define ABSTRACT_HASH_SIZE 1024
static struct list abstract_hash[ABSTRACT_HASH_SIZE];
static lock_t unix_abstract_lock = LOCK_INITIALIZER;
static int unix_abstract_get(const char *name, struct fd *bind_fd, uint32_t *socket_id) {
uint32_t hash = str_hash(name);
lock(&unix_abstract_lock);
struct unix_abstract *sock_tmp;
struct unix_abstract *sock = NULL;
struct list *bucket = &abstract_hash[hash % ABSTRACT_HASH_SIZE];
if (list_null(bucket))
list_init(bucket);
list_for_each_entry(bucket, sock_tmp, links) {
if (sock_tmp->hash == hash) {
sock = sock_tmp;
break;
}
}
if (bind_fd != NULL && sock != NULL)
return _EEXIST;
if (bind_fd == NULL && sock == NULL)
return _ENOENT;
if (sock == NULL) {
sock = malloc(sizeof(struct unix_abstract));
sock->refcount = 0;
sock->hash = hash;
sock->socket_id = unix_socket_next_id();
list_add(bucket, &sock->links);
}
sock->refcount++;
unlock(&unix_abstract_lock);
*socket_id = sock->socket_id;
if (bind_fd != NULL)
bind_fd->socket.unix_name_abstract = sock;
return 0;
}
static void unix_abstract_release(struct unix_abstract *name) {
lock(&unix_abstract_lock);
if (--name->refcount == 0) {
list_remove(&name->links);
free(name);
}
unlock(&unix_abstract_lock);
}
const char *sock_tmp_prefix = "/tmp/ishsock";
static int sockaddr_read_get_inode(addr_t sockaddr_addr, void *sockaddr, uint_t *sockaddr_len, struct inode_data **inode_out, int flag) {
static int sockaddr_read_bind(addr_t sockaddr_addr, void *sockaddr, uint_t *sockaddr_len, struct fd *bind_fd) {
// Make sure we can read things without overflowing buffers
if (*sockaddr_len < sizeof(socklen_t))
return _EINVAL;
@ -159,22 +236,31 @@ static int sockaddr_read_get_inode(addr_t sockaddr_addr, void *sockaddr, uint_t
case PF_LOCAL: {
// First pull out the path, being careful to not overflow anything.
char path[sizeof(struct sockaddr_max_) - offsetof(struct sockaddr_max_, data) + 1]; // big enough
size_t addr_path_size = *sockaddr_len - offsetof(struct sockaddr_, data);
memcpy(path, fake_addr->data, addr_path_size);
path[addr_path_size] = '\0';
size_t path_size = *sockaddr_len - offsetof(struct sockaddr_, data);
memcpy(path, fake_addr->data, path_size);
path[path_size] = '\0';
struct inode_data *inode = unix_socket_get(path, flag);
if (IS_ERR(inode))
return PTR_ERR(inode);
*inode_out = inode;
uint32_t socket_id;
int err;
if (path_size == 0) {
return _ENOENT;
} else if (path[0] != '\0') {
STRACE(" unix socket %s", path);
err = unix_socket_get(path, bind_fd, &socket_id);
} else {
STRACE(" unix abstract socket %s", path + 1);
err = unix_abstract_get(path + 1, bind_fd, &socket_id);
}
if (err < 0)
return err;
struct sockaddr_un *real_addr_un = sockaddr;
size_t path_len = sprintf(real_addr_un->sun_path, "%s%d.%d", sock_tmp_prefix, getpid(), inode->socket_id);
size_t path_len = sprintf(real_addr_un->sun_path, "%s%d.%u", sock_tmp_prefix, getpid(), socket_id);
// The call to real bind will fail if the backing socket already
// exists from a previous run or something. We already checked that
// the fake file doesn't exist in unix_socket_get, so try a simple
// solution.
if (flag & O_CREAT_)
if (bind_fd != NULL)
unlink(real_addr_un->sun_path);
*sockaddr_len = offsetof(struct sockaddr_un, sun_path) + path_len;
break;
@ -187,7 +273,7 @@ static int sockaddr_read_get_inode(addr_t sockaddr_addr, void *sockaddr, uint_t
static int sockaddr_read(addr_t sockaddr_addr, void *sockaddr, uint_t *sockaddr_len) {
struct inode_data *inode = NULL;
int err = sockaddr_read_get_inode(sockaddr_addr, sockaddr, sockaddr_len, &inode, 0);
int err = sockaddr_read_bind(sockaddr_addr, sockaddr, sockaddr_len, NULL);
inode_release_if_exist(inode);
return err;
}
@ -232,13 +318,15 @@ int_t sys_bind(fd_t sock_fd, addr_t sockaddr_addr, uint_t sockaddr_len) {
return _EBADF;
struct sockaddr_max_ sockaddr;
struct inode_data *inode = NULL;
int err = sockaddr_read_get_inode(sockaddr_addr, &sockaddr, &sockaddr_len, &inode, O_CREAT_);
int err = sockaddr_read_bind(sockaddr_addr, &sockaddr, &sockaddr_len, sock);
if (err < 0)
return err;
err = bind(sock->real_fd, (void *) &sockaddr, sockaddr_len);
if (err < 0) {
inode_release_if_exist(inode);
inode_release_if_exist(sock->socket.unix_name_inode);
if (sock->socket.unix_name_abstract != NULL)
unix_abstract_release(sock->socket.unix_name_abstract);
return errno_map();
}
sock->socket.unix_name_inode = inode;
@ -251,17 +339,13 @@ int_t sys_connect(fd_t sock_fd, addr_t sockaddr_addr, uint_t sockaddr_len) {
if (sock == NULL)
return _EBADF;
struct sockaddr_max_ sockaddr;
struct inode_data *inode = NULL;
int err = sockaddr_read_get_inode(sockaddr_addr, &sockaddr, &sockaddr_len, &inode, 0);
int err = sockaddr_read(sockaddr_addr, &sockaddr, &sockaddr_len);
if (err < 0)
return err;
err = connect(sock->real_fd, (void *) &sockaddr, sockaddr_len);
if (err < 0) {
inode_release_if_exist(inode);
if (err < 0)
return errno_map();
}
sock->socket.unix_peer_inode = inode;
return err;
}
@ -325,6 +409,8 @@ int_t sys_getsockname(fd_t sock_fd, addr_t sockaddr_addr, addr_t sockaddr_len_ad
if (user_get(sockaddr_len_addr, sockaddr_len))
return _EFAULT;
// TODO if this is a unix socket, return the same string passed to bind
char sockaddr[sockaddr_len];
int res = getsockname(sock->real_fd, (void *) sockaddr, &sockaddr_len);
if (res < 0)
@ -347,6 +433,9 @@ int_t sys_getpeername(fd_t sock_fd, addr_t sockaddr_addr, addr_t sockaddr_len_ad
if (user_get(sockaddr_len_addr, sockaddr_len))
return _EFAULT;
// TODO if this is a unix socket, return the same string the peer passed to
// bind once the peer pointer is available
char sockaddr[sockaddr_len];
int res = getpeername(sock->real_fd, (void *) sockaddr, &sockaddr_len);
if (res < 0)
@ -774,8 +863,10 @@ static ssize_t sock_write(struct fd *fd, const void *buf, size_t size) {
static int sock_close(struct fd *fd) {
sockrestart_end_listen(fd);
// FIXME next 3 lines should go in a function like release_unix_names
inode_release_if_exist(fd->socket.unix_name_inode);
inode_release_if_exist(fd->socket.unix_peer_inode);
if (fd->socket.unix_name_abstract != NULL)
unix_abstract_release(fd->socket.unix_name_abstract);
return realfs_close(fd);
}