+ * @reachable: Callback to tell if two devices can reach each other
+ *
+ * Compute a bitmap where every set bit is a device on the bus that is reachable
+ * from the start device, including the start device. Reachability between two
+ * devices is determined by a callback function.
+ *
+ * This is a non-recursive implementation that invokes the callback once per
+ * pair. The callback must be commutative:
+ * reachable(a, b) == reachable(b, a)
+ * reachable() can form a cyclic graph:
+ * reachable(a,b) == reachable(b,c) == reachable(c,a) == true
+ *
+ * Since this function is limited to a single bus the largest set can be 256
+ * devices large.
+ */
+void pci_reachable_set(struct pci_dev *start, struct pci_reachable_set *devfns,
+ bool (*reachable)(struct pci_dev *deva,
+ struct pci_dev *devb))
+{
+ struct pci_reachable_set todo_devfns = {};
+ struct pci_reachable_set next_devfns = {};
+ struct pci_bus *bus = start->bus;
+ bool again;
+
+ /* Assume devfn of all PCI devices is bounded by MAX_NR_DEVFNS */
+ static_assert(sizeof(next_devfns.devfns) * BITS_PER_BYTE >=
+ MAX_NR_DEVFNS);
+
+ memset(devfns, 0, sizeof(devfns->devfns));
+ __set_bit(start->devfn, devfns->devfns);
+ __set_bit(start->devfn, next_devfns.devfns);
+
+ down_read(&pci_bus_sem);
+ while (true) {
+ unsigned int devfna;
+ unsigned int i;
+
+ /*
+ * For each device that hasn't been checked compare every
+ * device on the bus against it.
+ */
+ again = false;
+ for_each_set_bit(devfna, next_devfns.devfns, MAX_NR_DEVFNS) {
+ struct pci_dev *deva = NULL;
+ struct pci_dev *devb;
+
+ list_for_each_entry(devb, &bus->devices, bus_list) {
+ if (devb->devfn == devfna)
+ deva = devb;
+
+ if (test_bit(devb->devfn, devfns->devfns))
+ continue;
+
+ if (!deva) {
+ deva = devb;
+ list_for_each_entry_continue(
+ deva, &bus->devices, bus_list)
+ if (deva->devfn == devfna)
+ break;
+ }
+
+ if (!reachable(deva, devb))
+ continue;
+
+ __set_bit(devb->devfn, todo_devfns.devfns);
+ again = true;
+ }
+ }
+
+ if (!again)
+ break;
+
+ /*
+ * Every new bit adds a new deva to check, reloop the whole
+ * thing. Expect this to be rare.
+ */
+ for (i = 0; i != ARRAY_SIZE(devfns->devfns); i++) {
+ devfns->devfns[i] |= todo_devfns.devfns[i];
+ next_devfns.devfns[i] = todo_devfns.devfns[i];
+ todo_devfns.devfns[i] = 0;
+ }
+ }
+ up_read(&pci_bus_sem);
+}
+EXPORT_SYMBOL_GPL(pci_reachable_set);
diff --git a/include/linux/pci.h b/include/linux/pci.h
index fb9adf0562f8ef..21f6b20b487f8d 100644
--- a/include/linux/pci.h
+++ b/include/linux/pci.h
@@ -855,6 +855,10 @@ struct pci_dynids {
struct list_head list; /* For IDs added at runtime */
};
+struct pci_reachable_set {
+ DECLARE_BITMAP(devfns, 256);
+};
+
enum pci_bus_isolation {
/*
* The bus is off a root port and the root port has isolated ACS flags
@@ -1269,6 +1273,9 @@ struct pci_dev *pci_get_domain_bus_and_slot(int domain, unsigned int bus,
struct pci_dev *pci_get_class(unsigned int class, struct pci_dev *from);
struct pci_dev *pci_get_base_class(unsigned int class, struct pci_dev *from);
+void pci_reachable_set(struct pci_dev *start, struct pci_reachable_set *devfns,
+ bool (*reachable)(struct pci_dev *deva,
+ struct pci_dev *devb));
enum pci_bus_isolation pci_bus_isolated(struct pci_bus *bus);
int pci_dev_present(const struct pci_device_id *ids);
@@ -2084,6 +2091,11 @@ static inline struct pci_dev *pci_get_base_class(unsigned int class,
struct pci_dev *from)
{ return NULL; }
+static inline void
+pci_reachable_set(struct pci_dev *start, struct pci_reachable_set *devfns,
+ bool (*reachable)(struct pci_dev *deva, struct pci_dev *devb))
+{ }
+
static inline enum pci_bus_isolation pci_bus_isolated(struct pci_bus *bus)
{ return PCIE_NON_ISOLATED; }
--
2.43.0