From 24d2eaa34a3aaece4afa4846bb08781ecb98fc20 Mon Sep 17 00:00:00 2001 From: Theodore Dubois Date: Sat, 13 Jul 2019 21:51:03 -0700 Subject: [PATCH] Implement abstract unix domain sockets --- fs/fd.h | 5 +- fs/sock.c | 157 ++++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 128 insertions(+), 34 deletions(-) diff --git a/fs/fd.h b/fs/fd.h index 672f5112..253b50b9 100644 --- a/fs/fd.h +++ b/fs/fd.h @@ -43,8 +43,11 @@ struct fd { int domain; int type; int protocol; + // These are only used as strong references, to keep the inode + // alive while there is a listener. struct inode_data *unix_name_inode; - struct inode_data *unix_peer_inode; + struct unix_abstract *unix_name_abstract; + // TODO add a field for unix socket name } socket; }; // fs data diff --git a/fs/sock.c b/fs/sock.c index 61d79f8f..504c0b76 100644 --- a/fs/sock.c +++ b/fs/sock.c @@ -69,17 +69,26 @@ static struct fd *sock_getfd(fd_t sock_fd) { return sock; } -static struct inode_data *unix_socket_get(const char *path_raw, int flag) { +static uint32_t unix_socket_next_id() { + static uint32_t next_id = 0; + static lock_t next_id_lock = LOCK_INITIALIZER; + lock(&next_id_lock); + uint32_t id = ++next_id; + unlock(&next_id_lock); + return id; +} + +static int unix_socket_get(const char *path_raw, struct fd *bind_fd, uint32_t *socket_id) { char path[MAX_PATH]; int err = path_normalize(AT_PWD, path_raw, path, true); if (err < 0) - return ERR_PTR(err); + return err; struct mount *mount = find_mount_and_trim_path(path); struct statbuf stat; err = mount->fs->stat(mount, path, &stat, true); // If bind was called, there are some funny semantics. - if (flag & O_CREAT_) { + if (bind_fd != NULL) { // If the file exists, fail. if (err == 0) { err = _EADDRINUSE; @@ -103,7 +112,7 @@ static struct inode_data *unix_socket_get(const char *path_raw, int flag) { // If something other than bind was called, just do the obvious thing and // fail if stat failed. - if (!(flag & O_CREAT_) && err < 0) + if (bind_fd == NULL && err < 0) goto out; if (!S_ISSOCK(stat.mode)) { @@ -114,26 +123,94 @@ static struct inode_data *unix_socket_get(const char *path_raw, int flag) { // Look up the socket ID for the inode number. struct inode_data *inode = inode_get(mount, stat.inode); lock(&inode->lock); - if (inode->socket_id == 0) { - static uint32_t next_socket_id = 0; - static lock_t next_socket_id_lock = LOCK_INITIALIZER; - lock(&next_socket_id_lock); - inode->socket_id = ++next_socket_id; - unlock(&next_socket_id_lock); - } + if (inode->socket_id == 0) + inode->socket_id = unix_socket_next_id(); unlock(&inode->lock); + *socket_id = inode->socket_id; mount_release(mount); - return inode; + if (bind_fd != NULL) + bind_fd->socket.unix_name_inode = inode; + else + inode_release(inode); + return 0; out: mount_release(mount); - return ERR_PTR(err); + return err; +} + +// Dan Bernstein's simple and decently effective hash function +static uint32_t str_hash(const char *str) { + uint32_t hash = 5381; + for (int i = 0; str[i] != '\0'; i++) { + hash = 33 * hash ^ str[i]; + } + return hash; +} + +// The abstract socket namespace is a lot simpler than it sounds: if the first +// byte of the path is a null byte, then it gets looked up in this hashtable +// instead of the filesystem. + +struct unix_abstract { + unsigned refcount; + uint32_t hash; + uint32_t socket_id; + struct list links; +}; +#define ABSTRACT_HASH_SIZE 1024 +static struct list abstract_hash[ABSTRACT_HASH_SIZE]; +static lock_t unix_abstract_lock = LOCK_INITIALIZER; + +static int unix_abstract_get(const char *name, struct fd *bind_fd, uint32_t *socket_id) { + uint32_t hash = str_hash(name); + lock(&unix_abstract_lock); + struct unix_abstract *sock_tmp; + struct unix_abstract *sock = NULL; + struct list *bucket = &abstract_hash[hash % ABSTRACT_HASH_SIZE]; + if (list_null(bucket)) + list_init(bucket); + list_for_each_entry(bucket, sock_tmp, links) { + if (sock_tmp->hash == hash) { + sock = sock_tmp; + break; + } + } + + if (bind_fd != NULL && sock != NULL) + return _EEXIST; + if (bind_fd == NULL && sock == NULL) + return _ENOENT; + + if (sock == NULL) { + sock = malloc(sizeof(struct unix_abstract)); + sock->refcount = 0; + sock->hash = hash; + sock->socket_id = unix_socket_next_id(); + list_add(bucket, &sock->links); + } + + sock->refcount++; + unlock(&unix_abstract_lock); + *socket_id = sock->socket_id; + if (bind_fd != NULL) + bind_fd->socket.unix_name_abstract = sock; + return 0; +} + +static void unix_abstract_release(struct unix_abstract *name) { + lock(&unix_abstract_lock); + if (--name->refcount == 0) { + list_remove(&name->links); + free(name); + } + unlock(&unix_abstract_lock); } const char *sock_tmp_prefix = "/tmp/ishsock"; -static int sockaddr_read_get_inode(addr_t sockaddr_addr, void *sockaddr, uint_t *sockaddr_len, struct inode_data **inode_out, int flag) { +static int sockaddr_read_bind(addr_t sockaddr_addr, void *sockaddr, uint_t *sockaddr_len, struct fd *bind_fd) { // Make sure we can read things without overflowing buffers if (*sockaddr_len < sizeof(socklen_t)) return _EINVAL; @@ -159,22 +236,31 @@ static int sockaddr_read_get_inode(addr_t sockaddr_addr, void *sockaddr, uint_t case PF_LOCAL: { // First pull out the path, being careful to not overflow anything. char path[sizeof(struct sockaddr_max_) - offsetof(struct sockaddr_max_, data) + 1]; // big enough - size_t addr_path_size = *sockaddr_len - offsetof(struct sockaddr_, data); - memcpy(path, fake_addr->data, addr_path_size); - path[addr_path_size] = '\0'; + size_t path_size = *sockaddr_len - offsetof(struct sockaddr_, data); + memcpy(path, fake_addr->data, path_size); + path[path_size] = '\0'; - struct inode_data *inode = unix_socket_get(path, flag); - if (IS_ERR(inode)) - return PTR_ERR(inode); - *inode_out = inode; + uint32_t socket_id; + int err; + if (path_size == 0) { + return _ENOENT; + } else if (path[0] != '\0') { + STRACE(" unix socket %s", path); + err = unix_socket_get(path, bind_fd, &socket_id); + } else { + STRACE(" unix abstract socket %s", path + 1); + err = unix_abstract_get(path + 1, bind_fd, &socket_id); + } + if (err < 0) + return err; struct sockaddr_un *real_addr_un = sockaddr; - size_t path_len = sprintf(real_addr_un->sun_path, "%s%d.%d", sock_tmp_prefix, getpid(), inode->socket_id); + size_t path_len = sprintf(real_addr_un->sun_path, "%s%d.%u", sock_tmp_prefix, getpid(), socket_id); // The call to real bind will fail if the backing socket already // exists from a previous run or something. We already checked that // the fake file doesn't exist in unix_socket_get, so try a simple // solution. - if (flag & O_CREAT_) + if (bind_fd != NULL) unlink(real_addr_un->sun_path); *sockaddr_len = offsetof(struct sockaddr_un, sun_path) + path_len; break; @@ -187,7 +273,7 @@ static int sockaddr_read_get_inode(addr_t sockaddr_addr, void *sockaddr, uint_t static int sockaddr_read(addr_t sockaddr_addr, void *sockaddr, uint_t *sockaddr_len) { struct inode_data *inode = NULL; - int err = sockaddr_read_get_inode(sockaddr_addr, sockaddr, sockaddr_len, &inode, 0); + int err = sockaddr_read_bind(sockaddr_addr, sockaddr, sockaddr_len, NULL); inode_release_if_exist(inode); return err; } @@ -232,13 +318,15 @@ int_t sys_bind(fd_t sock_fd, addr_t sockaddr_addr, uint_t sockaddr_len) { return _EBADF; struct sockaddr_max_ sockaddr; struct inode_data *inode = NULL; - int err = sockaddr_read_get_inode(sockaddr_addr, &sockaddr, &sockaddr_len, &inode, O_CREAT_); + int err = sockaddr_read_bind(sockaddr_addr, &sockaddr, &sockaddr_len, sock); if (err < 0) return err; err = bind(sock->real_fd, (void *) &sockaddr, sockaddr_len); if (err < 0) { - inode_release_if_exist(inode); + inode_release_if_exist(sock->socket.unix_name_inode); + if (sock->socket.unix_name_abstract != NULL) + unix_abstract_release(sock->socket.unix_name_abstract); return errno_map(); } sock->socket.unix_name_inode = inode; @@ -251,17 +339,13 @@ int_t sys_connect(fd_t sock_fd, addr_t sockaddr_addr, uint_t sockaddr_len) { if (sock == NULL) return _EBADF; struct sockaddr_max_ sockaddr; - struct inode_data *inode = NULL; - int err = sockaddr_read_get_inode(sockaddr_addr, &sockaddr, &sockaddr_len, &inode, 0); + int err = sockaddr_read(sockaddr_addr, &sockaddr, &sockaddr_len); if (err < 0) return err; err = connect(sock->real_fd, (void *) &sockaddr, sockaddr_len); - if (err < 0) { - inode_release_if_exist(inode); + if (err < 0) return errno_map(); - } - sock->socket.unix_peer_inode = inode; return err; } @@ -325,6 +409,8 @@ int_t sys_getsockname(fd_t sock_fd, addr_t sockaddr_addr, addr_t sockaddr_len_ad if (user_get(sockaddr_len_addr, sockaddr_len)) return _EFAULT; + // TODO if this is a unix socket, return the same string passed to bind + char sockaddr[sockaddr_len]; int res = getsockname(sock->real_fd, (void *) sockaddr, &sockaddr_len); if (res < 0) @@ -347,6 +433,9 @@ int_t sys_getpeername(fd_t sock_fd, addr_t sockaddr_addr, addr_t sockaddr_len_ad if (user_get(sockaddr_len_addr, sockaddr_len)) return _EFAULT; + // TODO if this is a unix socket, return the same string the peer passed to + // bind once the peer pointer is available + char sockaddr[sockaddr_len]; int res = getpeername(sock->real_fd, (void *) sockaddr, &sockaddr_len); if (res < 0) @@ -774,8 +863,10 @@ static ssize_t sock_write(struct fd *fd, const void *buf, size_t size) { static int sock_close(struct fd *fd) { sockrestart_end_listen(fd); + // FIXME next 3 lines should go in a function like release_unix_names inode_release_if_exist(fd->socket.unix_name_inode); - inode_release_if_exist(fd->socket.unix_peer_inode); + if (fd->socket.unix_name_abstract != NULL) + unix_abstract_release(fd->socket.unix_name_abstract); return realfs_close(fd); }