"""Local Hugging Face completion playground (Gradio).""" from __future__ import annotations import html import os import threading from typing import Any import gradio as gr import torch from gradio.themes.utils import fonts from gradio.themes.utils.colors import Color from transformers import AutoModelForCausalLM, AutoTokenizer from completion_html import build_completion_html _UZH_BLUE = Color( name="uzh_blue", c50="#BDC9E8", c100="#BDC9E8", c200="#9DADEE", c300="#7596FF", c400="#3062FF", c500="#0028A5", c600="#001E7C", c700="#001452", c800="#001452", c900="#000A28", c950="#000000", ) _UZH_CYAN = Color( name="uzh_cyan", c50="#DBF4F9", c100="#DBF4F9", c200="#B7E9F4", c300="#92DFEE", c400="#4AC9E3", c500="#1EA7C4", c600="#147082", c700="#147082", c800="#0E5A66", c900="#0A3D44", c950="#05282C", ) _UZH_GREY = Color( name="uzh_grey", c50="#FAFAFA", c100="#EFEFEF", c200="#E7E7E7", c300="#E0E0E0", c400="#C2C2C2", c500="#A3A3A3", c600="#666666", c700="#4D4D4D", c800="#333333", c900="#1A1A1A", c950="#000000", ) UZH_THEME = gr.themes.Default( primary_hue=_UZH_BLUE, secondary_hue=_UZH_CYAN, neutral_hue=_UZH_GREY, font=( fonts.GoogleFont("Source Sans 3"), "ui-sans-serif", "system-ui", "sans-serif", ), font_mono=("ui-monospace", "Menlo", "Consolas", "monospace"), ) # Stronger text contrast than default neutral greys (avoid overly light labels / hints) UZH_THEME.set( body_text_color="#1A1A1A", body_text_color_subdued="#4D4D4D", block_label_text_color="#333333", block_info_text_color="#4D4D4D", block_title_text_color="#1A1A1A", input_placeholder_color="#666666", body_text_color_dark="#F0F0F0", body_text_color_subdued_dark="#D0D0D0", block_label_text_color_dark="#EFEFEF", block_info_text_color_dark="#C2C2C2", block_title_text_color_dark="#FAFAFA", input_placeholder_color_dark="#A3A3A3", ) UZH_APP_CSS = """ /* Links: UZH Blue / Blue 3 */ .gradio-container a { color: #0028A5; } .gradio-container a:hover { color: #3062FF; } """ DEFAULT_MODEL_ID = "HuggingFaceTB/SmolLM-135M" _model_lock = threading.Lock() _loaded_model_id: str | None = None _model: Any = None _tokenizer: Any = None def select_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def load_model(model_id: str) -> None: global _loaded_model_id, _model, _tokenizer with _model_lock: if _loaded_model_id == model_id and _model is not None and _tokenizer is not None: return device = select_device() tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_id) model = model.to(device) model.eval() _tokenizer = tokenizer _model = model _loaded_model_id = model_id def _token_piece_text(tokenizer: Any, token_id: int) -> str: return tokenizer.decode([token_id], skip_special_tokens=False) def generate_completion_with_metadata( prompt_text: str, model_id: str, temperature: float, max_new_tokens: int, top_p: float, ) -> tuple[str, list[str], list[float], list[list[dict[str, Any]]], list[bool]]: load_model(model_id) assert _model is not None and _tokenizer is not None device = next(_model.parameters()).device tokenizer = _tokenizer model = _model encoded = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=True) input_ids = encoded["input_ids"].to(device) attention_mask = encoded.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(device) temperature_value = float(temperature) use_sampling = temperature_value > 0.0 generate_kwargs: dict[str, Any] = { "max_new_tokens": int(max_new_tokens), "do_sample": use_sampling, "return_dict_in_generate": True, "output_scores": True, "pad_token_id": tokenizer.pad_token_id, } if use_sampling: generate_kwargs["temperature"] = temperature_value generate_kwargs["top_p"] = float(top_p) if attention_mask is not None: generate_kwargs["attention_mask"] = attention_mask with torch.inference_mode(): outputs = model.generate(input_ids, **generate_kwargs) sequences = outputs.sequences scores = outputs.scores if scores is None: raise RuntimeError("Generation did not return scores; check model.generate arguments.") prompt_length = input_ids.shape[1] generated_ids = sequences[0, prompt_length:] generated_list = generated_ids.tolist() if len(generated_list) != len(scores): raise RuntimeError( f"Score count ({len(scores)}) does not match generated tokens ({len(generated_list)})." ) token_strings: list[str] = [] chosen_probabilities: list[float] = [] top5_alternatives: list[list[dict[str, Any]]] = [] chosen_in_top5_flags: list[bool] = [] for step_index, token_id in enumerate(generated_list): logits = scores[step_index][0] probabilities = torch.softmax(logits.float(), dim=-1) chosen_probability = float(probabilities[token_id].item()) top_k = min(5, probabilities.shape[-1]) top_values, top_indices = torch.topk(probabilities, top_k) top_token_ids = [int(top_indices[rank].item()) for rank in range(top_values.shape[0])] chosen_in_top5 = token_id in top_token_ids alternatives: list[dict[str, Any]] = [] for rank in range(top_values.shape[0]): alternative_id = int(top_indices[rank].item()) alternative_probability = float(top_values[rank].item()) alternatives.append( { "token_text": _token_piece_text(tokenizer, alternative_id), "probability": alternative_probability, } ) token_strings.append(_token_piece_text(tokenizer, token_id)) chosen_probabilities.append(chosen_probability) top5_alternatives.append(alternatives) chosen_in_top5_flags.append(chosen_in_top5) completion_text = tokenizer.decode(generated_list, skip_special_tokens=True) return completion_text, token_strings, chosen_probabilities, top5_alternatives, chosen_in_top5_flags def run_generate( user_prompt: str, max_new_tokens: int, temperature: float, top_p: float, original_text: str, completion_text: str, has_completion: bool, ) -> tuple[Any, ...]: del original_text, completion_text, has_completion try: ( _completion_full, token_strings, chosen_probabilities, top5_alternatives, chosen_in_top5_flags, ) = generate_completion_with_metadata( user_prompt, DEFAULT_MODEL_ID, temperature, max_new_tokens, top_p, ) except Exception as error: gr.Warning(f"Generation failed: {error}") safe_message = html.escape(str(error)) return ( gr.update( value=f'