AI memory enters the era of selective forgetting. Machine unlearning, a novel technology, selectively erases information from trained models without full retraining.
It's the absolute cutting edge of AI model safety and data privacy.
Microsoft researchers showcased this capability by removing Harry Potter references from META’s Llama2 Model. OpenAI researchers applied unlearning to eliminate biased content, enhancing model safety.
Unlearning addresses GDPR's "right to be forgotten" mandate. For example, Google held a machine unlearning competition focusing on removing specific facial images from an AI model trained to predict age from images.
For cybersecurity, unlearning mitigates data poisoning risks. JPMorgan Chase applied unlearning to consumer systems, preserving customer privacy without compromising effectiveness.
The future of scalable AI systems lie in selective forgetting, not indiscriminate data accumulation. This breakthrough ushers in a new paradigm of AI privacy, where machines learn, unlearn, and relearn with unprecedented flexibility.
As we stand on this new frontier, we have researched different machine unlearning techniques, that will lead us to a world where machines possess human-like memory malleability.
And if you're in a rush, we have a TLDR at the end.
What is Machine Unlearning
Machine unlearning erases specific data from pre-trained models without full retraining. This process requires embedding unlearning mechanisms during initial model training.
We have broken down the techniques into four major ones: Exact, Approximate, Prompt-Based, Decentralized.
Exact methods remove specific data precisely from AI models. They divide data or transform algorithms to target unwanted information. This approach deletes or updates only necessary parts, avoiding full model retrains.
Approximate methods take a broader approach. They fine-tune models on slightly altered data to reduce targeted information's influence. Some add controlled noise during training, limiting each data point's impact from the start.
Prompt-Based techniques guide AI language models to "forget" without changing their core. They use smart word tricks to create examples that steer the AI away from unwanted info. The AI receives new instructions to ignore certain knowledge, without erasing its memory.
Decentralized approaches erase data from AI systems spread across multiple devices. They remove information when a participant exits the network. These methods use compact models to efficiently spread the impact of this exit. This eliminates the leaving party's influence without retraining the entire system. The process maintains collaborative learning while enabling selective data removal.
As these techniques evolve, machine unlearning becomes more effective, enabling models to forget specific information while maintaining strong performance.
1. Exact Methods for Machine Unlearning
Exact unlearning removes specific data from machine learning models. It targets a subset called the "forget set" within the main dataset. The process creates a new model that performs nearly identically to the original. This technique deletes chosen data points rapidly, without the need of full retraining.
The updated model maintains high accuracy on remaining data. Think of it as selectively erasing memories from a human brain while preserving overall knowledge and functionality.
We have broken down the exact methods in two major types.
SISA.
Statistical Query.
1.1. SISA
SISA (Sharding, Isolation, Slicing, and Aggregation) is an advanced algorithmic framework designed to accelerate machine unlearning processes (Bourtoule et al. 2019). This technique strategically limits the influence of individual data points during model training, facilitating efficient removal of specific data without necessitating complete model retraining.
SISA Architecture
The diagram shows a layered machine learning system. Training data forms the base, split into 'n' shards (Shard 1 to Shard n). Each shard is an isolated subset of the dataset. Shards divide into 'R' slices (slice 1 to slice R), enabling incremental training and unlearning.
Models 1 through n, shown as connected nodes, train on corresponding shards. These models constitute the SISA (Sharded, Isolated, Sliced, Aggregated) system core. This structure allows parallel processing and targeted updates.
An aggregation layer tops the system. It combines outputs from all models to generate the final prediction. This method leverages insights from different data subsets.
The design enhances data management and system flexibility. It processes large datasets efficiently while allowing precise adjustments. The architecture balances scalability with fine-tuned control.
SISA Training
The SISA training process operates systematically. The main dataset splits into 'n' shards, each training a separate model. For each shard:
The first slice initializes a new model, training for a preset number of epochs.
Subsequent slices continue training from the model's previous state.
The last slice finalizes training, producing the shard's constituent model.
This slice-wise approach enables incremental learning. Models update efficiently as new data arrives. It also facilitates data removal without full retraining. The process repeats for all shards.
The final step assembles the complete SISA model by combining all trained constituent models. This structure allows for parallel processing and targeted updates.
SISA Inference
SISA inference runs input through all models (Model 1 to n). Each model produces a prediction. The system aggregates these predictions via majority voting. It selects the most common label as output.
This method reduces individual model biases. The aggregation layer functions as an ensemble. It combines insights from all data shards.
Voting can be weighted for more reliable models. This allows inference fine-tuning. The system can track confidence by measuring model agreement.
This enables parallel processing during inference and can reduce latency in large-scale applications.
SISA Unlearning Procedure
SISA unlearning locates the target data to be deleted, in a specific shard and slice, shown as the red box in Shard 2, slice 2. The system then jumps back to a point just before this data was added. This is shown in the green outline around slice 2 in Shard 2.
From this past state, SISA begins selective retraining. Model 2 Retrain, in green, shows this process. It retrains using all data except the unlearn target.
Once complete, the retrained Model 2 replaces its original version. The Aggregation layer at the top then updates to use this new model's input.
This selective retraining approach allows SISA to precisely remove specific data. It does this without disturbing other parts of the system, as Model 1 and Model n remain unchanged.
SISA Advantages
Computational Efficiency: SISA significantly reduces the computational overhead of unlearning by limiting retraining to a single shard and subset of slices.
Scalability: The sharding approach allows for parallel processing and efficient handling of large datasets.
Isolation: By confining data points to specific shards and slices, SISA minimizes the impact of individual data removals on the overall model.
Adaptability: The framework can be extended to various machine learning algorithms and data types.
SISA-based Methods
Several adaptations of SISA have been developed for specific use cases.
DaRE Forest: A random forest variant using a two-level approach with random and greedy nodes, enabling efficient updates of subtrees upon data removal (Brophy & Lowd 2021).
HedgeCut: Focuses on low-latency unlearning in extremely randomized trees, introducing split robustness concepts (Schelter et al. 2021).
GraphEraser: Tailored for Graph Neural Networks (GNNs), incorporating balanced graph partition and optimized shard model importance scoring (Chen et al. 2021).
RecEraser: Specialized for recommendation tasks, employing adaptive aggregation methods to combine sub-model predictions (Chen et al. 2022).
In conclusion, SISA represents a flexible and efficient framework for managing data removal requests in machine learning models. Its architecture and process flow, as illustrated in the provided images, demonstrate a sophisticated approach to balancing model performance with data privacy and right-to-be-forgotten requirements.
1.2 Statistical Query (SQ) Approach
Cao & Yang (2015) introduced an innovative approach to machine unlearning, distinct from the SISA framework. This method uses statistical queries and summation forms to enable efficient unlearning without complete model retraining.
SQ Core Concepts
Summation Form Transformation: The key idea is transforming learning algorithms into a summation form. In this form, the algorithm relies on a small number of summations, each representing an efficiently computable transformation of the training data samples. These summations are saved alongside the trained model.
Statistical Query (SQ) Learning: To achieve the summation form, Cao & Yang employ Statistical Query (SQ) learning, introduced by Kearns (1993) as a refinement of the PAC learning model. In SQ learning, the algorithm interacts with a statistical query oracle instead of directly examining individual examples. The learning algorithm queries the oracle with a function and a tolerance parameter, receiving an estimate of the expected value over a distribution of labeled examples.
Formal Definition: More formally, the learning algorithm queries an oracle with a function ξ:X x {0,1} → {0,1} and a tolerance parameter τ. The oracle responds with an estimate of the expected value of ξ over a distribution D of labeled examples such that the estimate of E[ξ(x,y)] is within an additive τ. The training algorithm only has access to these statistics; it does not directly access the input data.
How SQ Works
The image below illustrates the process flow.
The image shows a sophisticated approach to machine learning and unlearning using Statistical Query (SQ) learning. Here's a technical breakdown:
Data Points: The left side shows input data points x₁, x₂, x₃. These represent individual training examples from the dataset D.
Statistical Queries: Dashed lines connect data points to q₁ and q₂, representing statistical query functions ξ:X x {0,1} → {0,1}. These queries transform raw data into relevant statistics.
Summations: The Σ symbols denote summation operators. They aggregate query results across the dataset, computing E[ξ(x,y)] for each query function.
Tolerance: While not explicitly shown, each summation has an associated tolerance parameter τ, defining the acceptable error margin for the computed statistic.
Neural Network: The orange box on the right depicts a neural network architecture. It contains multiple layers (Input, Hidden, Output) interconnected by edges.
SQ to Model Pipeline: Solid arrows from summations to the neural network illustrate how aggregated statistics directly inform model parameters or training.
Unlearning Process:
To forget a data point, the system updates stored summations by subtracting that point's contributions.
This involves simple arithmetic operations on the summations.
The model is then efficiently recomputed using these updated summations.
Unlearning Efficiency: This structure allows for efficient unlearning. Removing a data point only requires updating the summations, not retraining the entire model.
Scalability: The approach scales well with large datasets, as the model depends on a fixed number of summations rather than individual data points.
This SQ learning framework transforms complex learning algorithms into a summation form, enabling rapid updates and unlearning while maintaining model accuracy.
SQ Technical Considerations
The effectiveness of this method depends on expressing the learning algorithm in terms of statistical queries. The choice of queries and summations can significantly impact unlearning efficiency and model performance. There's a balance to strike between the granularity of summations and the efficiency of unlearning.
This non-SISA approach provides a novel perspective on machine unlearning, focusing on data representation and model dependency rather than architectural modifications. Its application in various machine learning contexts remains an active area of research and development.
2. Approximate Methods for Machine Unlearning
While exact methods for machine unlearning are effective, they often require significant computational resources and storage overhead. Xu et al. (2024) present approximate methods as a more efficient alternative, especially for large datasets and complex models.
General Approach (Xu et al. 2024)
The process, as illustrated in the image, consists of four key steps:
1. Influence Computation: The process begins with the "Dataset" shown on the left side of the image. From this dataset, we identify "data to forget" - specific training examples we want the model to unlearn. The "Influence" step analyzes how these data points impact the model's predictions. This analysis employs influence functions, a technique from robust statistics. For a given data point z, its influence I(z) is calculated as:
Where:
Hθ is the Hessian matrix of the loss function L,
θ represents the model parameters,
∇θ L(z, θ) is the gradient of the loss with respect to θ.
Computing this for large models can be challenging due to the Hessian inverse. Practical implementations often use approximation techniques like the conjugate gradient method to estimate Hθ-1∇θL efficiently.
2. Model Parameter Adjustment: The image shows a transition from an initial neural network to an adjusted one. This step modifies the model parameters to counteract the influence of the data to be forgotten. The adjustment typically follows the direction opposite to the computed influence:
where ε is a small step size. In practice, this update may be applied iteratively or combined with optimization techniques like Adam or RMSprop for stability.
3. Noise Addition: The "noise" element above the neural network in the image represents the introduction of controlled noise to the model. This step implements differential privacy techniques to prevent inference of recently removed data. The noise is calibrated based on the sensitivity of the unlearning operation. For Gaussian noise, we might add noise drawn from N(0, σ2) to each parameter, where σ is calculated as:
Here, c is a constant, Δf is the sensitivity of the unlearning operation, and ε is the desired privacy parameter. The choice between Gaussian and Laplacian noise depends on the specific privacy guarantees required.
4. Updated Model Validation: The final neural network in the image, labeled "New model with influence minimized" represents the outcome of the unlearning process. This step assesses the performance of the new model to ensure effective unlearning while maintaining overall functionality.
Validation employs metrics such as:
Accuracy: (TP + TN) / (TP + TN + FP + FN)
F1 Score: 2 * (Precision * Recall) / (Precision + Recall)
AUC-ROC: Area under the Receiver Operating Characteristic curve
These metrics are compared against both the original model and an unlearned model.
Additionally, specific tests may be conducted to ensure the removed data cannot be inferred, such as membership inference attacks or analyzing the model's uncertainty on the forgotten examples.
This process provides a balance between unlearning efficiency and effectiveness, making it applicable to a wide range of machine learning models and datasets.
Below, we will go in-depth into two approximate methods of machine unlearning: Certified Removal and Unlearning in Spiking Neural Networks.
2.1 Certified Removal
Guo et al. (2019) unveiled certified removal, a big step in approximate machine unlearning. This method offers theoretical safeguards against adversarial extraction of removed training data. The image provided illustrates the key components and process flow of certified removal.
How Certified Removal Works
Certified removal operates by strategically adjusting the model to "forget" specific data points. Here's a simplified breakdown of the process.
Noise Injection: During the initial training, a controlled amount of noise is added to the model. This noise acts as a form of protection, making it harder for an adversary to extract information about individual data points.
Influence Calculation: For each piece of data in the forget set, the method calculates its "influence" on the model. This influence represents how much that specific data point contributed to the model's current state.
Parameter Adjustment: Using the calculated influence, the model's internal parameters are carefully adjusted. This adjustment effectively cancels out the impact of the forgotten data points.
The result is a new model that behaves as if it had never seen the forgotten data in the first place.
Why Certified Removal Matters
Guo et al. (2019) demonstrated the efficiency of certified removal compared to full retraining. In their experiments with a linear model trained on the MNIST dataset:
Certified removal took 0.04 seconds to remove a data point.
Full retraining of the model took 15.6 seconds.
These results highlight the significant time savings offered by certified removal, especially when dealing with large datasets or frequent removal requests.
2.2 Unlearning in Spiking Neural Networks
Spiking Neural Networks (SNNs) mimic the behavior of biological neurons in the human brain (Maass 1997). As shown in the figure below, SNNs consist of interconnected neuron-like units that communicate through a processing network.
In SNNs, each neuron has a membrane potential that changes over time. When this potential reaches a specific threshold, the neuron "fires" or "spikes," sending a signal along its axon (represented by orange lines in the image). This spiking behavior is a key characteristic that distinguishes SNNs from traditional artificial neural networks.
Wang et al. (2023) introduced a technique for unlearning in SNNs, addressing the need for privacy mechanisms and unlearning techniques in these biology inspired models. SNNs have shown success in pattern recognition, particularly in speech (Wang et al. 2023) and image recognition (Su et al. 2023). They have also demonstrated potential in medical applications, such as constructing stimulation systems for Parkinson's patients (Geiger 2023).
How SNN Unlearning works
It happens in three phases, which can be understood in the context of the image above.
Selective Retraining: This phase identifies neurons (the branching structures) and synapses (the connections between neurons, represented by arrows in the central network) responsible for the information to be forgotten. It estimates the correlation between a neuron's spike train (the pattern of signals sent along the orange lines) and the targeted data. Synapses are selected based on their weight change due to learning the targeted data. The weights of these synapses are then adjusted using a modified learning rule.
Synaptic Pruning: This step aims to remove synapses whose weight change surpasses a given threshold. In the context of the image, this would involve selectively removing some of the neuron connections in the interconnected network. All synapses in the network are verified and removed if necessary, effectively eliminating traces of the targeted data.
Adaptive Thresholding: In this phase, neuron firing thresholds are dynamically modified based on their activity in relation to the targeted data. This would affect how easily the neurons "fire" or send signals along the orange lines. This reduces the neuron's response to stimuli linked with the data to be unlearned.
The image shows input neurons on the left (Input 1, Input 2), hidden layers in the center (H11, H12, H13, H21, H22, H23), and output neurons on the right (Output 1, Output 2). The unlearning process would involve modifying the connections and behaviors of these components to "forget" specific information.
Why SNN Unlearning matters
Wang et al. (2023) tested the unlearning effectiveness on two datasets: UCI HAR and MNIST. The results showed:
A decrease in performance metrics (accuracy, precision, recall, etc.) for both datasets after unlearning. This indicates that the network successfully "forgot" the targeted information.
The UCI HAR dataset exhibited a bigger drop in performance compared to MNIST, indicating dataset-specific resilience to the unlearning process. This suggests that the complexity of the data and the way it's encoded in the network can affect the unlearning process.
Retraining after unlearning recovered a considerable portion of the lost performance, especially for the MNIST dataset. This demonstrates the network's ability to relearn and adapt after the unlearning process.
A clear trade-off emerged between the percentage of samples unlearned and accuracy loss. As more samples were removed, accuracy decreased. This highlights the balance between preserving overall performance and effectively removing specific information.
This method focuses on minimizing the impact of neurons and synapses highly correlated with the forget set. By targeting only these elements, as represented by the specific connections and nodes in the central network of the image, the approach avoids full retraining. This makes it an approximate method for unlearning in SNNs, offering a balance between effective information removal and computational efficiency.
2.3 Other Approximate Methods
The field of approximate machine unlearning continues to evolve beyond the work of Guo et al. (2019). Several researchers have proposed innovative approaches.
Sekhari et al. (2021) introduced an algorithm utilizing cheap-to-store data statistics. This method is great for convex loss functions and can unlearn a significant number of samples without full dataset access.
Suriyakumar & Wilson (2022) developed an online unlearning algorithm based on the infinitesimal jackknife method. Their approach reduces computational overhead by inverting the Hessian matrix only once.
Mehta et al. (2022) proposed a parameter selection technique using conditional independence tests. Their L-CODEC and L-FOCI algorithms identify relevant model parameters for unlearning, avoiding full Hessian matrix inversion.
Wu et al. (2022) presented the Performance Unchanged Model Augmentation (PUMA) method. PUMA removes unique characteristics of marked data points while preserving model performance through influence function calculations.
Tanno et al. (2022) introduced a "predictive approach" for identifying causes of model failures in medical imaging. They adapted the Elastic Weight Consolidation method to compute training example influence on failure sets.
Warnecke et al. (2021) proposed a framework for unlearning features and labels instead of entire data points. Their method leverages influence functions to perform closed-form updates on model parameters.
These advancements demonstrate the ongoing efforts to improve machine unlearning efficiency and applicability across various domains and model types.
3. Prompt-Based Unlearning
Prompt-based unlearning techniques aim to make language models "forget" specific information without directly modifying their parameters. These methods are particularly relevant for SOTA large language models (LLMs) where direct access to model parameters is often restricted.
It's important to stress that prompt-based methods do not facilitate machine unlearning as per our earlier definition, they provide a way to "pretend" to forget information.
How Prompt-Based methods work
Guardrails
Thaker et al. (2024) explored the effectiveness of simple guardrail approaches for unlearning in LLMs.
Prompting: Crafting specific prompts to guide the model's behavior.
Input/Output Filtering: Screening inputs and outputs to prevent unwanted information.
These techniques were tested on three different benchmarks:
In-Context Unlearning (ICUL)
Inspired by in-context learning (Brown et al. 2020), ICUL (Pawelczyk et al., 2024) involves constructing a specific context in a prompt that includes both correctly labeled and mislabeled examples. The process involves three steps.
Relabeling forget points: Choose a number of data points you want the model to forget. For each of these points, change its original label to a different, incorrect label. This creates a list of items, each with its content and a new, incorrect label.
Adding correct examples: Select a number of correctly labeled examples from your dataset. Add these examples to the list created in step 1. Now you have a longer list that includes both the relabeled "forget" points and some correctly labeled examples.
Creating the final prompt: Take the list from step 2 and add your actual query or question at the end. When you submit this to the language model, set the temperature to 0, which makes the model's responses more deterministic (less random).
Why Prompt-Based methods matter
Guardrails (Thaker et al., 2024)
Achieved unlearning performance comparable to more complex fine-tuning approaches.
Offered a simple, resource-efficient approach to unlearning in LLMs.
In-Context Unlearning (ICUL) (Pawelczyk et al., 2024)
Demonstrated performance equal to or better than some leading unlearning methods that require access to model parameters.
Effectively eliminated the influence of a training point on a model's output in text classification and question-answering tasks.
Used significantly less memory compared to traditional unlearning methods such as gradient ascent.
Both methods address the challenge of unlearning in black-box LLMs without requiring direct access to model parameters. This is crucial given the widespread use of LLMs in various professional and personal contexts, where avoiding outputs like hate speech, toxic behavior, or hallucinations due to data poisoning is essential.
3. Decentralized Machine Unlearning
Decentralized machine unlearning addresses the challenge of removing specific information from AI systems geographically distributed across interconnected devices.
The HDUS (Heterogeneous Decentralized Unlearning framework with Seed model distillation) method tackles this complex issue when knowledge disseminates through a network of collaborating machines (Ye et al. 2023).
This approach is vital in a connected world where decentralized systems like edge computing expand, and data privacy laws demand deletion rights.
How HDUS Works
HDUS uses a dual-model approach for greater precision and efficiency.
Main Model: A high-capacity neural network trained on local data, functioning as the core intelligence.
Seed Model: A lightweight version that distills insights from the main model, enabling peer collaboration.
The seed model, trained on a reference (or synthetic) data sample, mimics the main model’s behavior without revealing sensitive data. In collaborative training, peers share seed model parameters. For inference, each peer generates output using its main model and an ensemble of neighbors' seed models.
The above image shows how client a₁ processes data using its main model, extracts key insights through a seed model, and integrates inputs from neighboring clients (b₁, b₂, …, bₖ) to generate a final output. The diagram highlights the complex connections between processing units, storage, and data paths, reflecting the intricate nature of the HDUS framework.
When a device exits, it triggers unlearning by notifying its neighbors. They remove its contribution from their submodels, ensuring its influence is erased without needing full model retraining.
Why HDUS Matters
HDUS offers critical advantages by leveraging decentralized machine unlearning:
1. Efficiency: Fast unlearning without full model retraining, using lightweight seed models for updates.
2. Compatibility: Works across diverse device networks, supporting varied model architectures.
3. Complete removal: Ensures total erasure of a departing device's contribution, ensuring privacy.
4. Exact unlearning: HDUS achieves precise data removal in a decentralized setting, maintaining unlearning integrity across distributed nodes.
HDUS overcomes the challenge of propagating unlearning requests in networks where knowledge has already spread across peers. It maintains integrity across dynamic network topologies, vital for edge computing and federated learning environments.
The HDUS approach removes data associated with specific devices, not individual data points. This shift is crucial for privacy management in distributed AI systems, especially in environments like mobile edge computing and IoT, where devices frequently join or leave.
Unlike other approaches, HDUS focuses on removing a party's entire contribution, advancing privacy and data rights in decentralized AI.
TLDR
Machine unlearning transforms AI privacy by erasing data from trained models without full retraining. Companies like META and JPMorgan Chase leverage this innovation.
Unlearning specific data points is often exponentially faster than full retraining or distillation. Cutting computational power and storage needs. It removes targeted information, preserves model performance, and enables selective forgetting without losing other critical knowledge.
At Bagel, we fuse breakthrough technology like machine unlearning into an evolving Machine Learning ecosystem, redefining privacy and collaboration in AI.
Four major methods drive machine unlearning: Exact (SISA, SQ), Approximate, Prompt-Based, and Decentralized (HDUS). Each offers unique approaches to efficient data removal, with out full model retraining.
Method comparison:
Machine unlearning enables AI to adapt, correct, and protect privacy at scale.
Bagel is a deep machine learning and cryptography research lab. Making Open Source AI monetizable using cryptography.