On Mon, 28 Jul 2025 19:21:56 +0530 "Aneesh Kumar K.V (Arm)" <aneesh.kumar@xxxxxxxxxx> wrote: > Add changes to share the device's public key with the RMM. > > Signed-off-by: Aneesh Kumar K.V (Arm) <aneesh.kumar@xxxxxxxxxx> A few minor comments inline. > diff --git a/drivers/virt/coco/arm-cca-host/rmm-da.c b/drivers/virt/coco/arm-cca-host/rmm-da.c > index ec8c5bfcee35..3715e6d58c83 100644 > --- a/drivers/virt/coco/arm-cca-host/rmm-da.c > +++ b/drivers/virt/coco/arm-cca-host/rmm-da.c > @@ -6,6 +6,9 @@ > #include <linux/pci.h> > #include <linux/pci-ecam.h> > #include <asm/rmi_cmds.h> > +#include <crypto/internal/rsa.h> > +#include <keys/asymmetric-type.h> > +#include <keys/x509-parser.h> > > #include "rmm-da.h" > > @@ -311,6 +314,136 @@ static int do_pdev_communicate(struct pci_tsm *tsm, int target_state) > return do_dev_communicate(PDEV_COMMUNICATE, tsm, target_state); > } > > +static int parse_certificate_chain(struct pci_tsm *tsm) > +{ > + struct cca_host_dsc_pf0 *dsc_pf0; > + unsigned int chain_size; > + unsigned int offset = 0; > + u8 *chain_data; > + int ret = 0; > + > + dsc_pf0 = to_cca_dsc_pf0(tsm->pdev); > + chain_size = dsc_pf0->cert_chain.cache.size; > + chain_data = dsc_pf0->cert_chain.cache.buf; > + > + while (offset < chain_size) { > + unsigned int cert_len = > + x509_get_certificate_length(chain_data + offset, > + chain_size - offset); > + struct x509_certificate *cert = > + x509_cert_parse(chain_data + offset, cert_len); > + > + if (IS_ERR(cert)) { > + pr_warn("%s(): parsing of certificate chain not successful\n", __func__); > + ret = PTR_ERR(cert); Direct return looks fine here. Maybe add a DEFINE_FREE(x509_cert,...) as then can use direct returns throughout. > + break; > + } > + > + if (offset + cert_len == chain_size) { > + dsc_pf0->cert_chain.public_key = kzalloc(cert->pub->keylen, GFP_KERNEL); > + if (!dsc_pf0->cert_chain.public_key) { > + ret = -ENOMEM; > + x509_free_certificate(cert); > + break; > + } > + > + if (!strcmp("ecdsa-nist-p256", cert->pub->pkey_algo)) { > + dsc_pf0->rmi_signature_algorithm = RMI_SIG_ECDSA_P256; > + } else if (!strcmp("ecdsa-nist-p384", cert->pub->pkey_algo)) { > + dsc_pf0->rmi_signature_algorithm = RMI_SIG_ECDSA_P384; > + } else if (!strcmp("rsa", cert->pub->pkey_algo)) { > + dsc_pf0->rmi_signature_algorithm = RMI_SIG_RSASSA_3072; > + } else { > + ret = -ENXIO; > + x509_free_certificate(cert); > + break; > + } > + memcpy(dsc_pf0->cert_chain.public_key, cert->pub->key, cert->pub->keylen); > + dsc_pf0->cert_chain.public_key_size = cert->pub->keylen; > + } > + > + x509_free_certificate(cert); > + > + offset += cert_len; > + } > + > + if (ret == 0) > + dsc_pf0->cert_chain.valid = true; if (ret) return ret; dsc_pf0->cert_chain.valid = true; return 0; would be my preference for style here but others may disagree. > + > + return ret; > +} > + > +static int pdev_set_public_key(struct pci_tsm *tsm) > +{ > + struct rmi_public_key_params *key_shared; > + unsigned long expected_key_len = 0; Don't set this. It's only used in places where it is explicitly set and if it is used anywhere else we want the compiler to tell us. > + struct cca_host_dsc_pf0 *dsc_pf0; > + int ret; > + > + dsc_pf0 = to_cca_dsc_pf0(tsm->pdev); > + /* Check that all the necessary information was captured from communication */ > + if (!dsc_pf0->cert_chain.valid) > + return -EINVAL; > + > + key_shared = (struct rmi_public_key_params *)get_zeroed_page(GFP_KERNEL); > + if (!key_shared) > + return -ENOMEM; > + > + key_shared->rmi_signature_algorithm = dsc_pf0->rmi_signature_algorithm; > + > + switch (key_shared->rmi_signature_algorithm) { > + case RMI_SIG_ECDSA_P384: > + expected_key_len = 97; > + > + if (dsc_pf0->cert_chain.public_key_size != expected_key_len) > + return -EINVAL; > + key_shared->public_key_len = dsc_pf0->cert_chain.public_key_size; > + memcpy(key_shared->public_key, > + dsc_pf0->cert_chain.public_key, > + dsc_pf0->cert_chain.public_key_size); > + key_shared->metadata_len = 0; > + break; > + case RMI_SIG_ECDSA_P256: > + expected_key_len = 65; > + > + if (dsc_pf0->cert_chain.public_key_size != expected_key_len) > + return -EINVAL; > + key_shared->public_key_len = dsc_pf0->cert_chain.public_key_size; > + memcpy(key_shared->public_key, > + dsc_pf0->cert_chain.public_key, > + dsc_pf0->cert_chain.public_key_size); > + key_shared->metadata_len = 0; > + break; > + case RMI_SIG_RSASSA_3072: > + expected_key_len = 385; > + struct rsa_key rsa_key = {0}; Shouldn't define this inline. Maybe move up a line and add some {} to set the scope to this case statement. > + int ret_rsa_parse = rsa_parse_pub_key(&rsa_key, > + dsc_pf0->cert_chain.public_key, > + dsc_pf0->cert_chain.public_key_size); > + /* This also checks the key_len */ > + if (ret_rsa_parse) > + return ret_rsa_parse; > + /* > + * exponent is usally 65537 (size = 24bits) but in rare cases > + * it size can be as large as the modulus > + */ > + if (rsa_key.e_sz > expected_key_len) > + return -EINVAL; > + key_shared->public_key_len = rsa_key.n_sz; > + key_shared->metadata_len = rsa_key.e_sz; > + memcpy(key_shared->public_key, (unsigned char *)rsa_key.n, rsa_key.n_sz); Why is the cast needed? > + memcpy(key_shared->metadata, (unsigned char *)rsa_key.e, rsa_key.e_sz); > + break; > + default: > + return -EINVAL; > + } > + > + ret = rmi_pdev_set_pubkey(virt_to_phys(dsc_pf0->rmm_pdev), > + virt_to_phys(key_shared)); > + free_page((unsigned long)key_shared); > + return ret; > +}