1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
| import argparse import torch import os import sys
def find_candidate_last_linear(state_dict): """ 尝试从 state_dict 中寻找最后一层 linear 的 weight key。 策略: 1) 首先找所有 tensor ndim==2 且 shape[0]==1 的 weight(典型 binary linear 输出 [1, hidden]) 2) 若没有,尝试按名字启发式匹配包含常见后缀的 weight(fc, classifier, out, linear, dense) 返回 (weight_key, bias_key) 或 (None, None) """ cand_weights = [k for k, v in state_dict.items() if (isinstance(v, torch.Tensor) and v.ndim == 2 and v.shape[0] == 1)] if cand_weights: wkey = cand_weights[0] bkey = wkey.rsplit('.', 1)[0] + '.bias' if (wkey.rsplit('.', 1)[0] + '.bias') in state_dict else None if bkey is None: bias_cands = [k for k, v in state_dict.items() if isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape[0] == 1] bkey = bias_cands[0] if bias_cands else None return wkey, bkey
name_hints = ['fc.weight', 'classifier.weight', 'out.weight', 'linear.weight', 'dense.weight'] for hint in name_hints: matches = [k for k in state_dict.keys() if hint in k and isinstance(state_dict[k], torch.Tensor)] if matches: wkey = matches[-1] bkey = wkey.rsplit('.', 1)[0] + '.bias' if (wkey.rsplit('.', 1)[0] + '.bias') in state_dict else None if bkey is None: bias_cands = [k for k, v in state_dict.items() if isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape[0] == 1] bkey = bias_cands[0] if bias_cands else None return wkey, bkey
return None, None
def flip_last_layer(input_path, output_path, force=False): if not os.path.exists(input_path): print(f"[ERROR] 输入模型文件不存在: {input_path}") return False
sd = torch.load(input_path, map_location='cpu') if not isinstance(sd, dict): print("[WARN] 加载到的对象不是 state_dict(dict)。如果这是整个 model 对象,尝试保存 state_dict 而不是整个模型。") if hasattr(sd, 'state_dict'): sd = sd.state_dict() else: print("[ERROR] 既不是 state_dict 也不是 model 对象,无法处理。") return False
wkey, bkey = find_candidate_last_linear(sd) if wkey is None: print("[WARN] 未能自动定位最后一层线性权重(没有找到形状为 [1, hidden] 的 weight,也未匹配常见名称)。") print("可选操作:") print(" 1) 手动指定要翻转的权重 key(使用 --weight-key 与 --bias-key)。") print(" 2) 查看 state_dict.keys() 并选择合适的 key。") print("\nstate_dict keys preview:") for i, k in enumerate(list(sd.keys())[:50]): print(f" {i+1:02d}. {k} shape={tuple(sd[k].shape) if isinstance(sd[k], torch.Tensor) else type(sd[k])}") return False
print(f"[INFO] 自动定位到 weight: '{wkey}' bias: '{bkey}'") new_sd = {} for k, v in sd.items(): new_sd[k] = v.clone() if isinstance(v, torch.Tensor) else v
new_sd[wkey] = -new_sd[wkey] if bkey: new_sd[bkey] = -new_sd[bkey] print(f"[OK] 已将 {wkey} 与 {bkey} 取负。") else: print(f"[OK] 已将 {wkey} 取负。未找到对应 bias (将仅翻转 weight)。")
torch.save(new_sd, output_path) print(f"[OK] 已保存翻转后模型为: {output_path} ({os.path.getsize(output_path)/1024:.1f} KB)") return True
def quick_self_test(original_path, flipped_path): """ 若工程包含 src.model.TextClassifier 和 src.parameters.Parameters, 此函数会加载两份 state_dict 到模型并对一个零输入进行比对(近似 y + y_flipped ≈ 1)。 """ try: from src.parameters import Parameters from src.model import TextClassifier except Exception as e: print("[SKIP TEST] 未能导入 TextClassifier/Parameters(", e, "),跳过自检。") return
params = Parameters() model_a = TextClassifier(params) model_b = TextClassifier(params)
sd_a = torch.load(original_path, map_location='cpu') sd_b = torch.load(flipped_path, map_location='cpu')
if not isinstance(sd_a, dict) and hasattr(sd_a, 'state_dict'): sd_a = sd_a.state_dict() if not isinstance(sd_b, dict) and hasattr(sd_b, 'state_dict'): sd_b = sd_b.state_dict()
model_a.load_state_dict(sd_a) model_b.load_state_dict(sd_b) model_a.eval(); model_b.eval()
seq_len = params.seq_len if hasattr(params, 'seq_len') else None if seq_len is None: print("[SKIP TEST] Parameters 未包含 seq_len,无法构造输入进行自检。") return
inp = torch.zeros((1, seq_len), dtype=torch.long) with torch.no_grad(): out_a = model_a(inp) out_b = model_b(inp) def to_scalar(x): if isinstance(x, torch.Tensor): x = x.detach() if x.numel() == 0: return None x = x.squeeze() if x.dim() == 0: return float(x.item()) else: return float(x.view(-1)[0].item()) return None va = to_scalar(out_a) vb = to_scalar(out_b) print(f"[SELF TEST] 原模型输出 (sample): {va}") print(f"[SELF TEST] 翻转后模型输出 (sample): {vb}") if va is not None and vb is not None: print(f"[SELF TEST] va + vb = {va + vb:.6f} (接近 1 则表示翻转成功)")
def main(): parser = argparse.ArgumentParser(description="Flip final linear layer (weight & bias) in a PyTorch state_dict to invert outputs (approx 1 - y).") parser.add_argument("--input", "-i", required=True, help="输入的 state_dict 文件 (pth)") parser.add_argument("--output", "-o", default="flipped_model.pth", help="输出文件名") parser.add_argument("--weight-key", help="手动指定要翻转的 weight key(可选)") parser.add_argument("--bias-key", help="手动指定要翻转的 bias key(可选)") parser.add_argument("--test", action="store_true", help="尝试对原/翻转模型做快速自检(需能导入 src.model)") args = parser.parse_args()
if args.weight_key: sd = torch.load(args.input, map_location='cpu') if not isinstance(sd, dict) and hasattr(sd, 'state_dict'): sd = sd.state_dict() if args.weight_key not in sd: print(f"[ERROR] 指定的 weight key 不存在: {args.weight_key}") print("state_dict keys preview:") for k in list(sd.keys())[:50]: print(" ", k) sys.exit(2) if args.bias_key and args.bias_key not in sd: print(f"[ERROR] 指定的 bias key 不存在: {args.bias_key}") sys.exit(2) new_sd = {} for k, v in sd.items(): new_sd[k] = v.clone() if isinstance(v, torch.Tensor) else v new_sd[args.weight_key] = -new_sd[args.weight_key] if args.bias_key: new_sd[args.bias_key] = -new_sd[args.bias_key] torch.save(new_sd, args.output) print(f"[OK] 翻转并保存到 {args.output}") if args.test: quick_self_test(args.input, args.output) sys.exit(0)
ok = flip_last_layer(args.input, args.output) if not ok: print("[FAILED] 自动翻转失败,请手动指定 --weight-key 和 --bias-key(参考 state_dict.keys())") sys.exit(2)
if args.test: quick_self_test(args.input, args.output)
if __name__ == "__main__": main()
|