我目前正在写有关在线手写识别的学士论文。这不是OCR,因为我掌握了如何将符号写为笔轨迹坐标(x,y)的列表的信息。我称之为hwrt-手写识别工具包。它有一个文档,我的一个朋友在他的计算机上工作了“第一步”。

但是,这是我第一次编写Python程序包,希望其他人可以使用它。我希望获得有关该项目的一般反馈。

该项目托管在GitHub上,并具有以下结构:


.
├── bin
├── dist
├── docs
├── hwrt
│   ├── misc
│   └── templates
└── tests
    └── symbols



我有一些鼻子测试(还不够,我正在努力)。

bin中的一个文件是view.py。它使用户可以查看以前下载的数据(请参阅我的文档中的“第一步”)。

setup.py

try:
    from setuptools import setup
except ImportError:
    from distutils.core import setup

config = {
    'name': 'hwrt',
    'version': '0.1.125',
    'author': 'Martin Thoma',
    'author_email': 'info@martin-thoma.de',
    'packages': ['hwrt'],
    'scripts': ['bin/backup.py', 'bin/view.py', 'bin/download.py',
                'bin/test.py', 'bin/train.py', 'bin/analyze_data.py',
                'bin/hwrt', 'bin/record.py'],
    'package_data': {'hwrt': ['templates/*', 'misc/*']},
    'url': 'https://github.com/MartinThoma/hwrt',
    'license': 'MIT',
    'description': 'Handwriting Recognition Tools',
    'long_description': """A tookit for handwriting recognition. It was
    developed as part of the bachelors thesis of Martin Thoma.""",
    'install_requires': [
        "argparse",
        "theano",
        "nose",
        "natsort",
        "PyYAML",
        "matplotlib",
        "shapely"
    ],
    'keywords': ['HWRT', 'recognition', 'handwriting', 'on-line'],
    'download_url': 'https://github.com/MartinThoma/hwrt',
    'classifiers': ['Development Status :: 3 - Alpha',
                    'Environment :: Console',
                    'Intended Audience :: Developers',
                    'Intended Audience :: Science/Research',
                    'License :: OSI Approved :: MIT License',
                    'Natural Language :: English',
                    'Programming Language :: Python :: 2.7',
                    'Programming Language :: Python :: 3',
                    'Topic :: Scientific/Engineering :: Artificial Intelligence',
                    'Topic :: Software Development',
                    'Topic :: Utilities'],
    'zip_safe': False,
    'test_suite': 'nose.collector'
}

setup(**config)


view.py

#!/usr/bin/env python
"""
Display a recorded handwritten symbol as well as the preprocessing methods
and the data multiplication steps that get applied.
"""

import sys
import os
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                    level=logging.DEBUG,
                    stream=sys.stdout)
import yaml
try:  # Python 2
    import cPickle as pickle
except ImportError:  # Python 3
    import pickle

# My modules
import hwrt
from hwrt import HandwrittenData
sys.modules['HandwrittenData'] = HandwrittenData
import hwrt.utils as utils
import hwrt.preprocessing as preprocessing
import hwrt.features as features
import hwrt.data_multiplication as data_multiplication


def _fetch_data_from_server(raw_data_id):
    """Get the data from raw_data_id from the server.
    :returns: The ``data`` if fetching worked, ``None`` if it failed."""
    import MySQLdb
    import MySQLdb.cursors

    # Import configuration file
    cfg = utils.get_database_configuration()
    if cfg is None:
        return None

    # Establish database connection
    connection = MySQLdb.connect(host=cfg[args.mysql]['host'],
                                 user=cfg[args.mysql]['user'],
                                 passwd=cfg[args.mysql]['passwd'],
                                 db=cfg[args.mysql]['db'],
                                 cursorclass=MySQLdb.cursors.DictCursor)
    cursor = connection.cursor()

    # Download dataset
    sql = ("SELECT `id`, `data` "
           "FROM `wm_raw_draw_data` WHERE `id`=%i") % raw_data_id
    cursor.execute(sql)
    return cursor.fetchone()


def _get_data_from_rawfile(path_to_data, raw_data_id):
    """Get a HandwrittenData object that has ``raw_data_id`` from a pickle file
       ``path_to_data``.
       :returns: The HandwrittenData object if ``raw_data_id`` is in
                 path_to_data, otherwise ``None``."""
    loaded = pickle.load(open(path_to_data))
    raw_datasets = loaded['handwriting_datasets']
    for raw_dataset in raw_datasets:
        if raw_dataset['handwriting'].raw_data_id == raw_data_id:
            return raw_dataset['handwriting']
    return None


def _list_ids(path_to_data):
    """List raw data IDs grouped by symbol ID from a pickle file
       ``path_to_data``."""
    loaded = pickle.load(open(path_to_data))
    raw_datasets = loaded['handwriting_datasets']
    raw_ids = {}
    for raw_dataset in raw_datasets:
        raw_data_id = raw_dataset['handwriting'].raw_data_id
        if raw_dataset['formula_id'] not in raw_ids:
            raw_ids[raw_dataset['formula_id']] = [raw_data_id]
        else:
            raw_ids[raw_dataset['formula_id']].append(raw_data_id)
    for symbol_id in sorted(raw_ids):
        print("%i: %s" % (symbol_id, sorted(raw_ids[symbol_id])))


def _get_description(prev_description):
    """Get the parsed description file (a dictionary) from another
       parsed description file."""
    current_desc_file = os.path.join(utils.get_project_root(),
                                     prev_description['data-source'],
                                     "info.yml")
    if not os.path.isfile(current_desc_file):
        logging.error("You are probably not in the folder of a model, because "
                      "%s is not a file.", current_desc_file)
        sys.exit(-1)
    with open(current_desc_file, 'r') as ymlfile:
        current_description = yaml.load(ymlfile)
    return current_description


def _get_system(model_folder):
    """Return the preprocessing description, the feature description and the
       model description."""

    # Get model description
    model_description_file = os.path.join(model_folder, "info.yml")
    if not os.path.isfile(model_description_file):
        logging.error("You are probably not in the folder of a model, because "
                      "%s is not a file. (-m argument)",
                      model_description_file)
        sys.exit(-1)
    with open(model_description_file, 'r') as ymlfile:
        model_desc = yaml.load(ymlfile)

    # Get the feature and the preprocessing description
    feature_desc = _get_description(model_desc)
    preprocessing_desc = _get_description(feature_desc)

    return (preprocessing_desc, feature_desc, model_desc)


def display_data(raw_data_string, raw_data_id, model_folder):
    """Print ``raw_data_id`` with the content ``raw_data_string`` after
       applying the preprocessing of ``model_folder`` to it."""
    print("## Raw Data (ID: %i)" % raw_data_id)
    print("```")
    print(raw_data_string)
    print("```")

    preprocessing_desc, feature_desc, _ = _get_system(model_folder)

    # Print model
    print("## Model")
    print("%s\n" % model_folder)

    # Print preprocessing queue
    print("## Preprocessing")
    print("```")
    tmp = preprocessing_desc['queue']
    preprocessing_queue = preprocessing.get_preprocessing_queue(tmp)
    for algorithm in preprocessing_queue:
        print("* " + str(algorithm))
    print("```")

    feature_list = features.get_features(feature_desc['features'])
    input_features = sum(map(lambda n: n.get_dimension(), feature_list))
    print("## Features (%i)" % input_features)
    print("```")
    for algorithm in feature_list:
        print("* %s" % str(algorithm))
    print("```")

    # Get Handwriting
    recording = HandwrittenData.HandwrittenData(raw_data_string,
                                                raw_data_id=raw_data_id)

    # Get the preprocessing queue
    tmp = preprocessing_desc['queue']
    preprocessing_queue = preprocessing.get_preprocessing_queue(tmp)
    recording.preprocessing(preprocessing_queue)

    # Get feature values as list of floats, rounded to 3 decimal places
    tmp = feature_desc['features']
    feature_list = features.get_features(tmp)
    feature_values = recording.feature_extraction(feature_list)
    feature_values = [round(el, 3) for el in feature_values]
    print("Features:")
    print(feature_values)

    # Get the list of data multiplication algorithms
    mult_queue = data_multiplication.get_data_multiplication_queue(
        feature_desc['data-multiplication'])

    # Multiply traing_set
    training_set = [recording]
    for algorithm in mult_queue:
        new_trning_set = []
        for recording in training_set:
            samples = algorithm(recording)
            for sample in samples:
                new_trning_set.append(sample)
        training_set = new_trning_set

    # Display it
    for recording in training_set:
        recording.show()


def get_parser():
    """Return the parser object for this script."""
    from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
    parser = ArgumentParser(description=__doc__,
                            formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument("-i", "--id", dest="id", default=292293,
                        type=int,
                        help="which RAW_DATA_ID do you want?")
    parser.add_argument("--mysql", dest="mysql", default='mysql_online',
                        help="which mysql configuration should be used?")
    parser.add_argument("-m", "--model",
                        dest="model",
                        help="where is the model folder (with a info.yml)?",
                        metavar="FOLDER",
                        type=lambda x: utils.is_valid_folder(parser, x),
                        default=utils.default_model())
    parser.add_argument("-l", "--list",
                        dest="list",
                        help="list all raw data IDs / symbol IDs",
                        action='store_true',
                        default=False)
    parser.add_argument("-s", "--server",
                        dest="server",
                        help="contact the MySQL server",
                        action='store_true',
                        default=False)
    return parser

if __name__ == '__main__':
    args = get_parser().parse_args()
    if args.list:
        preprocessing_desc, _, _ = _get_system(args.model)
        raw_datapath = os.path.join(utils.get_project_root(),
                                    preprocessing_desc['data-source'])
        _list_ids(raw_datapath)
    else:
        if args.server:
            data = _fetch_data_from_server(args.id)
            print("hwrt version: %s" % hwrt.__version__)
            display_data(data['data'], data['id'], args.model)
        else:
            logging.info("RAW_DATA_ID %i does not exist or "
                         "database connection did not work.", args.id)
            # The data was not on the server / the connection to the server did
            # not work. So try it again with the model data
            preprocessing_desc, _, _ = _get_system(args.model)
            raw_datapath = os.path.join(utils.get_project_root(),
                                        preprocessing_desc['data-source'])
            handwriting = _get_data_from_rawfile(raw_datapath, args.id)
            if handwriting is None:
                logging.info("Recording with ID %i was not found in %s",
                             args.id,
                             raw_datapath)
            else:
                print("hwrt version: %s" % hwrt.__version__)
                display_data(handwriting.raw_data_json,
                             handwriting.formula_id,
                             args.model)


正如我写的那样,我想获得有关该项目的一般反馈。但是,我对包装没有经验,所以我复制了setup.py。我特别不确定我选择的zip_safe: False是否正确。

我想我到处都遵循PEP8,我使用pylint来改进我的代码。但是,对于view.py,我不理解以下样式警告/不知道如何解决(以一种很好的方式):




 W:115, 4: Redefining name 'preprocessing_desc' from outer scope (line 218) (redefined-outer-name)
W:128, 4: Redefining name 'preprocessing_desc' from outer scope (line 218) (redefined-outer-name)
R:120, 0: Too many local variables (19/15) (too-many-locals)
C:216, 4: Invalid constant name "args" (invalid-name)
C:218, 8: Invalid constant name "preprocessing_desc" (invalid-name)
C:219, 8: Invalid constant name "raw_datapath" (invalid-name)
C:224,12: Invalid constant name "data" (invalid-name)
C:232,12: Invalid constant name "preprocessing_desc" (invalid-name)
C:233,12: Invalid constant name "raw_datapath" (invalid-name)
C:235,12: Invalid constant name "handwriting" (invalid-name)
 



评论

您是否阅读过pylint-messages.wikidot.com

@jonrsharpe:是的,我看过该页面。但是,在“从外部作用域重新定义名称%s”的形式中,它们没有Wiki条目。而且我知道其他警告的出处,但我不知道解决这些警告的好方法。例如,我可以将args更改为ARGS,并且应该将其修复。但是,我从未见过。另外,optparse文档也将其称为args,所以我认为我不应该将其命名为ARGS。

我想如果将__name__ =='__main__':的代码从下面移到函数中,大多数警告都会消失。

@JanneKarila:哦,你是对的!谢谢!

您还可以通过注释禁用特定的警告,例如#pylint:disable =重新定义外部名称

#1 楼

总的来说,我认为这是一个很好的质量,如果代码有点密集,我
表示由于那里有很多功能,所以它不是最容易遵循的代码。我的一个建议是根据主题更多地分组
。所以例如该实用程序可能只有几个
节,用于文件,字符串格式化,调用外部程序。
某种
包装器对象可以处理项目,或者也可以正常工作。

(坦白说,我现在可以只提交拉取请求,这很容易,呵呵。) br />代码

运行测试时,我发现从
open导入future.builtins不起作用(Python 2.7.9),因为没有这样的
模块/ future_builtins也没有open,这意味着
无法安装nntoolkit,并且serve.py:19也会引发错误。我
是什么原因造成的,因为您已经安装了Travis CI,所以我会看到
是否可以找到导致此问题的根本原因。
IMO pickle不是长期数据文件的最佳格式;但是,在这一点上,这对我来说,如果它对您有用,则比起
,为什么不行(尽管您已经有至少一种解决方法,
sys.modules部分,因此请牢记这一点。

为了提高速度,可以使用ujson替代json。 )
。我还希望有一个全局标志来禁用颜色,并且
使用库进行格式化会很好(我看到了colorterm和
termcolor;可能还有其他)。

对于诸如data_analyzation_metrics.py:119之类的东西来说,则不需要"%s" % str(x)

您已经在某些地方进行了此操作,因此,我建议您始终使用
str(如果可能)。

with open("foo") as file:代替self.__repr__()看起来更干净。

repr(self)中,应提取因子features.py:1742
例如像3左右的东西;在
中一般提取常见子表达式(甚至draw_width = 3 if self.pen_down else 2)也可以消除很多代码,因此在这里我不再寻找其他示例。

一般来说,如果您有len(x) ,则不需要if foo: return,只需
删除该缩进即可;也可以早点返回可以消除许多压痕。

对于else,整个方法可以简化为:

def __eq__(self, other):
    return isinstance(other, self.__class__) \
        and self.__dict__ == other.__dict__


HandwrittenData.py:208来自euclidean_distance的也被定义为
preprocessing.py:30,因此,如果在某些情况下,您运行
超过了几个元素,则可以考虑使用它。 br /> scipy.spatial.distance.euclidean

preprocessing.py:497中的for index, point in enumerate(pointlist):应该已经存在于某处...?在可能的情况下使用较少的内存,即仅使用which循环迭代而不是存储结果列表。

包装

selfcheck.py具有一个版本数字,但Git存储库没有
相应的标签/版本。如果您现在开始并添加例如zip是第一个发行版(或类似版本),因此可以更轻松地引用特定的
版本,即从安装脚本中引用。

itertools.izip带有换行符,即在某些情况下可能看起来很奇怪。

关键字看起来不错,除了我怀疑包含for
是否已经以这种方式命名该软件包,以及setup.py可能不需要连字符(需要连字符,请参见注释)。

分类器很好;您还可以列出
Python 3的更多特定版本。

0.1.207没有版本要求。我想说
至少有一些下限(例如,您当前安装的软件包)
将很有用。当然,您可能不完全知道要选择哪个版本,但是对于尝试使其运行的人来说,
还是有帮助的。

我不完全知道外部需求的过程,但是您
只需要指出ImageMagick是一个依赖项即可。

您还修复了PEP8内容,因此现在只有几行太长了剩下;您还可以将long_description添加为预提交钩子,因此
如果不进行所有固定操作就无法签入;我至少将它用于
库代码。 install_requires也是如此;我可能会禁用其中的一些
(并将其添加到Makefile中或作为预提交
钩子再次添加)。

测试

太好了!有点重复,例如requirements.txt被实施了3次。如果可能的话,我会把它移到utils包中(只是另一个),以使其摆脱干扰。

未来的想法

好吧,我喜欢PostgreSQL,所以我认为这会在某个时候出现。如果
没有紧迫的理由专门使用MySQL,那么使用
独立于数据库的库会很酷。

评论


\ $ \ begingroup \ $
哇,太棒了!感谢您的详细反馈!我将花费一些时间来完成此过程,我可能会回来在评论中提出问题,但我只是想让您知道我非常喜欢您的反馈。
\ $ \ endgroup \ $
–马丁·托马
2014年12月16日在21:14

\ $ \ begingroup \ $
很乐意提供帮助。不幸的是,高级架构建议要困难得多,我跳过了这一点,因此也许有人会为此添加些东西。
\ $ \ endgroup \ $
–ferada
2014年12月16日在21:18

\ $ \ begingroup \ $
“'在线'可能不需要连字符”-它确实需要连字符。这是一个非常特定的术语。在这种情况下,在线并不意味着像网络,而是与OCR相反。 (我刚刚通过分解测试助手来改善了测试。我还添加了第一个git标签:-))
\ $ \ endgroup \ $
–马丁·托马
2014-12-16 22:05