Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| from miditoolkit import MidiFile | |
| from src.model.generate import batch_performance_render, map_midi | |
| from src.model.pianoformer import PianoT5Gemma | |
| # ------------------------------ | |
| # Load model | |
| # ------------------------------ | |
| def load_model(): | |
| print("Loading model...") | |
| model = PianoT5Gemma.from_pretrained( | |
| "yhj137/pianist-transformer-rendering", | |
| token=os.environ.get("hf_token"), | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ------------------------------ | |
| # Define inference function | |
| # ------------------------------ | |
| def render_midi(midi_file, temperature, top_p): | |
| try: | |
| input_path = midi_file.name | |
| midi = MidiFile(input_path) | |
| # Run inference | |
| res = batch_performance_render( | |
| model, | |
| [midi], | |
| temperature=temperature, | |
| top_p=top_p, | |
| device="cpu" # change to "cuda" if GPU available | |
| ) | |
| # Save raw (unmapped) result | |
| raw_out_path = "raw_render.mid" | |
| res[0].dump(raw_out_path) | |
| # Try to create editable (mapped) version | |
| editable_out_path = "editable_render.mid" | |
| try: | |
| mapped = map_midi(midi, res[0]) | |
| mapped.dump(editable_out_path) | |
| return [raw_out_path, editable_out_path] | |
| except Exception as e: | |
| print(f"[Warning] map_midi failed: {e}") | |
| return [raw_out_path, f"[Error] map_midi failed: {e}"] | |
| except Exception as e: | |
| raise gr.Error(f"Inference failed: {e}") | |
| # ------------------------------ | |
| # Build Gradio interface | |
| # ------------------------------ | |
| demo = gr.Interface( | |
| fn=render_midi, | |
| inputs=[ | |
| gr.File(label="Upload a Score MIDI File (.mid or .midi)", file_types=[".mid", ".midi", ".MID", ".MIDI"]), | |
| gr.Slider(0.1, 2.0, value=1.0, step=0.01, label="Temperature"), | |
| gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p"), | |
| ], | |
| outputs=[ | |
| gr.File(label="Raw Performance"), | |
| gr.File(label="Editable Version") | |
| ], | |
| title="🎹 Pianist Transformer Rendering", | |
| description=( | |
| "Upload a piano score MIDI file and let the Pianist Transformer render it into " | |
| "a more expressive performance MIDI.\n\n" | |
| "Two versions will be saved:\n\n" | |
| "• **Raw Performance** – directly generated by the model\n\n" | |
| "• **Editable Version** – aligned with the score using our Expressive Tempo Mapping algorithm\n\n" | |
| "If mapping fails, only the raw version will be returned with an error message.\n\n" | |
| "⚠️ **This is only a demo running on limited compute resources. Please do not upload long pieces — " | |
| "we recommend clips shorter than 1 minute.**" | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |