语音转文字准确率计算 转写正确率的衡量项:ACC、Corr
H = 正确的字数
D = 删减的错误,“我是中国人” “我是中国”
S = 替换的错误,“我是中国人” “我是中华人”
I = 插入的错误,“我是中国人” “我是中国男人”
N = 总字数
ACC = (H - I)/N
Corr= H/N
思路:
sample "我是中国人学大生,今天要测试,录音结束"
test "中国人民生,今天有个大事情,我想吃饭"
语音转写文字,需要遵从文字的语义,所以不能文字出现就算正确,不考虑各种复杂的因素,
要从test中找到sample中对应的字符,并且顺序要按sample中的字符顺序(排除标点符号)
如:
sample中每个字找到的顺序
那么如何找到正确的文字个数呢?应该就是在这组序号中剔除未找到的-1,然后从剩下的序号中找到最长升序序列(竟然是个算法问题,丢!)
如图,剔除-1,剩下的就是12,0,1,2,9,4,5,6,最长升序那不就是0,1,2,4,5,6吗,对应的文字就是“中国人生今天”,那么算法如何实现呢?
这里给出个笨办法:
遍历序列,将每个升序数组都保存在list里,如果不是升序,就新建一个list。
如当遇到12,则新建一个list:
当遇到0,则需要新建一个list
另外还需要注意,即使是升序,也不一定就能组成最长
如0 1 2,后面是9,如果加进去了,就只能组成0 1 2 9,长度为4,如果放弃加9,则有机会组成0 1 2 4 5 6,长度为6。所以每次添加一个符合升序规则的数字的时候,我们要提前将原list备份一个,留个机会看能否组成更长的序列。
剩下就是计算删减、替换、添加的文字个数了,这个比较简单,看看代码逻辑就行了。
本文代码没有考虑时间复杂度和内存占用(用于测试),请自行优化。
上代码:
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Set;
import java.util.function.IntPredicate;
import java.util.logging.Logger;
/**
* 语音转文字准确率计算 转写正确率的衡量项:ACC、Corr H = 正确的字数 D = 删减的错误,“我是中国人”“我是中国” S =
* 替换的错误,“我是中国人”“我是中华人” I = 插入的错误,“我是中国人”“我是中国男人” N = 总字数 ACC = (H - I)/N
* Corr= H/N
*/
public class StringCompare {
private static final String ORIGINAL_STRING = "我是中国人学大生,今天要测试,录音结束";
private static final String TEST_STRING = "中国人民生,今天有个大事情,我想吃饭";
private static float mAcc = 0;
private static float mCorr = 0;
private static int H = 0;// 正确的字数
private static int N = 0;// 总字数
private static int D = 0;// 删减的字数
private static int S = 0;// 替换的字数
private static int I = 0;// 插入的字数
private static List<Character> arrayD = new ArrayList<Character>();
private static List<Character> arrayI = new ArrayList<Character>();
private static List<Character> arrayS = new ArrayList<Character>();
/**
* 移除标点符号
*/
private static String removePunctuation(String sentence) {
String string = sentence.replaceAll("\\pP", "");// 完全清除标点
System.out.println(string);
return string;
}
private static void print() {
mAcc = (float) (H - I) / (float) N;
mCorr = H / (float) N;
System.out.println("H:" + H + ", N:" + N + ", D:" + D + ", I:" + I + ", S:" + S);
System.out.println("mAcc" + mAcc + ",mCorr:" + mCorr);
}
/*
* 获取A字符串每个字符在B字符串中的位置
*/
private static int[] getInIndex(String original, String test) {
char[] charOriginal = original.toCharArray();
List<Integer> list = new ArrayList<Integer>();
for (int i = 0; i < charOriginal.length; i++) {
int tempIndex = test.indexOf(charOriginal[i]);
System.out.println("tempIndex:" + tempIndex);
while (list.contains(tempIndex) && tempIndex != -1) {
System.out.println("while");
tempIndex = test.indexOf(charOriginal[i], tempIndex + 1);
System.out.println("tempIndex:" + tempIndex);
}
list.add(tempIndex);
System.out.println("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
}
return list.stream().mapToInt(Integer::valueOf).toArray();
}
// 找出最长的有序list
private static Map<Integer, Integer> getMaxSortedIndexs(int[] index) {
List<Map<Integer, Integer>> list = new ArrayList<Map<Integer, Integer>>();
for (int i = 0; i < index.length; i++) {
if (index[i] >= 0) {
System.out.println("cur index:" + index[i]);
if (list.size() == 0) {
list.add(new LinkedHashMap<Integer, Integer>());
list.get(list.size() - 1).put(i, index[i]);
} else {
ListIterator<Map<Integer, Integer>> it = list.listIterator();
while (it.hasNext()) {
Map<Integer, Integer> everyList = it.next();
// 与最后一个元素比较
if (index[i] > (int) everyList.values().toArray()[everyList.size() - 1]) {
Map<Integer, Integer> tempList = new LinkedHashMap<Integer, Integer>();
tempList.putAll(everyList);
everyList.put(i, index[i]);
it.add(tempList);
} else {
Map<Integer, Integer> tempList = new LinkedHashMap<Integer, Integer>();
tempList.put(i, index[i]);
it.add(tempList);
}
}
}
}
}
System.out.println("list size:" + list.size());
int longestIndex = 0;
int maxlength = 0;
for (int i = 0; i < list.size(); i++) {
System.out.println("list[" + i + "]:" + list.get(i).toString());
if (list.get(i).size() > maxlength) {
maxlength = list.get(i).size();
longestIndex = i;
}
}
return list.get(longestIndex);
}
public static void main(String[] args) {
long time = System.currentTimeMillis();
// 拿到每个字符在测试字符串中的位置
String original = removePunctuation(ORIGINAL_STRING);
String test = removePunctuation(TEST_STRING);
int[] index = getInIndex(original, test);
for (int i = 0; i < index.length; i++) {
System.out.print(index[i] + "\t");
}
System.out.print("\n");
Map<Integer, Integer> sortedIndex = getMaxSortedIndexs(index);
System.out.println("cost:" + (System.currentTimeMillis() - time));
Set<Integer> keySet = sortedIndex.keySet();
Collection<Integer> valueSet = sortedIndex.values();
System.out.println("在original字符串中正确的字符");
System.out.println("[" + ORIGINAL_STRING + "]");
for (int key : keySet) {
System.out.println("index:" + key + "value:" + original.charAt(key));
}
System.out.println("--------------------------------------------------------------");
System.out.print("在test字符串中正确的字符\n");
System.out.println("[" + TEST_STRING + "]");
for (int key : valueSet) {
System.out.println("index:" + key + "value:" + test.charAt(key));
}
System.out.println("--------------------------------------------------------------");
System.out.print("计算准确率\n");
// 每一段的差值都需要计算
int tempkey = 0;
int tempValue = 0;
// 保存每一段的A字符序号的差
int diffKey = 0;
// 保存每一段的B字符序号的差
int diffValue = 0;
int indexFlag == 0;
for (Integer key : sortedIndex.keySet()) {
int value = sortedIndex.get(key);
if (tempkey == 0 && key != 0 && indexFlag == 0) {
diffKey = key - tempkey;
} else if (key == 0) {
diffKey = 0;
} else {
diffKey = key - tempkey - 1;
}
if (tempValue == 0 && value != 0 && indexFlag == 0) {
diffValue = value - tempValue;
} else if (value == 0) {
diffValue = 0;
} else {
diffValue = value - tempValue - 1;
}
System.err.println("diffKey:" + diffKey + ", diffValue:" + diffValue);
if (diffKey > diffValue) {
D += diffKey - diffValue;
S += diffValue;
} else if (diffKey == diffValue) {
S += diffValue;
} else {
I += diffValue - diffKey;
S += diffKey;
}
tempkey = key;
tempValue = value;
indexFlag ++;
}
System.out.println("tempkey:" + tempkey + ",tempValue:" + tempValue);
diffKey = original.length() - tempkey - 1;
diffValue = test.length() - tempValue - 1;
System.out.println("diffKey:" + diffKey + ",diffValue:" + diffValue);
if (diffKey > diffValue) {
D += diffKey - diffValue;
S += diffValue;
} else if (diffKey == diffValue) {
S += diffValue;
} else {
I += diffValue - diffKey;
S += diffKey;
}
H = sortedIndex.size();
N = test.length();
print();
System.out.println("--------------------------------------------------------------");
}
}
C#版本
using System;
using System.Collections.Generic;
using System.Text.RegularExpressions;
using System.Linq;
namespace Calibration.utils
{
class EsrUtils
{
private EsrUtils() { }
/*private static readonly EsrUtils singleInstance = new EsrUtils();
public static EsrUtils GetInstance
{
get
{
return singleInstance;
}
}*/
//移除标点符号
public static string RemovePunctuation(string sentence)
{
return Regex.Replace(sentence, "[ \\[ \\] \\^ \\-_*×――(^)$%~!@#$…&%¥—+=<>《》!!???::•`·、。,;,.;\"‘’“”-]", "");
}
// 获取A字符串每个字符在B字符串中的位置
public static int[] GetInIndex(String original, String test)
{
char[] charOriginal = original.ToCharArray();
var list = new List<int>();
for (int i = 0; i < charOriginal.Length; i++)
{
int tempIndex = test.IndexOf(charOriginal[i]);
Console.WriteLine("tempIndex:" + tempIndex);
while (list.Contains(tempIndex) && tempIndex != -1)
{
Console.WriteLine("while");
tempIndex = test.IndexOf(charOriginal[i], tempIndex + 1);
Console.WriteLine("tempIndex:" + tempIndex);
}
list.Add(tempIndex);
Console.WriteLine("charOriginal[" + i + "]:" + charOriginal[i] + ", index:" + tempIndex);
}
return list.ToArray();
}
// 找出序号数组中最长的升序子序列
// 目的
public static Dictionary<int, int> GetMaxSortedIndexs(int[] index)
{
List<Dictionary<int, int>> list = new List<Dictionary<int, int>> { };
for (int i = 0; i < index.Length; i++)
{
if (index[i] >= 0)
{
Console.WriteLine("cur index:" + index[i]);
if (list.Count == 0)
{
list.Add(new Dictionary<int, int>());
list.Last().Add(i, index[i]);
}
else
{
List<Dictionary<int, int>> listBackup = new List<Dictionary<int, int>> { };
for (int j = 0; j < list.Count; j++)
{
Dictionary<int, int> everyList = list[j];
// 与最后一个元素比较
if (index[i] > everyList.Values.Last())
{
// 将当前Dictionary备份一个,因为当前的数据添加或者不添加会有两种结果
// 如数组 12 0 1 2 9 4 5 6
// 如果 0 1 2 后加了9,那只有 0 1 2 9长度为4
// 如果 0 1 2 不加9,那就有0 1 2 4 5 6,长度为6
Dictionary<int, int> tempList = new Dictionary<int, int>(everyList);
listBackup.Add(tempList);
everyList.Add(i, index[i]);
}
else
{
Dictionary<int, int> tempList = new Dictionary<int, int>();
tempList.Add(i, index[i]);
listBackup.Add(tempList);
}
}
// list 合并
list = list.Union(listBackup).ToList<Dictionary<int, int>>();
}
}
}
Console.WriteLine("list size:" + list.Count);
int longestIndex = 0;
int maxlength = 0;
for (int i = 0; i < list.Count; i++)
{
Console.WriteLine("list[" + i + "]:" + list[i]);
if (list[i].Count > maxlength)
{
maxlength = list[i].Count;
longestIndex = i;
}
}
return list[longestIndex];
}
public static List<float[]> GetParameters(string originalString, List<string> testString)
{
List<float[]> resultList = new List<float[]> { };
try
{
string original = RemovePunctuation(originalString);
for (int i = 0; i < testString.Count; i++)
{
string test = testString[i];
test = RemovePunctuation(test);
int[] index = EsrUtils.GetInIndex(original, test);
Console.WriteLine("index:" + index);
Dictionary<int, int> dic = EsrUtils.GetMaxSortedIndexs(index);
int H = 0;
int N = 0;
int I = 0;
int S = 0;
int D = 0;
float corr = 0;
float acc = 0;
Dictionary<int, int> sortedIndex = GetMaxSortedIndexs(index);
Console.WriteLine("在original字符串中正确的字符");
Console.WriteLine("[" + originalString + "]");
foreach (int key in dic.Keys)
{
Console.WriteLine("index:" + key + "value:" + original.ToCharArray()[key]);
}
Console.WriteLine("--------------------------------------------------------------");
Console.WriteLine("在test字符串中正确的字符\n");
Console.WriteLine("[" + testString[i] + "]");
foreach (int key in dic.Values)
{
Console.WriteLine("index:" + key + "value:" + test.ToCharArray()[key]);
}
Console.WriteLine("--------------------------------------------------------------");
Console.WriteLine("计算准确率\n");
// 每一段的差值都需要计算
/*
int tempkey = 0;
int tempValue = 0;
int diffKey = 0;
int diffValue = 0;
foreach (int key in sortedIndex.Keys)
{
int value = sortedIndex[key];
diffKey = key - tempkey;
diffValue = value - tempValue;
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
tempkey = key;
tempValue = value;
}
Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
diffKey = original.Length - tempkey;
diffValue = test.Length - tempValue;
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
H = sortedIndex.Count;
N = test.Length;
*/
// 每一段的差值都需要计算
int tempkey = 0;
int tempValue = 0;
// 保存每一段的A字符序号的差
int diffKey = 0;
// 保存每一段的B字符序号的差
int diffValue = 0;
int indexFlag == 0;
foreach (int key in sortedIndex.Keys)
{
int value = sortedIndex[key];
if (tempkey == 0 && key != 0 && indexFlag == 0) //判断第一位元素
{
diffKey = key - tempkey;
}
else if (key == 0)
{
diffKey = 0;
}
else
{
diffKey = key - tempkey - 1;
}
if (tempValue == 0 && value != 0 && indexFlag == 0) //判断第一位元素
{
diffValue = value - tempValue;
}
else if (value == 0)
{
diffValue = 0;
}
else
{
diffValue = value - tempValue - 1;
}
Console.WriteLine("diffKey:" + diffKey + ", diffValue:" + diffValue);
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
tempkey = key;
tempValue = value;
indexFlag ++;
}
Console.WriteLine("tempkey:" + tempkey + ",tempValue:" + tempValue);
diffKey = original.Length - tempkey - 1;
diffValue = test.Length - tempValue - 1;
Console.WriteLine("diffKey:" + diffKey + ",diffValue:" + diffValue);
if (diffKey > diffValue)
{
D += diffKey - diffValue;
S += diffValue;
}
else if (diffKey == diffValue)
{
S += diffValue;
}
else
{
I += diffValue - diffKey;
S += diffKey;
}
H = sortedIndex.Count;
N = test.Length;
float[] result = new float[7];
acc = (float)(H - I) / (float)N;
corr = H / (float)N;
result[0] = acc;
result[1] = corr;
result[2] = H;
result[3] = N;
result[4] = D;
result[5] = S;
result[6] = I;
Console.WriteLine("acc" + acc + ",corr: " + corr);
resultList.Add(result);
}
}
catch (Exception e)
{
Console.WriteLine("error accur:" + e.ToString());
return null;
}
return resultList;
}
internal static object GetInstance()
{
throw new NotImplementedException();
}
}
}