__ublk_check_and_get_req() attempts to atomically look up the struct request for a ublk I/O and take a reference on it. However, the request can be freed between the lookup on the tagset in blk_mq_tag_to_rq() and the increment of its reference count in ublk_get_req_ref(), for example if an elevator switch happens concurrently. Fix the potential use after free by moving the reference count from ublk_rq_data to ublk_io. Move the fields buf_index and buf_ctx_handle too to reduce the number of cache lines touched when dispatching and completing a ublk I/O, allowing ublk_rq_data to be removed entirely. Suggested-by: Ming Lei <ming.lei@xxxxxxxxxx> Signed-off-by: Caleb Sander Mateos <csander@xxxxxxxxxxxxxxx> Fixes: 62fe99cef94a ("ublk: add read()/write() support for ublk char device") --- drivers/block/ublk_drv.c | 131 +++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 60 deletions(-) diff --git a/drivers/block/ublk_drv.c b/drivers/block/ublk_drv.c index 467769fc388f..55d5435f1ee2 100644 --- a/drivers/block/ublk_drv.c +++ b/drivers/block/ublk_drv.c @@ -80,18 +80,10 @@ #define UBLK_PARAM_TYPE_ALL \ (UBLK_PARAM_TYPE_BASIC | UBLK_PARAM_TYPE_DISCARD | \ UBLK_PARAM_TYPE_DEVT | UBLK_PARAM_TYPE_ZONED | \ UBLK_PARAM_TYPE_DMA_ALIGN | UBLK_PARAM_TYPE_SEGMENT) -struct ublk_rq_data { - refcount_t ref; - - /* for auto-unregister buffer in case of UBLK_F_AUTO_BUF_REG */ - u16 buf_index; - void *buf_ctx_handle; -}; - struct ublk_uring_cmd_pdu { /* * Store requests in same batch temporarily for queuing them to * daemon context. * @@ -167,10 +159,26 @@ struct ublk_io { /* valid if UBLK_IO_FLAG_OWNED_BY_SRV is set */ struct request *req; }; struct task_struct *task; + + /* + * The number of uses of this I/O by the ublk server + * if user copy or zero copy are enabled: + * - 1 from dispatch to the server until UBLK_IO_COMMIT_AND_FETCH_REQ + * - 1 for each inflight ublk_ch_{read,write}_iter() call + * - 1 for each io_uring registered buffer + * The I/O can only be completed once all references are dropped. + * User copy and buffer registration operations are only permitted + * if the reference count is nonzero. + */ + refcount_t ref; + + /* auto-registered buffer, valid if UBLK_IO_FLAG_AUTO_BUF_REG is set */ + u16 buf_index; + void *buf_ctx_handle; }; struct ublk_queue { int q_id; int q_depth; @@ -226,11 +234,12 @@ struct ublk_params_header { static void ublk_io_release(void *priv); static void ublk_stop_dev_unlocked(struct ublk_device *ub); static void ublk_abort_queue(struct ublk_device *ub, struct ublk_queue *ubq); static inline struct request *__ublk_check_and_get_req(struct ublk_device *ub, - const struct ublk_queue *ubq, int tag, size_t offset); + const struct ublk_queue *ubq, struct ublk_io *io, + size_t offset); static inline unsigned int ublk_req_build_flags(struct request *req); static inline struct ublksrv_io_desc * ublk_get_iod(const struct ublk_queue *ubq, unsigned tag) { @@ -671,38 +680,30 @@ static inline bool ublk_need_req_ref(const struct ublk_queue *ubq) return ublk_support_user_copy(ubq) || ublk_support_zero_copy(ubq) || ublk_support_auto_buf_reg(ubq); } static inline void ublk_init_req_ref(const struct ublk_queue *ubq, - struct request *req) + struct ublk_io *io) { - if (ublk_need_req_ref(ubq)) { - struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); - - refcount_set(&data->ref, 1); - } + if (ublk_need_req_ref(ubq)) + refcount_set(&io->ref, 1); } static inline bool ublk_get_req_ref(const struct ublk_queue *ubq, - struct request *req) + struct ublk_io *io) { - if (ublk_need_req_ref(ubq)) { - struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); - - return refcount_inc_not_zero(&data->ref); - } + if (ublk_need_req_ref(ubq)) + return refcount_inc_not_zero(&io->ref); return true; } static inline void ublk_put_req_ref(const struct ublk_queue *ubq, - struct request *req) + struct ublk_io *io, struct request *req) { if (ublk_need_req_ref(ubq)) { - struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); - - if (refcount_dec_and_test(&data->ref)) + if (refcount_dec_and_test(&io->ref)) __ublk_complete_rq(req); } else { __ublk_complete_rq(req); } } @@ -1179,55 +1180,54 @@ static inline void __ublk_abort_rq(struct ublk_queue *ubq, blk_mq_requeue_request(rq, false); else blk_mq_end_request(rq, BLK_STS_IOERR); } -static void ublk_auto_buf_reg_fallback(struct request *req) +static void +ublk_auto_buf_reg_fallback(const struct ublk_queue *ubq, struct ublk_io *io) { - const struct ublk_queue *ubq = req->mq_hctx->driver_data; - struct ublksrv_io_desc *iod = ublk_get_iod(ubq, req->tag); - struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); + unsigned tag = io - ubq->ios; + struct ublksrv_io_desc *iod = ublk_get_iod(ubq, tag); iod->op_flags |= UBLK_IO_F_NEED_REG_BUF; - refcount_set(&data->ref, 1); + refcount_set(&io->ref, 1); } -static bool ublk_auto_buf_reg(struct request *req, struct ublk_io *io, - unsigned int issue_flags) +static bool ublk_auto_buf_reg(const struct ublk_queue *ubq, struct request *req, + struct ublk_io *io, unsigned int issue_flags) { struct ublk_uring_cmd_pdu *pdu = ublk_get_uring_cmd_pdu(io->cmd); - struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); int ret; ret = io_buffer_register_bvec(io->cmd, req, ublk_io_release, pdu->buf.index, issue_flags); if (ret) { if (pdu->buf.flags & UBLK_AUTO_BUF_REG_FALLBACK) { - ublk_auto_buf_reg_fallback(req); + ublk_auto_buf_reg_fallback(ubq, io); return true; } blk_mq_end_request(req, BLK_STS_IOERR); return false; } /* one extra reference is dropped by ublk_io_release */ - refcount_set(&data->ref, 2); + refcount_set(&io->ref, 2); - data->buf_ctx_handle = io_uring_cmd_ctx_handle(io->cmd); + io->buf_ctx_handle = io_uring_cmd_ctx_handle(io->cmd); /* store buffer index in request payload */ - data->buf_index = pdu->buf.index; + io->buf_index = pdu->buf.index; io->flags |= UBLK_IO_FLAG_AUTO_BUF_REG; return true; } static bool ublk_prep_auto_buf_reg(struct ublk_queue *ubq, struct request *req, struct ublk_io *io, unsigned int issue_flags) { if (ublk_support_auto_buf_reg(ubq) && ublk_rq_has_data(req)) - return ublk_auto_buf_reg(req, io, issue_flags); + return ublk_auto_buf_reg(ubq, req, io, issue_flags); - ublk_init_req_ref(ubq, req); + ublk_init_req_ref(ubq, io); return true; } static bool ublk_start_io(const struct ublk_queue *ubq, struct request *req, struct ublk_io *io) @@ -1485,10 +1485,12 @@ static void ublk_queue_reinit(struct ublk_device *ub, struct ublk_queue *ubq) */ if (io->task) { put_task_struct(io->task); io->task = NULL; } + + WARN_ON_ONCE(refcount_read(&io->ref)); } } static int ublk_ch_open(struct inode *inode, struct file *filp) { @@ -1989,33 +1991,35 @@ static inline int ublk_set_auto_buf_reg(struct io_uring_cmd *cmd) static void ublk_io_release(void *priv) { struct request *rq = priv; struct ublk_queue *ubq = rq->mq_hctx->driver_data; + struct ublk_io *io = &ubq->ios[rq->tag]; - ublk_put_req_ref(ubq, rq); + ublk_put_req_ref(ubq, io, rq); } static int ublk_register_io_buf(struct io_uring_cmd *cmd, - const struct ublk_queue *ubq, unsigned int tag, + const struct ublk_queue *ubq, + struct ublk_io *io, unsigned int index, unsigned int issue_flags) { struct ublk_device *ub = cmd->file->private_data; struct request *req; int ret; if (!ublk_support_zero_copy(ubq)) return -EINVAL; - req = __ublk_check_and_get_req(ub, ubq, tag, 0); + req = __ublk_check_and_get_req(ub, ubq, io, 0); if (!req) return -EINVAL; ret = io_buffer_register_bvec(cmd, req, ublk_io_release, index, issue_flags); if (ret) { - ublk_put_req_ref(ubq, req); + ublk_put_req_ref(ubq, io, req); return ret; } return 0; } @@ -2118,14 +2122,12 @@ static int ublk_commit_and_fetch(const struct ublk_queue *ubq, * one for registering the buffer, it is ublk server's * responsibility for unregistering the buffer, otherwise * this ublk request gets stuck. */ if (io->flags & UBLK_IO_FLAG_AUTO_BUF_REG) { - struct ublk_rq_data *data = blk_mq_rq_to_pdu(req); - - if (data->buf_ctx_handle == io_uring_cmd_ctx_handle(cmd)) - io_buffer_unregister_bvec(cmd, data->buf_index, + if (io->buf_ctx_handle == io_uring_cmd_ctx_handle(cmd)) + io_buffer_unregister_bvec(cmd, io->buf_index, issue_flags); io->flags &= ~UBLK_IO_FLAG_AUTO_BUF_REG; } ret = ublk_set_auto_buf_reg(cmd); @@ -2141,11 +2143,11 @@ static int ublk_commit_and_fetch(const struct ublk_queue *ubq, if (req_op(req) == REQ_OP_ZONE_APPEND) req->__sector = ub_cmd->zone_append_lba; if (likely(!blk_should_fake_timeout(req->q))) - ublk_put_req_ref(ubq, req); + ublk_put_req_ref(ubq, io, req); return 0; } static bool ublk_get_data(const struct ublk_queue *ubq, struct ublk_io *io) @@ -2220,11 +2222,11 @@ static int __ublk_ch_uring_cmd(struct io_uring_cmd *cmd, goto out; ret = -EINVAL; switch (_IOC_NR(cmd_op)) { case UBLK_IO_REGISTER_IO_BUF: - return ublk_register_io_buf(cmd, ubq, tag, ub_cmd->addr, issue_flags); + return ublk_register_io_buf(cmd, ubq, io, ub_cmd->addr, issue_flags); case UBLK_IO_UNREGISTER_IO_BUF: return ublk_unregister_io_buf(cmd, ubq, ub_cmd->addr, issue_flags); case UBLK_IO_FETCH_REQ: ret = ublk_fetch(cmd, ubq, io, ub_cmd->addr); if (ret) @@ -2252,19 +2254,24 @@ static int __ublk_ch_uring_cmd(struct io_uring_cmd *cmd, __func__, cmd_op, tag, ret, io->flags); return ret; } static inline struct request *__ublk_check_and_get_req(struct ublk_device *ub, - const struct ublk_queue *ubq, int tag, size_t offset) + const struct ublk_queue *ubq, struct ublk_io *io, size_t offset) { + unsigned tag = io - ubq->ios; struct request *req; + /* + * can't use io->req in case of concurrent UBLK_IO_COMMIT_AND_FETCH_REQ, + * which would overwrite it with io->cmd + */ req = blk_mq_tag_to_rq(ub->tag_set.tags[ubq->q_id], tag); if (!req) return NULL; - if (!ublk_get_req_ref(ubq, req)) + if (!ublk_get_req_ref(ubq, io)) return NULL; if (unlikely(!blk_mq_request_started(req) || req->tag != tag)) goto fail_put; @@ -2274,11 +2281,11 @@ static inline struct request *__ublk_check_and_get_req(struct ublk_device *ub, if (offset > blk_rq_bytes(req)) goto fail_put; return req; fail_put: - ublk_put_req_ref(ubq, req); + ublk_put_req_ref(ubq, io, req); return NULL; } static inline int ublk_ch_uring_cmd_local(struct io_uring_cmd *cmd, unsigned int issue_flags) @@ -2341,11 +2348,12 @@ static inline bool ublk_check_ubuf_dir(const struct request *req, return false; } static struct request *ublk_check_and_get_req(struct kiocb *iocb, - struct iov_iter *iter, size_t *off, int dir) + struct iov_iter *iter, size_t *off, int dir, + struct ublk_io **io) { struct ublk_device *ub = iocb->ki_filp->private_data; struct ublk_queue *ubq; struct request *req; size_t buf_off; @@ -2375,11 +2383,12 @@ static struct request *ublk_check_and_get_req(struct kiocb *iocb, return ERR_PTR(-EACCES); if (tag >= ubq->q_depth) return ERR_PTR(-EINVAL); - req = __ublk_check_and_get_req(ub, ubq, tag, buf_off); + *io = &ubq->ios[tag]; + req = __ublk_check_and_get_req(ub, ubq, *io, buf_off); if (!req) return ERR_PTR(-EINVAL); if (!req->mq_hctx || !req->mq_hctx->driver_data) goto fail; @@ -2388,46 +2397,48 @@ static struct request *ublk_check_and_get_req(struct kiocb *iocb, goto fail; *off = buf_off; return req; fail: - ublk_put_req_ref(ubq, req); + ublk_put_req_ref(ubq, *io, req); return ERR_PTR(-EACCES); } static ssize_t ublk_ch_read_iter(struct kiocb *iocb, struct iov_iter *to) { struct ublk_queue *ubq; struct request *req; + struct ublk_io *io; size_t buf_off; size_t ret; - req = ublk_check_and_get_req(iocb, to, &buf_off, ITER_DEST); + req = ublk_check_and_get_req(iocb, to, &buf_off, ITER_DEST, &io); if (IS_ERR(req)) return PTR_ERR(req); ret = ublk_copy_user_pages(req, buf_off, to, ITER_DEST); ubq = req->mq_hctx->driver_data; - ublk_put_req_ref(ubq, req); + ublk_put_req_ref(ubq, io, req); return ret; } static ssize_t ublk_ch_write_iter(struct kiocb *iocb, struct iov_iter *from) { struct ublk_queue *ubq; struct request *req; + struct ublk_io *io; size_t buf_off; size_t ret; - req = ublk_check_and_get_req(iocb, from, &buf_off, ITER_SOURCE); + req = ublk_check_and_get_req(iocb, from, &buf_off, ITER_SOURCE, &io); if (IS_ERR(req)) return PTR_ERR(req); ret = ublk_copy_user_pages(req, buf_off, from, ITER_SOURCE); ubq = req->mq_hctx->driver_data; - ublk_put_req_ref(ubq, req); + ublk_put_req_ref(ubq, io, req); return ret; } static const struct file_operations ublk_ch_fops = { @@ -2448,10 +2459,11 @@ static void ublk_deinit_queue(struct ublk_device *ub, int q_id) for (i = 0; i < ubq->q_depth; i++) { struct ublk_io *io = &ubq->ios[i]; if (io->task) put_task_struct(io->task); + WARN_ON_ONCE(refcount_read(&io->ref)); } if (ubq->io_cmd_buf) free_pages((unsigned long)ubq->io_cmd_buf, get_order(size)); } @@ -2600,11 +2612,10 @@ static int ublk_add_tag_set(struct ublk_device *ub) { ub->tag_set.ops = &ublk_mq_ops; ub->tag_set.nr_hw_queues = ub->dev_info.nr_hw_queues; ub->tag_set.queue_depth = ub->dev_info.queue_depth; ub->tag_set.numa_node = NUMA_NO_NODE; - ub->tag_set.cmd_size = sizeof(struct ublk_rq_data); ub->tag_set.driver_data = ub; return blk_mq_alloc_tag_set(&ub->tag_set); } static void ublk_remove(struct ublk_device *ub) -- 2.45.2