I made Qwen2-0.5B run 100ร— smaller with no matrix multiplications and almost no RAM usage โ€” adaptive sparse selection at inference time, and it actually works.

contact: Twitter: https://x.com/liberal17th email: bogunusov@gmail.com

Status: experimental / work in progress. This is a research test, not a production release, not a compression method, and not a claim of a new architecture. Numbers below are placeholders โ€” real plots and stats will be added once benchmark runs are complete.Currently, according to tests, it doesn't use any RAM at all, but there may be errors. With this compression, we compressed by about 30%, but the main matrices are not multiplied.

What this actually is

During autoregressive generation, this repo tracks a small set of statistical features (mean, std, quantiles, rolling window stats, autocorrelation โ€” 64 features per layer) computed from the input activations hitting each attention/FFN weight matrix in Qwen2-0.5B. A lightweight Bayesian selector then flags which of those features deviate meaningfully from their running distribution at each generation step, instead of treating every feature as equally relevant every time.

The output metric is simple: what fraction of tracked features get flagged as informative per step, averaged over a generation. That's it. It's an exploration of whether activation statistics carry sparse, structured signal during inference โ€” not a finished result and not a benchmark win yet.

Why this might be interesting anyway

Most work on transformer internals looks at weights (pruning, quantization, low-rank decomposition). This script instead asks: at inference time, does the activation stream flowing through each layer have a small, identifiable subset of statistics that matter more than the rest at any given step? If that subset is small and stable, it's a hint (not proof) that there's structure worth digging into โ€” for interpretability, for adaptive compute, or just as a diagnostic tool for understanding what a layer is "paying attention to" numerically.

๐Ÿ“ฆ Model Size Comparison

Qwen2-0.5B 900 MB
Adaptive Sparse (this) 10 MB
Original This model 90ร— smaller

๐Ÿง  RAM Usage During Inference(testing on google colab)

Qwen2-0.5B (original) 4 GB
Adaptive Sparse (this) ~0 GB
Original This model 4 GB โ†’ 0 GB

Files

File What it does
terminal_chat_bayesian.py Main experiment. Loads Qwen2-0.5B, hooks every attention/FFN weight's input activations, runs the Bayesian feature selector during generation, prints the fraction of flagged features per response. Requires bayes_analysis.safetensors (see below).
storage_reconstruction_test.py Secondary test. Splits weight tensors into (mean_scalar, residual_tensor) across a JSON + safetensors file, reconstructs on load. Included for transparency โ€” this is a loading mechanics test, not a result.

Requirements

pip install torch transformers safetensors numpy

CUDA GPU required for terminal_chat_bayesian.py (checks torch.cuda.is_available() and will exit if not found). storage_reconstruction_test.py runs on CPU.

How to run

1. Bayesian feature selector chat (main experiment)

You need a bayes_analysis.safetensors file in the working directory containing precomputed per-layer feature tensors (keys ending in __feat). This file is produced by a separate analysis pass over the model's weights โ€” generate it before running this script, or use the one provided in this repo's Files tab if included.

python terminal_chat_bayesian.py

In the chat session:

  • Type normally to talk to the model
  • /stats โ€” shows how many features were flagged vs. total possible in the last response
  • /bayes โ€” shows the top 10 layers by number of currently-flagged features
  • /clear โ€” resets conversation history
  • /exit โ€” quit

2. Storage/reconstruction test (secondary, not a compression result)

Requires bayesian_features.json and layer_residuals.safetensors in /content/ (paths are hardcoded for Colab โ€” edit json_path / safetensors_path in prepare_fast_hybrid_model() if running elsewhere).

python storage_reconstruction_test.py

This will strip attention/FFN weights from the loaded model and reconstruct them from the two files, then start a basic chat loop. Reconstruction is exact by construction โ€” see the "What this is not" section above for why.

3. compression with Bayes

Show on R2 test 100% score with compression.Compressed approximately 30%

python bayes_compression.py

Code

terminal_chat_bayesian.py

import torch
import numpy as np
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import os
import sys

MODEL_NAME       = "Qwen/Qwen2-0.5B"
MAX_NEW_TOKENS   = 200
TEMPERATURE      = 0.7
ANALYSIS_FILE    = "bayes_analysis.safetensors"
SYSTEM_PROMPT    = "You are a helpful assistant."
NUM_FEATURES     = 64
BAYES_EVERY_N    = 8      # compute bayes stats every N tokens instead of every token
BAYES_ENABLED    = True   # can be fully disabled with this flag


def _row_features_torch(x: torch.Tensor, n_features: int = NUM_FEATURES) -> torch.Tensor:
    x = x.float()
    L = x.shape[0]
    mean = x.mean()
    std = x.std(unbiased=False)
    abs_x = x.abs()

    feats = torch.zeros(n_features, dtype=torch.float32, device=x.device)
    feats[0] = mean
    feats[1] = std
    feats[2] = x.max()
    feats[3] = x.min()

    q = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75, 0.05, 0.10, 0.90, 0.95], device=x.device))
    feats[4], feats[5], feats[6] = q[0], q[1], q[2]
    feats[16], feats[17], feats[18], feats[19] = q[3], q[4], q[5], q[6]

    feats[7] = (x > mean + std).sum()
    feats[8] = (x < mean - std).sum()
    feats[9] = abs_x.mean()
    feats[10] = abs_x.median()

    w = 8
    if L >= w:
        wins = x.unfold(0, w, 1)
        feats[11] = wins.mean(dim=1).mean()
        feats[12] = wins.std(dim=1, unbiased=False).mean()
        feats[13] = wins.max(dim=1).values.mean()
        feats[14] = wins.min(dim=1).values.mean()
        feats[15] = x.diff().abs().mean()
    else:
        feats[11], feats[12], feats[13], feats[14], feats[15] = mean, std, x.max(), x.min(), 0.0

    if L > 1 and std > 1e-12:
        a, b = x[:-1], x[1:]
        a_c, b_c = a - a.mean(), b - b.mean()
        denom = torch.sqrt((a_c * a_c).sum() * (b_c * b_c).sum())
        feats[20] = (a_c * b_c).sum() / denom if denom > 1e-12 else 0.0
    else:
        feats[20] = 0.0

    return feats[:n_features]


class BayesData:
    def __init__(self, path: str = ANALYSIS_FILE):
        if not os.path.exists(path):
            print(f"[error] {path} not found. Run the analysis pass first to generate it.")
            sys.exit(1)
        print("[bayes-data] loading from safetensors ...")
        raw = load_file(path)
        self.layers = {}
        names = {k[: -len("__feat")] for k in raw.keys() if k.endswith("__feat")}
        for sk in names:
            param_name = sk.replace("__", ".")
            self.layers[param_name] = {"feat": raw[f"{sk}__feat"].float().numpy()}
        print(f"[bayes-data] loaded {len(self.layers)} layers")

    def get(self, param_name):
        return self.layers.get(param_name)

    def num_features_for(self, param_name) -> int:
        data = self.layers.get(param_name)
        return 1 if data is None else data["feat"].shape[1]


class BayesianFeatureSelector:
    def __init__(self, n_features: int, device):
        self.n_features = n_features
        self.marked_counts   = torch.ones(n_features, dtype=torch.float32, device=device)
        self.unmarked_counts = torch.ones(n_features, dtype=torch.float32, device=device)
        self.running_mean = torch.zeros(n_features, dtype=torch.float32, device=device)
        self.running_var  = torch.ones(n_features, dtype=torch.float32, device=device)
        self.n_seen = 0

    def select(self, feat_vector: torch.Tensor) -> torch.Tensor:
        if self.n_seen == 0:
            return torch.arange(self.n_features, device=feat_vector.device)
        std = torch.sqrt(self.running_var) + 1e-8
        deviation = (feat_vector - self.running_mean).abs() / std
        marked = torch.where(deviation > 1.0)[0]
        if marked.numel() == 0:
            priors = self.marked_counts / (self.marked_counts + self.unmarked_counts)
            marked = priors.argmax().unsqueeze(0)
        return marked

    def update(self, feat_vector: torch.Tensor, marked_idx: torch.Tensor):
        marked_mask = torch.zeros(self.n_features, dtype=torch.bool, device=feat_vector.device)
        marked_mask[marked_idx] = True
        self.marked_counts[marked_mask]    += 1
        self.unmarked_counts[~marked_mask] += 1

        self.n_seen += 1
        delta = feat_vector - self.running_mean
        self.running_mean += delta / self.n_seen
        delta2 = feat_vector - self.running_mean
        self.running_var += (delta * delta2 - self.running_var) / self.n_seen
        self.running_var.clamp_(min=1e-8)


class LayerBayesRegistry:
    def __init__(self, layer_names: list, n_features: int, device):
        self.selectors   = {name: BayesianFeatureSelector(n_features, device) for name in layer_names}
        self.layer_order = layer_names
        self.n_features  = n_features

    def select_for(self, layer_name: str, feat_vector: torch.Tensor) -> torch.Tensor:
        return self.selectors[layer_name].select(feat_vector)

    def observe(self, layer_name: str, feat_vector: torch.Tensor, marked_idx: torch.Tensor):
        self.selectors[layer_name].update(feat_vector, marked_idx)

    def state_summary(self) -> dict:
        out = {}
        for name, sel in self.selectors.items():
            if sel.n_seen == 0:
                out[name] = sel.n_features
            else:
                priors = sel.marked_counts / (sel.marked_counts + sel.unmarked_counts)
                out[name] = int((priors > 0.5).sum().item())
        return out


def build_bayes_registry(model, bayes_data, device) -> LayerBayesRegistry:
    layer_names = []
    n_features  = NUM_FEATURES
    for name, module in model.named_modules():
        param_name = f"{name}.weight"
        if bayes_data.get(param_name) is not None and hasattr(module, "weight"):
            layer_names.append(param_name)
            n_features = bayes_data.num_features_for(param_name)
    print(f"[registry] {len(layer_names)} layers, n_features={n_features}")
    return LayerBayesRegistry(layer_names, n_features, device)


def generate_with_bayes_scalar(model, tokenizer, history, bayes_data, bayes_registry):
    try:
        prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
    except Exception:
        prompt = "\n".join(f"{m['role'].upper()}: {m['content']}" for m in history) + "\nASSISTANT:"

    inputs    = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]

    total_scalars_used = 0
    total_elements_all  = 0
    activation_store    = {}
    hooks                = []
    layer_meta = {}

    if BAYES_ENABLED:
        for name, module in model.named_modules():
            param_name = f"{name}.weight"
            data = bayes_data.get(param_name)
            if data is None or not hasattr(module, "weight"):
                continue
            w = module.weight
            dim_in = w.shape[1] if w.ndim >= 2 else w.shape[0]
            n_features = data["feat"].shape[1]
            layer_meta[param_name] = (dim_in, n_features)

            def make_hook(pn, di):
                def hook_fn(module, inp, out):
                    x_in = inp[0]
                    if x_in.ndim == 3:
                        x_t = x_in[0, -1, :]
                    elif x_in.ndim == 2:
                        x_t = x_in[0, :]
                    else:
                        return
                    if x_t.shape[0] == di:
                        activation_store[pn] = x_t.detach()
                return hook_fn

            hooks.append(module.register_forward_hook(make_hook(param_name, dim_in)))

    vocab_size = tokenizer.vocab_size or model.config.vocab_size
    new_tokens_list = []
    step_counter = 0

    with torch.no_grad():
        past_key_values = None
        cur_input = input_ids

        for step in range(MAX_NEW_TOKENS):
            if step == 0:
                out = model(input_ids=cur_input, use_cache=True)
            else:
                out = model(input_ids=cur_input, past_key_values=past_key_values, use_cache=True)

            past_key_values = out.past_key_values
            logits = out.logits[:, -1, :vocab_size].float()

            torch.nan_to_num_(logits, nan=0.0, posinf=1e4, neginf=-1e4)
            logits.div_(max(TEMPERATURE, 1e-6))

            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            probs_sorted = torch.softmax(sorted_logits, dim=-1)
            cumprobs     = torch.cumsum(probs_sorted, dim=-1)
            mask = (cumprobs - probs_sorted) > 0.9
            sorted_logits[mask] = -1e9

            probs = torch.softmax(sorted_logits, dim=-1)
            probs.clamp_(min=0.0)
            s = probs.sum(dim=-1, keepdim=True)
            if not (s == 0).any():
                probs.div_(s)
            else:
                probs.fill_(1.0 / probs.shape[-1])

            next_sorted = torch.multinomial(probs, num_samples=1)
            next_token  = sorted_idx.gather(-1, next_sorted)
            next_id = next_token.item()
            new_tokens_list.append(next_id)

            if BAYES_ENABLED and (step_counter % BAYES_EVERY_N == 0) and activation_store:
                for param_name, x_t in activation_store.items():
                    dim_in, n_features = layer_meta[param_name]
                    feat_vector = _row_features_torch(x_t, n_features)
                    marked_idx  = bayes_registry.select_for(param_name, feat_vector)
                    total_scalars_used += marked_idx.numel()
                    total_elements_all += n_features
                    bayes_registry.observe(param_name, feat_vector, marked_idx)
            activation_store.clear()
            step_counter += 1

            if next_id == tokenizer.eos_token_id:
                break

            cur_input = next_token

    for h in hooks:
        h.remove()

    response_text = tokenizer.decode(new_tokens_list, skip_special_tokens=True)
    pct = 100.0 * total_scalars_used / total_elements_all if total_elements_all > 0 else 0.0
    return response_text, total_scalars_used, total_elements_all, pct


BANNER = """
+==================================================================+
|  Qwen2-0.5B x Bayesian Minimal Feature Selection                |
|  /stats  - stats for the last response                          |
|  /bayes  - state of the bayesian models (top 10 by k)           |
|  /clear  - clear history                                        |
|  /exit   - quit                                                 |
+==================================================================+
"""


def chat(model, tokenizer, bayes_data, bayes_registry):
    print(BANNER)
    history    = [{"role": "system", "content": SYSTEM_PROMPT}]
    last_stats = None

    while True:
        try:
            user = input("You: ").strip()
        except (EOFError, KeyboardInterrupt):
            print("\nExiting.")
            break

        if not user:
            continue
        if user == "/exit":
            break
        if user == "/clear":
            history = [{"role": "system", "content": SYSTEM_PROMPT}]
            print("[history cleared]")
            continue
        if user == "/stats":
            if last_stats:
                sc, el, pct = last_stats
                print(f"\n  Scalars flagged      : {sc:,}")
                print(f"  Total possible       : {el:,}")
                print(f"  Fraction flagged     : {pct:.4f}%\n")
            else:
                print("[no data yet - send a message first]")
            continue
        if user == "/bayes":
            summary = bayes_registry.state_summary()
            print("\n  [bayesian state - top 10 layers by k]")
            for name, k in sorted(summary.items(), key=lambda x: -x[1])[:10]:
                print(f"    {name:<55} k={k}")
            print()
            continue

        history.append({"role": "user", "content": user})
        t0 = time.time()

        resp, scalars_used, total_elements, pct = generate_with_bayes_scalar(
            model, tokenizer, history, bayes_data, bayes_registry
        )

        history.append({"role": "assistant", "content": resp})
        elapsed    = time.time() - t0
        last_stats = (scalars_used, total_elements, pct)

        print(f"\nModel ({elapsed:.1f}s): {resp}")
        print(f"\n  +- Bayesian minimal feature selection -----------------+")
        print(f"  |  Flagged      : {scalars_used:>15,}                  |")
        print(f"  |  Total        : {total_elements:>15,}                  |")
        print(f"  |  Fraction     : {pct:>14.4f} %                  |")
        print(f"  +--------------------------------------------------------+\n")


if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("[error] CUDA not available. This script is configured for GPU.")
        sys.exit(1)

    device = "cuda"
    print(f"[start] device: {device}")

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32       = True
    torch.backends.cudnn.benchmark        = True

    print(f"\n[1/3] Loading {MODEL_NAME} ...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map=device,
        trust_remote_code=True,
    )
    model.eval()

    print("\n[2/3] Loading features from analysis file ...")
    bayes_data = BayesData()

    print("\n[3/3] Initializing bayesian feature selection registry ...")
    bayes_registry = build_bayes_registry(model, bayes_data, device)

    chat(model, tokenizer, bayes_data, bayes_registry)

storage_reconstruction_test.py

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import json
import os
from safetensors.torch import load_file

# ==========================================
# 1. STORAGE / LOADING TEST (NOT A COMPRESSION RESULT)
# ==========================================
# NOTE: reconstruction below is mean + residual, which is mathematically
# exact by construction (mean + (original - mean) = original).
# R2 = 1.0 is expected here and does not indicate compression -
# it indicates the two files together contain the same information
# as the original weight, just split across two files.

class FastBayesianStorage:
    """Weight storage split across two files, for testing a load pipeline"""
    def __init__(self):
        self.base_predictions = {}
        self.layer_residuals = {}
        self.layer_shapes = {}

    def decompress_layer(self, name):
        """Exact reconstruction: mean_val + residual = original (by construction)"""
        shape = self.layer_shapes[name]
        mean_val = self.base_predictions[name]
        residual = self.layer_residuals[name]

        reconstructed = np.full(residual.shape, mean_val, dtype=np.float32) + residual

        return torch.from_numpy(reconstructed).view(shape)

    def load_from_files(self, json_path="/content/bayesian_features.json", safetensors_path="/content/layer_residuals.safetensors"):
        """Loads scalar features from JSON and residual tensors from Safetensors"""
        print(f"\n[Import] Loading features and layer residuals from files...")

        # 1. Load metadata and scalar features
        with open(json_path, "r", encoding="utf-8") as f:
            json_data = json.load(f)

        self.base_predictions = json_data["base_predictions"]
        self.layer_shapes = json_data["layer_shapes"]
        print(f"  -> Scalar features and shapes loaded from: {json_path}")

        # 2. Load residual tensors (convert Torch -> NumPy for reconstruction)
        tensors_dict = load_file(safetensors_path)
        for name, tensor in tensors_dict.items():
            self.layer_residuals[name] = tensor.numpy()
        print(f"  -> Residual tensors loaded from: {safetensors_path}")


# ==========================================
# 2. BUILD HYBRID MODEL FROM SPLIT FILES
# ==========================================

def prepare_fast_hybrid_model(model_name="Qwen/Qwen2-0.5B"):
    start_time = time.time()
    print(f"Loading base model and tokenizer {model_name}...")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float32, device_map="cpu", low_cpu_mem_usage=True
    )

    target_layers = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
    compressed_layer_names = []

    print("\n[Process] Removing original attention/FFN weight tensors (preparing to load from files)...")
    for name, param in list(model.named_parameters()):
        if any(target in name for target in target_layers) and "weight" in name:
            compressed_layer_names.append(name)

            # remove original weight to simulate a clean storage state
            delattr(model.get_submodule(name.rsplit('.', 1)[0]), 'weight')

    print(f"\n[Done] Structure preparation time: {time.time() - start_time:.2f} sec.")
    return model, tokenizer, compressed_layer_names

# ==========================================
# 3. TERMINAL CHAT
# ==========================================

def run_fast_terminal_chat():
    # Paths to your prepared files
    json_path = "/content/bayesian_features.json"
    safetensors_path = "/content/layer_residuals.safetensors"

    # Build empty model structure
    model, tokenizer, compressed_names = prepare_fast_hybrid_model()

    # Initialize storage and load the already-prepared files (no overwrite)
    storage = FastBayesianStorage()
    storage.load_from_files(json_path=json_path, safetensors_path=safetensors_path)

    # Reconstruct weights from loaded files
    start_restore = time.time()
    print("\n[Info] Reconstructing weight tensors from loaded files...")
    for name in compressed_names:
        restored_tensor = storage.decompress_layer(name)
        submodule = model.get_submodule(name.rsplit('.', 1)[0])
        submodule.weight = torch.nn.Parameter(restored_tensor)
    print(f"[Done] All weights reconstructed (R2=1.0 by construction, see note above) in: {time.time() - start_restore:.2f} sec!")

    print("\n" + "="*50)
    print("  QWEN-0.5B CHAT - RECONSTRUCTED FROM SPLIT FILES")
    print("  Type 'exit' to quit.")
    print("="*50 + "\n")

    while True:
        user_input = input("You: ")
        if user_input.lower() in ['exit', 'quit']:
            break

        if not user_input.strip():
            continue

        messages = [{"role": "user", "content": user_input}]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = tokenizer([text], return_tensors="pt")

        print("Qwen: ", end="", flush=True)
        generated_ids = model_inputs.input_ids

        with torch.no_grad():
            for _ in range(70):
                outputs = model(input_ids=generated_ids)
                next_token_logits = outputs.logits[:, -1, :]
                next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

                token_str = tokenizer.decode(next_token[0], skip_special_tokens=True)
                print(token_str, end="", flush=True)

                generated_ids = torch.cat([generated_ids, next_token], dim=-1)
                if next_token.item() in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")]:
                    break
        print("\n" + "-"*50)

if __name__ == "__main__":
    run_fast_terminal_chat()

Code

bayes_compression.py

from __future__ import annotations

import hashlib
import json
import os
import struct
import zlib
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Sequence, Tuple

import torch
from safetensors import safe_open
from safetensors.torch import save_file

MAGIC = "__BAYES_PACKET_ZLIB__"
VERSION = 4

DEFAULT_PACKET_MB = 8
RAW_ENTROPY_THRESHOLD = 7.90
MIN_COMPRESS_BYTES = 256 * 1024
ZLIB_LEVEL = 1


def _configure_torch() -> None:
    try:
        torch.set_num_threads(1)
    except Exception:
        pass
    try:
        torch.set_num_interop_threads(1)
    except Exception:
        pass


def dtype_to_name(dtype: torch.dtype) -> str:
    return str(dtype).replace("torch.", "")


def name_to_dtype(name: str) -> torch.dtype:
    return getattr(torch, name)


def tensor_to_raw_bytes(t: torch.Tensor) -> bytes:
    t = t.detach().contiguous().cpu()
    return t.view(torch.uint8).numpy().tobytes()


def raw_bytes_to_tensor(raw: bytes, dtype: torch.dtype, shape: Sequence[int]) -> torch.Tensor:
    if not raw:
        return torch.empty(tuple(shape), dtype=dtype)
    u8 = torch.frombuffer(memoryview(raw), dtype=torch.uint8).clone()
    return u8.view(dtype).reshape(tuple(shape)).contiguous()


def _sha256(data: bytes) -> str:
    return hashlib.sha256(data).hexdigest()


def _packetize(raw: bytes, packet_size: int) -> List[bytes]:
    if packet_size <= 0:
        raise ValueError("packet_size must be positive")
    if not raw:
        return [b""]
    return [raw[i:i + packet_size] for i in range(0, len(raw), packet_size)]


def _bayes_features(raw: bytes) -> Dict[str, float]:
    if not raw:
        return {
            "n": 0,
            "mean": 0.0,
            "std": 0.0,
            "min": 0,
            "max": 0,
            "nonzero": 0,
            "entropy": 0.0,
            "top1_mass": 0.0,
            "hist_sha256": _sha256(b""),
        }

    u8 = torch.frombuffer(memoryview(raw), dtype=torch.uint8)
    n = int(u8.numel())

    counts = torch.bincount(u8.to(torch.int64), minlength=256).to(torch.float32)
    posterior = counts + 1.0
    total = float(posterior.sum().item())
    probs = posterior / total

    entropy = float((-(probs * torch.log2(probs.clamp_min(1e-12)))).sum().item())
    top1_mass = float((posterior.max() / total).item())

    f = u8.float()
    mean = float(f.mean().item())
    std = float(f.std(unbiased=False).item()) if n > 1 else 0.0
    mn = int(u8.min().item())
    mx = int(u8.max().item())
    nonzero = int((u8 != 0).sum().item())

    hist_sha256 = _sha256(counts.to(torch.int32).cpu().numpy().tobytes())

    return {
        "n": n,
        "mean": mean,
        "std": std,
        "min": mn,
        "max": mx,
        "nonzero": nonzero,
        "entropy": entropy,
        "top1_mass": top1_mass,
        "hist_sha256": hist_sha256,
    }


def _choose_codec(raw: bytes) -> Tuple[str, bytes, Dict[str, float]]:
    feats = _bayes_features(raw)

    if len(raw) < MIN_COMPRESS_BYTES or feats["entropy"] >= RAW_ENTROPY_THRESHOLD:
        return "raw", raw, feats

    comp = zlib.compress(raw, level=ZLIB_LEVEL)
    if len(comp) >= len(raw):
        return "raw", raw, feats
    return "zlib", comp, feats


def _write_record(out, meta: Dict, payload: bytes) -> None:
    meta_bytes = json.dumps(meta, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
    out.write(struct.pack(">I", len(meta_bytes)))
    out.write(meta_bytes)
    out.write(struct.pack(">I", len(payload)))
    out.write(payload)


def _read_exact(f, n: int) -> bytes:
    data = f.read(n)
    if len(data) != n:
        raise EOFError("Unexpected end of payload")
    return data


def _read_record(f):
    head = f.read(4)
    if not head:
        return None, None
    if len(head) != 4:
        raise EOFError("Corrupted record header")
    meta_len = struct.unpack(">I", head)[0]
    meta = json.loads(_read_exact(f, meta_len).decode("utf-8"))
    payload_len = struct.unpack(">I", _read_exact(f, 4))[0]
    payload = _read_exact(f, payload_len)
    return meta, payload


def _compress_shard_worker(args):
    shard_path, tensor_names, packet_size = args
    shard_name = Path(shard_path).name
    entries = []

    with safe_open(str(shard_path), framework="pt", device="cpu") as f:
        try:
            shard_metadata = f.metadata()
        except Exception:
            shard_metadata = None

        for tensor_name in tensor_names:
            tensor = f.get_tensor(tensor_name)
            raw = tensor_to_raw_bytes(tensor)
            packets = _packetize(raw, packet_size)

            tensor_entry = {
                "name": tensor_name,
                "dtype": dtype_to_name(tensor.dtype),
                "shape": list(tensor.shape),
                "raw_len": len(raw),
                "packet_count": len(packets),
            }

            for packet_index, packet_raw in enumerate(packets):
                codec, payload, feats = _choose_codec(packet_raw)
                packet_meta = {
                    "kind": "packet",
                    "shard_name": shard_name,
                    "tensor_name": tensor_name,
                    "dtype": dtype_to_name(tensor.dtype),
                    "shape": list(tensor.shape),
                    "codec": codec,
                    "packet_index": packet_index,
                    "packet_count": len(packets),
                    "packet_raw_len": len(packet_raw),
                    "sha256": _sha256(packet_raw),
                    "features": feats,
                    "is_last_packet": packet_index == len(packets) - 1,
                    "payload_len": len(payload),
                }
                entries.append((packet_meta, payload))

    return shard_name, shard_metadata, entries


def compress_qwen2_safetensors_fast(
    model_dir: str,
    output_bundle_dir: str,
    packet_mb: int = DEFAULT_PACKET_MB,
    max_workers: int | None = None,
) -> None:
    _configure_torch()

    model_dir = Path(model_dir)
    out_dir = Path(output_bundle_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    shard_files = sorted(model_dir.glob("*.safetensors"))
    if not shard_files:
        raise FileNotFoundError(f"No .safetensors files found in {model_dir}")

    packet_size = max(64 * 1024, packet_mb * 1024 * 1024)
    cpu_count = os.cpu_count() or 1
    if max_workers is None:
        max_workers = max(1, min(cpu_count, len(shard_files), 8))

    manifest = {
        "format": MAGIC,
        "version": VERSION,
        "source_model_dir": str(model_dir),
        "packet_size": packet_size,
        "compression": "zlib",
        "zlib_level": ZLIB_LEVEL,
        "files": [],
    }

    for aux_name in [
        "config.json",
        "generation_config.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "model.safetensors.index.json",
    ]:
        aux_path = model_dir / aux_name
        if aux_path.exists() and aux_path.is_file():
            manifest.setdefault("aux_files", [])
            manifest["aux_files"].append(
                {"name": aux_name, "text": aux_path.read_text(encoding="utf-8")}
            )

    jobs = []
    for shard_path in shard_files:
        with safe_open(str(shard_path), framework="pt", device="cpu") as f:
            tensor_names = list(f.keys())
        jobs.append((str(shard_path), tensor_names, packet_size))

    payload_path = out_dir / "payload.bin"

    if len(jobs) == 1:
        results = [_compress_shard_worker(jobs[0])]
    else:
        results = [None] * len(jobs)
        with ProcessPoolExecutor(max_workers=max_workers) as pool:
            future_map = {pool.submit(_compress_shard_worker, job): i for i, job in enumerate(jobs)}
            for fut in as_completed(future_map):
                idx = future_map[fut]
                results[idx] = fut.result()

    with open(payload_path, "wb") as payload_out:
        for shard_name, shard_meta, entries in results:
            shard_entry = {
                "name": shard_name,
                "metadata": shard_meta,
                "records": len(entries),
                "tensors": [],
            }

            tensor_map = {}
            for packet_meta, payload in entries:
                _write_record(payload_out, packet_meta, payload)

                tname = packet_meta["tensor_name"]
                if tname not in tensor_map:
                    tensor_map[tname] = {
                        "name": tname,
                        "dtype": packet_meta["dtype"],
                        "shape": packet_meta["shape"],
                        "raw_len": 0,
                        "packet_count": packet_meta["packet_count"],
                    }
                tensor_map[tname]["raw_len"] = packet_meta["packet_raw_len"]

            shard_entry["tensors"] = list(tensor_map.values())
            manifest["files"].append(shard_entry)

    with open(out_dir / "manifest.json", "w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False, separators=(",", ":"))

    print("Compression finished.")
    print(f"Payload:  {payload_path}")
    print(f"Manifest: {out_dir / 'manifest.json'}")
    print(f"Packet:   {packet_size / (1024 * 1024):.1f} MB")


def _verify_packet(meta: Dict, raw: bytes) -> None:
    if len(raw) != int(meta["packet_raw_len"]):
        raise ValueError(
            f"Length mismatch for {meta.get('tensor_name')} packet {meta.get('packet_index')}"
        )

    if _sha256(raw) != meta["sha256"]:
        raise ValueError(
            f"SHA256 mismatch for {meta.get('tensor_name')} packet {meta.get('packet_index')}"
        )

    feats = _bayes_features(raw)
    exp = meta["features"]

    if feats["hist_sha256"] != exp["hist_sha256"]:
        raise ValueError(
            f"Histogram signature mismatch for {meta.get('tensor_name')} packet {meta.get('packet_index')}"
        )

    if feats["n"] != exp["n"]:
        raise ValueError(
            f"Feature length mismatch for {meta.get('tensor_name')} packet {meta.get('packet_index')}"
        )

    if int(feats["min"]) != int(exp["min"]) or int(feats["max"]) != int(exp["max"]):
        raise ValueError(
            f"Range feature mismatch for {meta.get('tensor_name')} packet {meta.get('packet_index')}"
        )


def decompress_qwen2_safetensors_fast(
    bundle_dir: str,
    restored_model_dir: str,
) -> None:
    _configure_torch()

    bundle_dir = Path(bundle_dir)
    restored_model_dir = Path(restored_model_dir)
    restored_model_dir.mkdir(parents=True, exist_ok=True)

    manifest_path = bundle_dir / "manifest.json"
    payload_path = bundle_dir / "payload.bin"

    if not manifest_path.exists():
        raise FileNotFoundError(f"Missing manifest.json: {manifest_path}")
    if not payload_path.exists():
        raise FileNotFoundError(f"Missing payload.bin: {payload_path}")

    with open(manifest_path, "r", encoding="utf-8") as f:
        manifest = json.load(f)

    for aux in manifest.get("aux_files", []):
        (restored_model_dir / aux["name"]).write_text(aux["text"], encoding="utf-8")

    shard_meta_map = {entry["name"]: entry.get("metadata") for entry in manifest["files"]}

    current_shard_name = None
    current_state_dict = {}
    current_tensor_parts: Dict[str, List[bytes]] = {}
    current_tensor_meta: Dict[str, Dict] = {}

    def flush_current_shard():
        nonlocal current_state_dict, current_tensor_parts, current_tensor_meta, current_shard_name
        if current_shard_name is None:
            return

        for tensor_name, parts in current_tensor_parts.items():
            meta = current_tensor_meta[tensor_name]
            raw = b"".join(parts)
            dtype = name_to_dtype(meta["dtype"])
            shape = tuple(meta["shape"])
            current_state_dict[tensor_name] = raw_bytes_to_tensor(raw, dtype=dtype, shape=shape)

        out_shard = restored_model_dir / current_shard_name
        save_file(current_state_dict, str(out_shard), metadata=shard_meta_map.get(current_shard_name))

        current_state_dict = {}
        current_tensor_parts = {}
        current_tensor_meta = {}
        current_shard_name = None

    with open(payload_path, "rb") as payload_in:
        while True:
            meta, payload = _read_record(payload_in)
            if meta is None:
                break

            codec = meta["codec"]
            if codec == "zlib":
                raw = zlib.decompress(payload)
            elif codec == "raw":
                raw = payload
            else:
                raise ValueError(f"Unknown codec: {codec}")

            _verify_packet(meta, raw)

            shard_name = meta["shard_name"]
            tensor_name = meta["tensor_name"]

            if current_shard_name is None:
                current_shard_name = shard_name
            elif current_shard_name != shard_name:
                flush_current_shard()
                current_shard_name = shard_name

            if tensor_name not in current_tensor_parts:
                current_tensor_parts[tensor_name] = []
                current_tensor_meta[tensor_name] = meta

            current_tensor_parts[tensor_name].append(raw)

    flush_current_shard()
    print(f"Restored to: {restored_model_dir}")


if __name__ == "__main__":
    source_model_dir = "/content/Qwen2-0.5B"
    bundle_dir = "qwen2_0_5b_bayes_zlib_bundle"
    restored_dir = "qwen2_0_5b_restored"

    compress_qwen2_safetensors_fast(
        source_model_dir,
        bundle_dir,
        packet_mb=8,
    )
    decompress_qwen2_safetensors_fast(bundle_dir, restored_dir)

Results

Placeholder โ€” to be filled in with real numbers from benchmark runs.

  • Fraction of features flagged per layer, averaged across a test set of prompts
  • How the flagged fraction changes over the course of a generation (early tokens vs. late tokens)
  • Per-layer comparison: which layers have consistently high vs. low flagged fractions
  • Any correlation (or lack of one) between flagged fraction and output quality โ€” this is the test that would actually justify calling the flagged subset "informative"

Open questions / next steps

  • Does the flagged feature subset stay stable across different prompts, or does it change drastically session to session?
  • Is there a relationship between which features get flagged and attention patterns in the same layer?
  • Right now BAYES_EVERY_N = 8 and the deviation threshold (> 1.0 std) are picked without tuning โ€” sweeping these would show whether the flagged fraction is a real signal or just a threshold artifact.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ 6 Ask for provider support