使用backtrader选股(单均线金叉选股)

数据源: akshare 沪深300成分股票

策略:class StockScreener(bt.Strategy):

1. params 定均线周期天数
2. def _init_(self):
  • 计算平均线,因为有多个股票数据源(self.datas 列表)),那么均线变量也是列表, 使用列表推导式(List Comprehension)来创建一个列表 ,
    均线 =[均线函数(每个股票收盘价, 周期) for 每个股票 in 数据集 ]
    self.smas = [bt.indicators.SimpleMovingAverage(d.close, period=self.params.period) for d in self.datas]
  • 交叉信号,也是根据数据集, 得出交叉信号列表
    方法一:
self.crossovers = [bt.ind.CrossOver(d.close, self.smas[self.datas.index(d)]) for d in self.datas]

然后可以用:

for i in range(len(self.datas)):
    if self.crossovers[i].value > 0:
        # 执行某个操作
        pass

方法二
也可以使用 enumerate 函数来遍历 self.datas 列表,并为每个元素提供一个索引 i 和对应的值 d。
for i, d in enumerate(self.datas)
使用 enumerate 是为了同时获取数据源(self.datas)的索引和对应的元素。具体来说,enumerate 会生成一个索引-值对,这样可以在遍历数据源的同时获取每个数据源的索引。

self.crossovers = [bt.ind.CrossOver(d.close, self.smas[i]) for i, d in enumerate(self.datas)]

然后可以用:

for i, d in enumerate(self.datas):
        if self.crossovers[i] > 0:
        # 执行某个操作
        pass
3. def next(self):

利用self.crossovers大于0或小于0,判断执行加入选股集
self.selected_stocks.add(d._name)
或移除
self.selected_stocks.discard(d._name)

完整代码:

# run_screener 函数的原理是通过 backtrader 框架加载股票数据,运行策略,
# 沪深300个股票
# 并根据策略的逻辑筛选出符合条件的股票代码。
import backtrader as bt
import akshare as ak
import pandas as pd
from datetime import datetime, timedelta

class StockScreener(bt.Strategy):
    params = (
        ('period', 20),
    )

    def __init__(self):
        # 使用列表推导式(List Comprehension)来创建一个列表
        # for d in self.datas:这是一个循环,遍历 self.datas 列表中的每一个数据源 d。
        # self.datas 通常包含多个数据源,每个数据源代表一个股票或资产的历史数据。
        # 对于每一个数据源 d,创建一个简单移动平均线指标。d.close 表示该数据源的收盘价数据,
        self.smas = [bt.indicators.SimpleMovingAverage(d.close, period=self.params.period) for d in self.datas]
        self.selected_stocks = set()
        self.crossovers = [bt.ind.CrossOver(d.close, self.smas[i]) for i, d in enumerate(self.datas)]  # crossover signals

    def next(self):
        # 使用 enumerate 函数来遍历 self.datas 列表,并为每个元素提供一个索引 i 和对应的值 d。
        # self.datas 是一个包含多个数据源的列表,每个数据源代表一个股票或资产的历史数据。
        for i, d in enumerate(self.datas):
            if self.crossovers[i] > 0:
                print('加入: ', d._name)
                self.selected_stocks.add(d._name)
            elif self.crossovers[i] < 0:
                print('移除: ', d._name)
                self.selected_stocks.discard(d._name)
# ========获取300代码===============================
def get_hs300_stocks():
    # 获取沪深300指数的成分股
    hs300_df = ak.index_stock_cons(symbol="000300")
    hs300_stocks = hs300_df['品种代码'] # .tolist()
    hs300_stocks=hs300_stocks.tolist()
    print('完成获取300股票代码')
    return hs300_stocks

# ========数据清洗=====================================
def fetch_data(stock_code, start_date, end_date):
    stock_data = ak.stock_zh_a_hist(symbol=stock_code, start_date=start_date, end_date=end_date, adjust="qfq")
    stock_data.index = pd.to_datetime(stock_data['日期'])
    stock_data = stock_data[['开盘', '最高', '最低', '收盘', '成交量']]
    stock_data.columns = ['open', 'high', 'low', 'close', 'volume']
    return stock_data
# ========遍历300个股票,分别策略分析==========
def run_screener(stock_codes, start_date, end_date):
    cerebro = bt.Cerebro()
    print('加入数据源')
    for stock_code in stock_codes:
        data = fetch_data(stock_code, start_date, end_date)
        
        if len(data) < 20:
            print(f"Warning: Data for {stock_code} is too short. Skipping.")
            continue
        data = bt.feeds.PandasData(dataname=data)
        
        cerebro.adddata(data, name=stock_code)
    cerebro.addstrategy(StockScreener)
    results = cerebro.run()
    print(results)
    
    return results[0].selected_stocks

# ========回测天数,返回开始日期和今天==========
def get_date_range(days_before):
    end_date = datetime.now()
    start_date = end_date - timedelta(days=days_before)
    
    start_date_str = start_date.strftime('%Y%m%d')
    end_date_str = end_date.strftime('%Y%m%d')
    
    return start_date_str, end_date_str

# ======================================================================================
# ======================================================================================
if __name__ == '__main__':
    stock_codes = get_hs300_stocks()
    
    start_date, end_date = get_date_range(50)
    
    selected_stocks = run_screener(stock_codes, start_date, end_date)
    print("Selected Stocks:", selected_stocks)
    
   

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容