Dan Williams <dan.j.williams@xxxxxxxxx> writes: > From: Xu Yilun <yilun.xu@xxxxxxxxxxxxxxx> ... > @@ -558,11 +675,11 @@ int pci_tsm_bind(struct pci_dev *pdev, struct kvm *kvm, u64 tdi_id) > if (!pdev->tsm) > return -EINVAL; > > - struct pci_dev *pf0_dev __free(pci_dev_put) = tsm_pf0_get(pdev); > - if (!pf0_dev) > + struct pci_dev *dsm_dev __free(pci_dev_put) = dsm_dev_get(pdev); > + if (!dsm_dev) > return -EINVAL; > > - struct mutex *ops_lock __free(tdi_ops_unlock) = tdi_ops_lock(pf0_dev); > + struct mutex *ops_lock __free(tdi_ops_unlock) = tdi_ops_lock(dsm_dev); > if (IS_ERR(ops_lock)) > return PTR_ERR(ops_lock); > > @@ -573,10 +690,13 @@ int pci_tsm_bind(struct pci_dev *pdev, struct kvm *kvm, u64 tdi_id) > return -EBUSY; > } > > - tdi = tsm_ops->bind(pdev, pf0_dev, kvm, tdi_id); > + tdi = tsm_ops->bind(pdev, dsm_dev, kvm, tdi_id); > if (!tdi) > return -ENXIO; > > + tdi->pdev = pdev; > + tdi->dsm_dev = dsm_dev; > + tdi->kvm = kvm; > pdev->tsm->tdi = tdi; > should that be no_free_ptr(dsm_dev)? Also unbind needs to drop that device reference? modified drivers/pci/tsm.c @@ -697,7 +697,7 @@ int pci_tsm_bind(struct pci_dev *pdev, struct kvm *kvm, u64 tdi_id) return -ENXIO; tdi->pdev = pdev; - tdi->dsm_dev = dsm_dev; + tdi->dsm_dev = no_free_ptr(dsm_dev); tdi->kvm = kvm; pdev->tsm->tdi = tdi; @@ -714,10 +714,6 @@ static int __pci_tsm_unbind(struct pci_dev *pdev) if (!pdev->tsm) return -EINVAL; - struct pci_dev *dsm_dev __free(pci_dev_put) = dsm_dev_get(pdev); - if (!dsm_dev) - return -EINVAL; - struct mutex *lock __free(tdi_ops_unlock) = tdi_ops_lock(dsm_dev); if (IS_ERR(lock)) return PTR_ERR(lock); @@ -726,6 +722,10 @@ static int __pci_tsm_unbind(struct pci_dev *pdev) if (!tdi) return 0; + struct pci_dev *dsm_dev __free(pci_dev_put) = tdi->dsm_dev; + if (!dsm_dev) + return -EINVAL; + tsm_ops->unbind(tdi); pdev->tsm->tdi = NULL; > > return 0; > @@ -592,11 +712,11 @@ static int __pci_tsm_unbind(struct pci_dev *pdev) > if (!pdev->tsm) > return -EINVAL; > > - struct pci_dev *pf0_dev __free(pci_dev_put) = tsm_pf0_get(pdev); > - if (!pf0_dev) > + struct pci_dev *dsm_dev __free(pci_dev_put) = dsm_dev_get(pdev); > + if (!dsm_dev) > return -EINVAL; > > - struct mutex *lock __free(tdi_ops_unlock) = tdi_ops_lock(pf0_dev); > + struct mutex *lock __free(tdi_ops_unlock) = tdi_ops_lock(dsm_dev); > if (IS_ERR(lock)) > return PTR_ERR(lock); > > @@ -641,11 +761,11 @@ int pci_tsm_guest_req(struct pci_dev *pdev, struct pci_tsm_guest_req_info *info) > if (!pdev->tsm) > return -ENODEV; > > - struct pci_dev *pf0_dev __free(pci_dev_put) = tsm_pf0_get(pdev); > - if (!pf0_dev) > + struct pci_dev *dsm_dev __free(pci_dev_put) = dsm_dev_get(pdev); > + if (!dsm_dev) > return -EINVAL; > > - struct mutex *lock __free(tdi_ops_unlock) = tdi_ops_lock(pf0_dev); > + struct mutex *lock __free(tdi_ops_unlock) = tdi_ops_lock(dsm_dev); > if (IS_ERR(lock)) > return -ENODEV; > ... -aneesh