onnx 加载本地模型 tts基础功能完成

This commit is contained in:
yangyakun
2026-04-09 18:01:08 +08:00
parent e4bd908714
commit 63ac5bb7e2
29 changed files with 783 additions and 88 deletions

View File

@@ -5,7 +5,7 @@ ext {
compileSdkVersion: 34, compileSdkVersion: 34,
buildToolsVersion: "29.0.2", buildToolsVersion: "29.0.2",
minSdkVersion : 19, minSdkVersion : 21,
targetSdkVersion : 29, targetSdkVersion : 29,
] ]
dependencies = [ dependencies = [

View File

@@ -68,6 +68,7 @@ dependencies {
implementation(project(":libraries:mogo-tts:tts-iflytek-offline")) implementation(project(":libraries:mogo-tts:tts-iflytek-offline"))
implementation(project(":libraries:mogo-tts:tts-base")) implementation(project(":libraries:mogo-tts:tts-base"))
implementation(project(":libraries:mogo-tts:tts-mogo"))
// ======================集成 Bugly================================================================ // ======================集成 Bugly================================================================
implementation 'com.tencent.bugly:crashreport:4.1.9.2' implementation 'com.tencent.bugly:crashreport:4.1.9.2'

View File

@@ -9,7 +9,6 @@ import com.blankj.utilcode.util.PathUtils
import com.blankj.utilcode.util.TimeUtils import com.blankj.utilcode.util.TimeUtils
import com.bytedance.boost_multidex.BoostMultiDex import com.bytedance.boost_multidex.BoostMultiDex
import com.mogo.tts.common.TtsManager import com.mogo.tts.common.TtsManager
import com.mogo.tts.iflytek.offline.IFlyTekOfflineTts
import com.mogo.tts.utils.PhoneUtilsExtend import com.mogo.tts.utils.PhoneUtilsExtend
import com.tencent.bugly.crashreport.CrashReport import com.tencent.bugly.crashreport.CrashReport
import com.tencent.bugly.crashreport.CrashReport.CrashHandleCallback import com.tencent.bugly.crashreport.CrashReport.CrashHandleCallback

View File

@@ -14,7 +14,8 @@ object TtsManager {
// 暂时换成反射,解决死锁问题 // 暂时换成反射,解决死锁问题
var clazz1: Class<*>? = null var clazz1: Class<*>? = null
try { try {
clazz1 = Class.forName("com.mogo.tts.iflytek.offline.IFlyTekOfflineTts") //clazz1 = Class.forName("com.mogo.tts.iflytek.offline.IFlyTekOfflineTts")
clazz1 = Class.forName("com.k2fsa.sherpa.onnx.MogoOfflineTTS")
} catch (ignored: Exception) { } catch (ignored: Exception) {
} }

View File

@@ -0,0 +1,103 @@
package com.mogo.tts.common.impl;
import android.content.Context;
import android.os.Looper;
import com.elegant.utils.UiThreadHandler;
import com.mogo.tts.common.IGlobalTtsCallback;
import com.mogo.tts.common.IMogoTTS;
import com.mogo.tts.common.IMogoTTSCallback;
import com.mogo.tts.common.LangTtsEntity;
import com.mogo.tts.common.LanguageType;
import com.mogo.tts.common.log.TtsLogManager;
import java.util.HashMap;
public abstract class BaseMogoTTS implements IMogoTTS {
protected Context context;
// 由于主动打断不会有回调事件所以主动打断时清掉map中被打断的text和callback
protected volatile String curTtsContent = "";
protected volatile LangTtsEntity curTtsEntity = null;
protected HashMap<String, IMogoTTSCallback> speakVoiceMap = new HashMap<>();
protected IGlobalTtsCallback mGlobalTtsCallback = null;
protected String getTAG() {
return "BaseMogoTTS";
}
@Override
public void initTts(Context context) {
this.context = context;
}
@Override
public void speakTTSVoice(String tts) {
if (tts.isEmpty()) return;
speakTTSVoice(tts,null);
}
@Override
public void speakTTSVoice(String tts, IMogoTTSCallback callBack) {
if (tts.isEmpty()) return;
speakTTSVoice(new LangTtsEntity(tts, LanguageType.CHINESE), callBack);
}
@Override
public void speakTTSVoice(LangTtsEntity ttsEntity, IMogoTTSCallback callBack) {
if (Thread.currentThread() != Looper.getMainLooper().getThread()) {
UiThreadHandler.post(new Runnable() {
@Override
public void run() {
if (callBack != null) {
speakVoiceMap.put(ttsEntity.toString(),callBack);
}
speakMultiLangTTS(ttsEntity);
}
});
}else {
if (callBack != null) {
speakVoiceMap.put(ttsEntity.toString(),callBack);
}
speakMultiLangTTS(ttsEntity);
}
}
protected void speakMultiLangTTS(LangTtsEntity ttsEntity){
this.curTtsEntity = ttsEntity;
// 合成并播放
TtsLogManager.d(getTAG(), "tts准备合成"+ttsEntity);
}
@Override
public void stopTts() {
if (Thread.currentThread() != Looper.getMainLooper().getThread()) {
UiThreadHandler.post(new Runnable() {
@Override
public void run() {
realStop();
}
});
} else {
realStop();
}
}
protected void realStop() {
TtsLogManager.d(getTAG(),"停止tts");
if(curTtsEntity!=null){
String key = curTtsEntity.toString();
if (speakVoiceMap.containsKey(key)) {
speakVoiceMap.remove(key).onStopTts(key);
}
curTtsEntity = null;
}
this.curTtsContent = "";
}
public void registerTtsListener(IGlobalTtsCallback callback) {
this.mGlobalTtsCallback = callback;
}
}

View File

@@ -18,21 +18,14 @@ import com.iflytek.aikit.core.BaseLibrary
import com.iflytek.aikit.core.CoreListener import com.iflytek.aikit.core.CoreListener
import com.iflytek.aikit.core.ErrType import com.iflytek.aikit.core.ErrType
import com.mogo.tts.common.IGlobalTtsCallback import com.mogo.tts.common.IGlobalTtsCallback
import com.mogo.tts.common.IMogoTTS
import com.mogo.tts.common.IMogoTTSCallback
import com.mogo.tts.common.LangTtsEntity import com.mogo.tts.common.LangTtsEntity
import com.mogo.tts.common.log.TtsLogManager import com.mogo.tts.common.log.TtsLogManager
import com.mogo.tts.common.LanguageType import com.mogo.tts.common.impl.BaseMogoTTS
import com.mogo.tts.common.utils.FileUtils import com.mogo.tts.common.utils.FileUtils
import java.io.File import java.io.File
@Keep @Keep
class IFlyTekOfflineTts : IMogoTTS { class IFlyTekOfflineTts : BaseMogoTTS() {
companion object {
const val TAG = "IFlyTekTts"
}
private var context: Context? = null
private var aiHandle: AiHandle? = null private var aiHandle: AiHandle? = null
private var OUTPUT_DIR :String = "" private var OUTPUT_DIR :String = ""
@@ -40,20 +33,12 @@ class IFlyTekOfflineTts : IMogoTTS {
"e2e44feff" "e2e44feff"
} }
// 由于主动打断不会有回调事件所以主动打断时清掉map中被打断的text和callback override fun getTAG(): String {
@Volatile return "IFlyTekOfflineTts"
private var curTtsContent = ""
@Volatile
private var curTtsEntity: LangTtsEntity? = null
private val speakVoiceMap by lazy {
HashMap<String, IMogoTTSCallback>()
} }
private var mGlobalTtsCallback: IGlobalTtsCallback? = null
override fun initTts(context: Context) { override fun initTts(context: Context) {
this.context = context super.initTts(context)
initSDK() initSDK()
} }
@@ -96,16 +81,16 @@ class IFlyTekOfflineTts : IMogoTTS {
ErrType.AUTH -> { ErrType.AUTH -> {
if (code == 0) { if (code == 0) {
// SDK授权成功 // SDK授权成功
TtsLogManager.d(TAG, "科大讯飞离线语音合成授权成功!") TtsLogManager.d(tag, "科大讯飞离线语音合成授权成功!")
} else { } else {
// SDK授权失败授权码为code // SDK授权失败授权码为code
TtsLogManager.d(TAG, "科大讯飞离线语音合成授权失败码:$code") TtsLogManager.d(tag, "科大讯飞离线语音合成授权失败码:$code")
} }
} }
else -> { else -> {
// SDK状态为type, code // SDK状态为type, code
TtsLogManager.d(TAG, "type:$type, code:$code") TtsLogManager.d(tag, "type:$type, code:$code")
} }
} }
} }
@@ -126,69 +111,19 @@ class IFlyTekOfflineTts : IMogoTTS {
} }
} }
override fun speakTTSVoice(tts: String?) { override fun realStop() {
speakTTSVoice(tts, null) super.realStop()
}
override fun speakTTSVoice(tts: String?, callBack: IMogoTTSCallback?) {
if (tts.isNullOrEmpty()) return
speakTTSVoice(LangTtsEntity(tts, LanguageType.CHINESE), callBack)
}
override fun speakTTSVoice(ttsEntity: LangTtsEntity, callBack: IMogoTTSCallback?) {
if (Thread.currentThread() != Looper.getMainLooper().thread) {
UiThreadHandler.post {
if (callBack != null) {
speakVoiceMap[ttsEntity.toString()] = callBack
}
speakMultiLangTTS(ttsEntity)
}
} else {
if (callBack != null) {
speakVoiceMap[ttsEntity.toString()] = callBack
}
speakMultiLangTTS(ttsEntity)
}
}
override fun stopTts() {
if (Thread.currentThread() != Looper.getMainLooper().thread) {
UiThreadHandler.post {
realStop()
}
} else {
realStop()
}
}
private fun realStop() {
TtsLogManager.d(TAG,"停止tts")
curTtsEntity?.let {
val string = it.toString()
if (speakVoiceMap.containsKey(string)) {
speakVoiceMap.remove(string)?.onStopTts(string)
}
curTtsEntity = null
}
curTtsContent = ""
AudioTrackManager.instance?.stopPlay() AudioTrackManager.instance?.stopPlay()
if (aiHandle == null || aiHandle?.isSuccess == false) { if (aiHandle == null || aiHandle?.isSuccess == false) {
return return
} }
val end = AiHelper.getInst().end(aiHandle) val end = AiHelper.getInst().end(aiHandle)
aiHandle = null aiHandle = null
TtsLogManager.d(TAG,"停止tts:${end}") TtsLogManager.d(tag,"停止tts:${end}")
} }
override fun speakMultiLangTTS(ttsEntity: LangTtsEntity) {
override fun registerTtsListener(callback: IGlobalTtsCallback?) { super.speakMultiLangTTS(ttsEntity)
mGlobalTtsCallback = callback
}
private fun speakMultiLangTTS(ttsEntity: LangTtsEntity) {
curTtsEntity = ttsEntity
// 合成并播放
TtsLogManager.d(TAG, "tts准备合成$ttsEntity")
stopTts() stopTts()
startSpeak(ttsEntity) startSpeak(ttsEntity)
} }
@@ -278,7 +213,7 @@ class IFlyTekOfflineTts : IMogoTTS {
var bytes: ByteArray? var bytes: ByteArray?
for (i in list.indices) { for (i in list.indices) {
bytes = list[i].value ?: continue bytes = list[i].value ?: continue
TtsLogManager.d(TAG, "onResult:handleID:" + handleID + ":" + list[i].key) TtsLogManager.d(tag, "onResult:handleID:" + handleID + ":" + list[i].key)
if (!dir.exists()) { if (!dir.exists()) {
dir.mkdirs() dir.mkdirs()
} }
@@ -295,23 +230,23 @@ class IFlyTekOfflineTts : IMogoTTS {
) { ) {
when (event) { when (event) {
AeeEvent.AEE_EVENT_UNKNOWN.value -> { AeeEvent.AEE_EVENT_UNKNOWN.value -> {
TtsLogManager.d(TAG, "未知错误") TtsLogManager.d(tag, "未知错误")
// handleErrorEvent("未知错误") // handleErrorEvent("未知错误")
aiHandle?.let { aiHandle?.let {
val ret = AiHelper.getInst().end(it) val ret = AiHelper.getInst().end(it)
aiHandle = null aiHandle = null
TtsLogManager.d(TAG, "AIKit_End$ret") TtsLogManager.d(tag, "AIKit_End$ret")
} }
val errorInfo = ResourcesHelper.getString(context, R.string.module_tts_unknown_error) val errorInfo = ResourcesHelper.getString(context, R.string.module_tts_unknown_error)
handleErrorEvent(errorInfo) handleErrorEvent(errorInfo)
} }
AeeEvent.AEE_EVENT_TIMEOUT.value->{ AeeEvent.AEE_EVENT_TIMEOUT.value->{
TtsLogManager.d(TAG, "超时错误") TtsLogManager.d(tag, "超时错误")
// handleErrorEvent("未知错误") // handleErrorEvent("未知错误")
aiHandle?.let { aiHandle?.let {
val ret = AiHelper.getInst().end(it) val ret = AiHelper.getInst().end(it)
aiHandle = null aiHandle = null
TtsLogManager.d(TAG, "AIKit_End$ret") TtsLogManager.d(tag, "AIKit_End$ret")
} }
val errorInfo = ResourcesHelper.getString(context, R.string.module_tts_unknown_error) val errorInfo = ResourcesHelper.getString(context, R.string.module_tts_unknown_error)
handleErrorEvent(errorInfo) handleErrorEvent(errorInfo)
@@ -324,7 +259,7 @@ class IFlyTekOfflineTts : IMogoTTS {
aiHandle?.let { aiHandle?.let {
val ret = AiHelper.getInst().end(it) val ret = AiHelper.getInst().end(it)
aiHandle = null aiHandle = null
TtsLogManager.d(TAG, "AIKit_End$ret") TtsLogManager.d(tag, "AIKit_End$ret")
} }
onSpeakBegin() onSpeakBegin()
AudioTrackManager.instance?.startPlay("${OUTPUT_DIR}/OutPut_mogo.pcm") AudioTrackManager.instance?.startPlay("${OUTPUT_DIR}/OutPut_mogo.pcm")
@@ -333,7 +268,7 @@ class IFlyTekOfflineTts : IMogoTTS {
} }
override fun onError(handleID: Int, err: Int, msg: String?, usrCxt: Any?) { override fun onError(handleID: Int, err: Int, msg: String?, usrCxt: Any?) {
TtsLogManager.d(TAG, "错误码:$err,错误信息:$msg") TtsLogManager.d(tag, "错误码:$err,错误信息:$msg")
// handleErrorEvent("错误码:$err,错误信息:$msg") // handleErrorEvent("错误码:$err,错误信息:$msg")
val errorInfo = ResourcesHelper.getResources(context) val errorInfo = ResourcesHelper.getResources(context)
.getString(R.string.module_tts_ai_handle_error_code, err, msg) .getString(R.string.module_tts_ai_handle_error_code, err, msg)

View File

@@ -32,6 +32,8 @@ dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"]) implementation fileTree(dir: "libs", include: ["*.jar"])
implementation rootProject.ext.dependencies.androidxappcompat implementation rootProject.ext.dependencies.androidxappcompat
implementation(project(":libraries:mogo-tts:tts-base"))
} }
apply from: new File(rootProject.rootDir, "gradle/upload.gradle").toString() apply from: new File(rootProject.rootDir, "gradle/upload.gradle").toString()

View File

@@ -0,0 +1,228 @@
package com.k2fsa.sherpa.onnx
import android.content.Context
import android.content.res.AssetManager
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioManager
import android.media.AudioTrack
import android.util.Log
import com.elegant.utils.ThreadPoolService
import com.elegant.utils.UiThreadHandler
import com.mogo.tts.common.LangTtsEntity
import com.mogo.tts.common.impl.BaseMogoTTS
import com.mogo.tts.common.log.TtsLogManager
import com.mogo.tts.common.utils.FileUtils
import java.io.File
import java.util.concurrent.locks.ReentrantLock
class MogoOfflineTTS : BaseMogoTTS() {
private var tts: OfflineTts? = null
private var track: AudioTrack? = null
@Volatile
private var stopped: Boolean = false
private var OUTPUT_DIR :String = ""
private var waitPlayInfo:LangTtsEntity?=null
val lock = ReentrantLock()
override fun getTAG(): String {
return "MogoOfflineTTS"
}
override fun initTts(context: Context?) {
super.initTts(context)
initSdk()
}
private fun initSdk() {
context?.let {
val workPath = it.filesDir.absolutePath+ File.separator+"mogotts"+ File.separator+"xtts"
OUTPUT_DIR = workPath+File.separator+"output"
val file = File("$workPath/yue_dict")
if (!file.exists() || file.length() == 0L) {
ThreadPoolService.execute {
try {
FileUtils.copyAssetsToLocal(context, "matcha-icefall-zh-en/espeak-ng-data", workPath)
UiThreadHandler.post {
initEngine(workPath)
initAudioTrack()
}
}catch (e:Exception){
e.printStackTrace()
}
}
} else {
try {
initEngine(workPath)
initAudioTrack()
}catch (e:Exception){
e.printStackTrace()
}
}
}
}
private fun initEngine(workPath: String) {
// The purpose of such a design is to make the CI test easier
// Please see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py
val modelDir = "matcha-icefall-zh-en"
val acousticModelName = "model-steps-3.onnx"
val vocoder = "vocos-16khz-univ.onnx"
val lexicon = "lexicon.txt"
val assets: AssetManager? = context.assets
val isKitten = false
val config = getOfflineTtsConfig(
modelDir = modelDir,
modelName = "",
acousticModelName = acousticModelName ?: "",
vocoder = vocoder ?: "",
voices = "",
lexicon = lexicon ?: "",
dataDir = workPath,
dictDir = "",
ruleFsts = "",
ruleFars = "",
isKitten = isKitten,
)
tts = OfflineTts(assetManager = assets, config = config)
TtsLogManager.d(tag,"初始化tts:${tts?.config}")
}
private fun initAudioTrack() {
tts?.let {
val sampleRate = it.sampleRate()
val bufLength = AudioTrack.getMinBufferSize(
sampleRate,
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_FLOAT
)
Log.i(tag, "sampleRate: $sampleRate, buffLength: $bufLength")
val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.setUsage(AudioAttributes.USAGE_MEDIA)
.build()
val format = AudioFormat.Builder()
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
.setSampleRate(sampleRate)
.build()
track = AudioTrack(
attr, format, bufLength, AudioTrack.MODE_STREAM,
AudioManager.AUDIO_SESSION_ID_GENERATE
)
track?.play()
}
}
override fun release() {
tts?.release()
}
override fun speakMultiLangTTS(ttsEntity: LangTtsEntity?) {
super.speakMultiLangTTS(ttsEntity)
TtsLogManager.d(tag,"1 播放语言:${ttsEntity}")
if (lock.isLocked) {
waitPlayInfo = ttsEntity
TtsLogManager.d(tag,"speakMultiLangTTS 正在生成 锁定中")
return
}else{
if(waitPlayInfo!=null) {
TtsLogManager.d(tag,"speakMultiLangTTS 没有被锁 等待播放有值 ${waitPlayInfo}")
waitPlayInfo = null
}
}
startSpeak(ttsEntity)
}
private fun startSpeak(langTtsEntity: LangTtsEntity?) {
stopTts()
langTtsEntity?.let {
curTtsContent = it.ttsContent
realSpeak(it.ttsContent)
}
}
private fun realSpeak(content: String) {
TtsLogManager.d(tag,"2 realSpeak-开始进行生成播放--${Thread.currentThread().name}")
track?.pause()
track?.flush()
track?.play()
ThreadPoolService.execute {
var audio:GeneratedAudio?=null
stopped = false
try {
lock.lock()
TtsLogManager.d(tag,"3 开始生成 去加锁")
audio = tts?.generateWithConfigAndCallback(
text = content,
config = GenerationConfig(sid = 0, speed = 1.0f),
callback = object : Function1<FloatArray, Int> {
override fun invoke(samples: FloatArray): Int {
if(lock.isLocked){
TtsLogManager.d(tag,"4 生成成功 去解锁")
lock.unlock()
}
if(waitPlayInfo!=null){
TtsLogManager.d(tag,"生成成功 解锁后 等待播放有值 停止播放去生成")
track?.stop()
UiThreadHandler.post {
waitPlayInfo?.let {
startSpeak(it.copy())
waitPlayInfo = null
}
}
return 0
}
if (!stopped) {
track?.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
return 1
} else {
track?.stop()
return 0
}
}
}
)
TtsLogManager.d(tag,"5 realSpeak-结束播放语音")
}finally {
if(lock.isLocked) {
TtsLogManager.d(tag,"realSpeak-finally 中发现还在锁定解锁")
lock.unlock()
}
}
val filename = "$OUTPUT_DIR/generated.wav"
audio?.let {
val ok = it.samples.isNotEmpty() && it.save(filename)
if(ok){
UiThreadHandler.post {
track?.stop()
}
}
}
}
}
override fun realStop() {
super.realStop()
TtsLogManager.d(tag,"mogooffline realStop")
stopped = true
track?.pause()
track?.flush()
}
}

View File

@@ -0,0 +1,382 @@
// Copyright (c) 2023 Xiaomi Corporation
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflineTtsVitsModelConfig(
var model: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "", // unused
var noiseScale: Float = 0.667f,
var noiseScaleW: Float = 0.8f,
var lengthScale: Float = 1.0f,
)
data class OfflineTtsMatchaModelConfig(
var acousticModel: String = "",
var vocoder: String = "",
var lexicon: String = "",
var tokens: String = "",
var dataDir: String = "",
var dictDir: String = "", // unused
var noiseScale: Float = 1.0f,
var lengthScale: Float = 1.0f,
)
data class OfflineTtsKokoroModelConfig(
var model: String = "",
var voices: String = "",
var tokens: String = "",
var dataDir: String = "",
var lexicon: String = "",
var lang: String = "",
var dictDir: String = "", // unused
var lengthScale: Float = 1.0f,
)
data class OfflineTtsZipVoiceModelConfig(
var tokens: String = "",
var encoder: String = "",
var decoder: String = "",
var vocoder: String = "",
var dataDir: String = "",
var lexicon: String = "",
var featScale: Float = 0.1f,
var tShift: Float = 0.5f,
var targetRms: Float = 0.1f,
var guidanceScale: Float = 1.0f,
)
data class OfflineTtsKittenModelConfig(
var model: String = "",
var voices: String = "",
var tokens: String = "",
var dataDir: String = "",
var lengthScale: Float = 1.0f,
)
/**
* Configuration for Pocket TTS models.
*
* See https://k2-fsa.github.io/sherpa/onnx/tts/pocket/index.html for details.
*
* @property lmFlow Path to the LM flow model (.onnx)
* @property lmMain Path to the LM main model (.onnx)
* @property encoder Path to the encoder model (.onnx)
* @property decoder Path to the decoder model (.onnx)
* @property textConditioner Path to the text conditioner model (.onnx)
* @property vocabJson Path to vocabulary JSON file
* @property tokenScoresJson Path to token scores JSON file
*/
data class OfflineTtsPocketModelConfig(
var lmFlow: String = "",
var lmMain: String = "",
var encoder: String = "",
var decoder: String = "",
var textConditioner: String = "",
var vocabJson: String = "",
var tokenScoresJson: String = "",
var voiceEmbeddingCacheCapacity: Int = 50,
)
data class OfflineTtsSupertonicModelConfig(
var durationPredictor: String = "",
var textEncoder: String = "",
var vectorEstimator: String = "",
var vocoder: String = "",
var ttsJson: String = "",
var unicodeIndexer: String = "",
var voiceStyle: String = "",
)
data class OfflineTtsModelConfig(
var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(),
var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(),
var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(),
var zipvoice: OfflineTtsZipVoiceModelConfig = OfflineTtsZipVoiceModelConfig(),
var kitten: OfflineTtsKittenModelConfig = OfflineTtsKittenModelConfig(),
var pocket: OfflineTtsPocketModelConfig = OfflineTtsPocketModelConfig(),
var supertonic: OfflineTtsSupertonicModelConfig = OfflineTtsSupertonicModelConfig(),
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class OfflineTtsConfig(
var model: OfflineTtsModelConfig = OfflineTtsModelConfig(),
var ruleFsts: String = "",
var ruleFars: String = "",
var maxNumSentences: Int = 1,
var silenceScale: Float = 0.2f,
)
class GeneratedAudio(
val samples: FloatArray,
val sampleRate: Int,
) {
fun save(filename: String) = saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
private external fun saveImpl(
filename: String,
samples: FloatArray,
sampleRate: Int
): Boolean
}
data class GenerationConfig(
var silenceScale: Float = 0.2f,
var speed: Float = 1.0f,
var sid: Int = 0,
var referenceAudio: FloatArray? = null,
var referenceSampleRate: Int = 0,
var referenceText: String? = null,
var numSteps: Int = 5,
var extra: Map<String, String>? = null
)
class OfflineTts(
assetManager: AssetManager? = null,
var config: OfflineTtsConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
fun sampleRate() = getSampleRate(ptr)
fun numSpeakers() = getNumSpeakers(ptr)
fun generate(
text: String,
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio {
return generateImpl(ptr, text = text, sid = sid, speed = speed)
}
fun generateWithCallback(
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
return generateWithCallbackImpl(
ptr,
text = text,
sid = sid,
speed = speed,
callback = callback
)
}
fun generateWithConfig(
text: String,
config: GenerationConfig
): GeneratedAudio {
return generateWithConfigImpl(ptr, text, config, null)
}
fun generateWithConfigAndCallback(
text: String,
config: GenerationConfig,
callback: (samples: FloatArray) -> Int
): GeneratedAudio {
return generateWithConfigImpl(ptr, text, config, callback)
}
fun allocate(assetManager: AssetManager? = null) {
if (ptr == 0L) {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
}
fun free() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineTtsConfig,
): Long
private external fun newFromFile(
config: OfflineTtsConfig,
): Long
private external fun delete(ptr: Long)
private external fun getSampleRate(ptr: Long): Int
private external fun getNumSpeakers(ptr: Long): Int
// The returned array has two entries:
// - the first entry is an 1-D float array containing audio samples.
// Each sample is normalized to the range [-1, 1]
// - the second entry is the sample rate
private external fun generateImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f
): GeneratedAudio
private external fun generateWithCallbackImpl(
ptr: Long,
text: String,
sid: Int = 0,
speed: Float = 1.0f,
callback: (samples: FloatArray) -> Int
): GeneratedAudio
private external fun generateWithConfigImpl(
ptr: Long,
text: String,
config: GenerationConfig,
callback: ((samples: FloatArray) -> Int)?
): GeneratedAudio
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
// to download models
fun getOfflineTtsConfig(
modelDir: String,
modelName: String, // for VITS
acousticModelName: String, // for Matcha
vocoder: String, // for Matcha
voices: String, // for Kokoro or kitten
lexicon: String,
dataDir: String,
dictDir: String, // unused
ruleFsts: String,
ruleFars: String,
numThreads: Int? = null,
isKitten: Boolean = false
): OfflineTtsConfig {
// For Matcha TTS, please set
// acousticModelName, vocoder
// For Kokoro TTS, please set
// modelName, voices
// For Kitten TTS, please set
// modelName, voices, isKitten
// For VITS, please set
// modelName
val numberOfThreads = if (numThreads != null) {
numThreads
} else if (voices.isNotEmpty()) {
// for Kokoro and Kitten TTS models, we use more threads
4
} else {
2
}
if (modelName.isEmpty() && acousticModelName.isEmpty()) {
throw IllegalArgumentException("Please specify a TTS model")
}
if (modelName.isNotEmpty() && acousticModelName.isNotEmpty()) {
throw IllegalArgumentException("Please specify either a VITS or a Matcha model, but not both")
}
if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) {
throw IllegalArgumentException("Please provide vocoder for Matcha TTS")
}
val vits = if (modelName.isNotEmpty() && voices.isEmpty()) {
OfflineTtsVitsModelConfig(
model = "$modelDir/$modelName",
lexicon = "$modelDir/$lexicon",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsVitsModelConfig()
}
val matcha = if (acousticModelName.isNotEmpty()) {
OfflineTtsMatchaModelConfig(
acousticModel = "$modelDir/$acousticModelName",
vocoder = vocoder,
lexicon = "$modelDir/$lexicon",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsMatchaModelConfig()
}
val kokoro = if (voices.isNotEmpty() && !isKitten) {
OfflineTtsKokoroModelConfig(
model = "$modelDir/$modelName",
voices = "$modelDir/$voices",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
lexicon = when {
lexicon == "" -> lexicon
"," in lexicon -> lexicon
else -> "$modelDir/$lexicon"
},
)
} else {
OfflineTtsKokoroModelConfig()
}
val kitten = if (isKitten) {
OfflineTtsKittenModelConfig(
model = "$modelDir/$modelName",
voices = "$modelDir/$voices",
tokens = "$modelDir/tokens.txt",
dataDir = dataDir,
)
} else {
OfflineTtsKittenModelConfig()
}
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = vits,
matcha = matcha,
kokoro = kokoro,
kitten = kitten,
numThreads = numberOfThreads,
debug = true,
provider = "cpu",
),
ruleFsts = ruleFsts,
ruleFars = ruleFars,
)
}

View File

@@ -0,0 +1,11 @@
# Introduction
Note that if you use Android Studio, then you only need to
copy libonnxruntime.so and libsherpa-onnx-jni.so
to your jniLibs, and you don't need libsherpa-onnx-c-api.so or
libsherpa-onnx-cxx-api.so.
libsherpa-onnx-c-api.so and libsherpa-onnx-cxx-api.so are for users
who don't use JNI. In that case, libsherpa-onnx-jni.so is not needed.
In any case, libonnxruntime.so is always needed.

View File

@@ -0,0 +1,11 @@
# Introduction
Note that if you use Android Studio, then you only need to
copy libonnxruntime.so and libsherpa-onnx-jni.so
to your jniLibs, and you don't need libsherpa-onnx-c-api.so or
libsherpa-onnx-cxx-api.so.
libsherpa-onnx-c-api.so and libsherpa-onnx-cxx-api.so are for users
who don't use JNI. In that case, libsherpa-onnx-jni.so is not needed.
In any case, libonnxruntime.so is always needed.

View File

@@ -0,0 +1,11 @@
# Introduction
Note that if you use Android Studio, then you only need to
copy libonnxruntime.so and libsherpa-onnx-jni.so
to your jniLibs, and you don't need libsherpa-onnx-c-api.so or
libsherpa-onnx-cxx-api.so.
libsherpa-onnx-c-api.so and libsherpa-onnx-cxx-api.so are for users
who don't use JNI. In that case, libsherpa-onnx-jni.so is not needed.
In any case, libonnxruntime.so is always needed.

View File

@@ -0,0 +1,11 @@
# Introduction
Note that if you use Android Studio, then you only need to
copy libonnxruntime.so and libsherpa-onnx-jni.so
to your jniLibs, and you don't need libsherpa-onnx-c-api.so or
libsherpa-onnx-cxx-api.so.
libsherpa-onnx-c-api.so and libsherpa-onnx-cxx-api.so are for users
who don't use JNI. In that case, libsherpa-onnx-jni.so is not needed.
In any case, libonnxruntime.so is always needed.