Files
voice.sori.studio/backend/app/main.py
zenn 7101fdcd65 Initial commit: Korean voice-cloning TTS prototype
FastAPI backend, web UI, CosyVoice3/F5-TTS setup scripts, and handoff docs for GPU PC continuation.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-04 13:36:37 +09:00

171 lines
4.8 KiB
Python

from __future__ import annotations
import shutil
import uuid
from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from backend.app.config import get_settings, project_root
from backend.app.text_preprocess import preprocess_korean
from backend.app.tts.service import TTSService
ROOT = project_root()
WEB_DIR = ROOT / "web"
app = FastAPI(
title="Korean Voice Cloning TTS",
description="CosyVoice / F5-TTS 기반 한국어 보이스 클로닝 API",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
_tts: TTSService | None = None
def get_tts() -> TTSService:
global _tts
if _tts is None:
_tts = TTSService()
return _tts
class TTSRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=5000)
ref_audio: str | None = Field(
default=None, description="samples/ 또는 uploads/ 기준 상대/절대 경로"
)
ref_text: str | None = None
preprocess: bool = True
class TTSResponse(BaseModel):
job_id: str
audio_url: str
model: str
text_preview: str
class HealthResponse(BaseModel):
status: str
model: str
samples_count: int
@app.get("/api/health", response_model=HealthResponse)
def health() -> HealthResponse:
s = get_settings()
samples = list(s.samples_dir.glob("*.wav"))
return HealthResponse(
status="ok",
model=s.tts_model,
samples_count=len(samples),
)
@app.post("/api/tts", response_model=TTSResponse)
def create_tts(body: TTSRequest) -> TTSResponse:
text = preprocess_korean(body.text) if body.preprocess else body.text.strip()
if not text:
raise HTTPException(400, "text is empty")
ref_path: Path | None = None
if body.ref_audio:
p = Path(body.ref_audio)
if not p.is_absolute():
for base in (get_settings().samples_dir, get_settings().uploads_dir):
candidate = base / p
if candidate.is_file():
p = candidate
break
if not p.is_file():
raise HTTPException(404, f"ref_audio not found: {body.ref_audio}")
ref_path = p
try:
job_id, _ = get_tts().synthesize_to_file(
text, ref_audio=ref_path, ref_text=body.ref_text
)
except FileNotFoundError as e:
raise HTTPException(404, str(e)) from e
except RuntimeError as e:
raise HTTPException(503, str(e)) from e
return TTSResponse(
job_id=job_id,
audio_url=f"/api/audio/{job_id}",
model=get_settings().tts_model,
text_preview=text[:80] + ("" if len(text) > 80 else ""),
)
@app.get("/api/audio/{job_id}")
def get_audio(job_id: str) -> FileResponse:
path = get_settings().outputs_dir / job_id / "output.wav"
if not path.is_file():
alt = get_settings().outputs_dir / job_id / "part_000.wav"
path = alt if alt.is_file() else path
if not path.is_file():
raise HTTPException(404, "audio not found")
return FileResponse(path, media_type="audio/wav", filename=f"{job_id}.wav")
@app.get("/api/voice-samples")
def list_voice_samples() -> dict:
s = get_settings()
samples = []
for d, label in ((s.samples_dir, "samples"), (s.uploads_dir, "uploads")):
for wav in sorted(d.glob("*.wav")):
txt = wav.with_suffix(".txt")
samples.append(
{
"id": wav.stem,
"path": str(wav),
"label": label,
"has_transcript": txt.is_file(),
}
)
return {"samples": samples, "default_model": s.tts_model}
@app.post("/api/voice-sample")
async def upload_voice_sample(
file: UploadFile = File(...),
ref_text: str = Form(""),
) -> dict:
if not file.filename or not file.filename.lower().endswith(".wav"):
raise HTTPException(400, "WAV 파일만 업로드 가능합니다")
sample_id = uuid.uuid4().hex[:10]
dest = get_settings().uploads_dir / f"{sample_id}.wav"
with open(dest, "wb") as f:
shutil.copyfileobj(file.file, f)
if ref_text.strip():
(dest.with_suffix(".txt")).write_text(ref_text.strip(), encoding="utf-8")
return {
"id": sample_id,
"path": str(dest),
"message": "업로드 완료. TTS 요청 시 ref_audio에 이 path를 사용하세요.",
}
if WEB_DIR.is_dir():
app.mount("/", StaticFiles(directory=str(WEB_DIR), html=True), name="web")
@app.on_event("startup")
def startup() -> None:
get_settings().outputs_dir.mkdir(parents=True, exist_ok=True)
get_settings().uploads_dir.mkdir(parents=True, exist_ok=True)