Send email Copy Email Address

2023-12-15
Felix Koltermann

CISPA researchers test new method for uncertainty quantification in machine learning applications

For a machine learning algorithm to be trustworthy, the end user needs to know how confident the model is for every single prediction. To date, accuracy has been a major criterion for the evaluation of a model. However, this does not reflect on how confident the model is in processing each input. Scientists have devised methods to quantify the “uncertainty” but most of these methods are computationally expensive. The task becomes even more challenging when inputs are given in the form of a network with meaningful connections such as happens, for example, in drug discovery, medical diagnosis, or traffic forecast. At the same time, these meaningful connections in data can be used to arrive at a better understanding of the uncertainty. In their recent paper “Conformal Prediction Sets for Graph Neural Networks”, CISPA researcher Soroush H. Zargarbashi and his colleagues have successfully tested a new method to define a set of possible predictions for these networks that is guaranteed to include the true prediction.

Many machine learning applications are based on graph neural networks (GNNs). "In many real-life scenarios, we deal with graphs. There are meaningful connections between our datapoints, and with GNN we take those connections into account", CISPA researcher Soroush H. Zargarbashi explains. Graphs are a type of abstract data structure consisting of two elements, nodes and connections between nodes, called edges. Graphs can, for example social, provide models for networks, mesh of sensors, scientific papers with their references, etc. There is, however, one particular characteristic that causes problems in certain areas of application, Soroush continues: "If you use a model as a black box, an input always results in an output like your car seeing the scene and deciding to steer left. But if you don't know if the model is sure about this particular output, it becomes highly untrustworthy especially in safety-critical domains where the user needs an uncertainty estimate of the model." The problem is that the models often overestimate their prediction quality, while underestimating the uncertainty factor of the predictions.

To illustrate how important it is for models to provide a reliable uncertainty estimate for their predictions, Soroush gives an example: "Imagine you are using a medical diagnostic system to decide whether a patient has a certain disease. In this case, it is very important that your model can predict this with a high degree of certainty. If the model cannot do this, further diagnoses have to be made. The idea behind the quantification of uncertainty is to refine these predictions." Consider, for example, procedures in which AI is used to determine whether an organ is cancerous or not by analyzing MRI images automatically. Here, it is important to know the quality of prediction for each individual input. In other words, an accuracy of 90 percent can be very risky for the remaining 10 percent.

Conformal prediction as a solution for uncertainty quantification

"There are methods to quantify this uncertainty. However, they are computationally expensive, hard to apply and worst of all, many of them don’t work on graphs”, Soroush says. Many of these methods usually require modifications to the model architecture or retraining of the model. "However, there is growing interest in an alternative approach known as conformal prediction", he continues. Conformal prediction (CP) is a statistical method for creating prediction sets that has been known since the late 1990s and that does not require any assumptions about the prediction algorithm. Soroush mentions that CP can work around the model like a kind of wrapper and generate a set of possible predictions with a user-defined probability guarantee for the correct answer. But how exactly does it work? "For example, for a new patient, you tune the algorithm to produce sets that guarantee the true answer with a probability of 95 percent. This works for any model, even those who are 60 percent accurate. You just need a random sample of previous patients with their correct diagnosis. In this way, for each patient, we have a set of possible diagnoses which we know includes the correct answer with very high probability.”

Making machine learning trustworthy

Basically, the method developed by Soroush and his colleagues is an alternative to uncertainty quantification that not only works on graphs (data with connections) but also uses the information from the relations between datapoints. Their method is computationally inexpensive and easy to implement, provided that additional data are available. The key advantage of this approach, according to Soroush, is that CP is "model independent, meaning that it doesn't matter which model is used, so you don’t need to train anything from scratch.” To further improve the applicability of this study, they develop a method called Diffusion Adaptive Prediction Sets, which uses the connections between datapoints to improve the uncertainty estimation quality. In the published paper, the detailed empirical analysis of this method is embedded in a comprehensive theoretical investigation of when CP can be applied to GNNs. With their study, Soroush and his colleagues make an important contribution to increasing the trustworthiness of machine learning models on graph data.