From eb72d29a15a24cc1b68f66161293299cf06f0cc3 Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Mon, 24 Jul 2023 17:13:49 -0700 Subject: [PATCH] Support for dumping LinearLayer in weight-exchange --- dnn/torch/rdovae/export_rdovae_weights.py | 43 ++- .../libs/wexchange-1.0-py3-none-any.whl | Bin 7153 -> 0 bytes .../libs/wexchange-1.2-py3-none-any.whl | Bin 7794 -> 0 bytes dnn/torch/rdovae/requirements.txt | 3 +- dnn/torch/weight-exchange/setup.py | 2 +- .../wexchange/c_export/c_writer.py | 6 +- .../wexchange/c_export/common.py | 323 +++++++++--------- dnn/torch/weight-exchange/wexchange/tf/tf.py | 22 +- .../weight-exchange/wexchange/torch/torch.py | 16 +- 9 files changed, 214 insertions(+), 201 deletions(-) delete mode 100644 dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl delete mode 100644 dnn/torch/rdovae/libs/wexchange-1.2-py3-none-any.whl diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index b6fbaa4b..f9c1db81 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -29,6 +29,9 @@ import os import argparse +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) parser = argparse.ArgumentParser() @@ -83,20 +86,30 @@ def c_export(args, model): message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" - enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message) - dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message) - stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message) - constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True) + enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message, model_struct_name='RDOVAEEnc') + dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec') + stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False) + constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False) # some custom includes - for writer in [enc_writer, dec_writer, stats_writer]: + for writer in [enc_writer, dec_writer]: writer.header.write( f""" #include "opus_types.h" +#include "dred_rdovae.h" + +#include "dred_rdovae_constants.h" + +""" + ) + + stats_writer.header.write( +f""" +#include "opus_types.h" + #include "dred_rdovae_constants.h" -#include "nnet.h" """ ) @@ -111,9 +124,9 @@ f""" ('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH') ] - for name, export_name, activation in encoder_dense_layers: + for name, export_name, _ in encoder_dense_layers: layer = model.get_submodule(name) - dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True) + dump_torch_weights(enc_writer, layer, name=export_name, verbose=True) encoder_gru_layers = [ @@ -122,15 +135,15 @@ f""" ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH') ] - enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True, input_sparse=True, dotp=True) - for name, export_name, activation in encoder_gru_layers]) + enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True) + for name, export_name, _ in encoder_gru_layers]) encoder_conv_layers = [ ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR') ] - enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers]) + enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in encoder_conv_layers]) del enc_writer @@ -148,9 +161,9 @@ f""" ('core_decoder.module.output' , 'dec_final', 'LINEAR') ] - for name, export_name, activation in decoder_dense_layers: + for name, export_name, _ in decoder_dense_layers: layer = model.get_submodule(name) - dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True) + dump_torch_weights(dec_writer, layer, name=export_name, verbose=True) decoder_gru_layers = [ @@ -159,8 +172,8 @@ f""" ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH') ] - dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True, input_sparse=True, dotp=True) - for name, export_name, activation in decoder_gru_layers]) + dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True) + for name, export_name, _ in decoder_gru_layers]) del dec_writer diff --git a/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl deleted file mode 100644 index cfeebae5bafd31e7e674e73120d45c76ab177d3a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7153 zcmaKx1yoeu+Qx?%V(6hkNn*IS!b=;zxD2O_WPV??^gwaf=UDc05AZA-WJe6X`cJUH&-M#?dRFm z)Xmt!(9YbHOJCp8&JwP#&k1vfu=NZ;2cLE!IJ&r4Rb``;O94Cpm47(*11+HAO>z~A zH~n318uf1+V|`OMn7t$X{{taWTwFcuavW;^3x?sMJ2ec`R-HC9oVs&21mqx^VonZo zhELKudPeV36^hAQ={c|#t2p}TB^!c@wY04q<$MTk1)}b60sVMJx8r~W0DMCIkAWEL zyEne6#%xPoF-2?2Q@2U zOEo5jZPBa7YcU%g+dCGi4EamrZM`7`xP4B@K2R+bV+qX9Ul2KOu?mBvJ_&TywLN&R zD}e2a8&9l5Tp@0KN|CsuV~B*a55gId_ujBLQqCpP>hvq(rpeY-ucnNd z!^E(aMKXcxW96;}r?pnpESF_x&QSBzyEUMj)pU36OM915zX0Y_ye>*BKEnn_p@zvp}``!xzuJvrN#YCNjSvFDIELrlx@qSDSeoS7Dh9UR| zJyHn65xT_uk~TPIVAre4G+SA6@t%j7v$=fJ#E#?Iaq|`S_W3IrO?+)ub(Thu|7PyG zaCB==Sp}zZMcmZF-la5`g;kOUt2FdfEpVZdUS{j{z6?&*Jv!f4tIXVIimY*$=3wGQ z#AvA`Xj|Uq>rU&qug@08qID0Ni|U&AxwTjj_G0t-amNqw3T0 zw4ahExz1E}NU#b@7h+h_u#$?8zXJn94AiC>`j|q$w?@AMrdm_(gTGy7@Gc4r#*Aq+ z#>=lb^xOA)_P^68sl7Ob%!!ITp^Y>V6teM}{hVpMc2>R(-%tl+DLBGWHxTYe&7iFm zo6TU=xmm8e6UVK4ofR15@2Fo~-6i8`aoq7x+C?UOoD>8SX&9nELAw6Vbyp(!)Yx?) z>Q0+VD<&C2o7*WjG}4gz1mYSIxHc5F8$#C)-+0_GixS~wzX^US%iJMgtQw0|rw`^p z^^NJ`+y5-vfI*L;wUqVt#m3AzWGcLt5qsd(48c>_G>&1o!Hd`F-Og;F7x$tHqJ)jb zyGaP;WqZIN#fyh1nW~}QF_f_8>~pN!_qC2CR5MDR1QEG zwVPO9kxN|7^h*lR1Y7t#H9d5#+t+dItOhrp&nBIn2H}u4E9vrbs6_IYQPdS>{@BW} zBtT^P7LzzmSb1qauoF9Xv9wA6J|aBaHkFZ)5@}P%(Vbui!gI&g>zx;$AxR z>}>F>O+5;tY3avo=3TQg`;Po(qiA+$i_&K)*^UxEqdmYtOUwx!8P~mxvT8H^*0iZj z5gly%)yi4Dxuo{HpTyk9tml@&IO9Q6>${Qz9^Sd7Qyv2Vqy$Y8r3Z0pkGAmByC2wo zFBq?q%eNVoAk6Wj0=|Wm<}dpB)!0+8=A{ULl7WUCTh=t|Fq$bV%||kqdCt8{eG=Sj zIU3(a-!t&1v~OhU)Y%=|;`Dnq^9wZOCu@2@s$0a~m^h=eDokC5Mwb0|Lfyw3dXVfOI+1Z&33>igIaD*Qa7 z?#xUd-@-lz_@G&x%ilXS_J{yNUP*?emOB&Ze$Xo#vx*NFPCXY#oi*b3MY@M-uUDji z?2C3Zw95jc?~xK8e!nszZ;D__vMVFSNwgt4_2*L+iy+!fBWJo0^hmCboif=FR>gVa z!}s`MJNu{7=wUaw5jqx=fE;k{@%D?SVZ>+m0P_sgqbMYzLQYI_duao9c;QR!>}uZ#FX3|DWgPM{af%68d1FRT zP0>b-L%1pCTjpp5su73RqK3{#%1ImNbAHo|nmA1qk@Kyse)BT;#?{P6`m ztxo%FPYx{XZ3z-nP~4-Fc~%%X7KVI!R64mFF)xyMGHg~N8Cj~FUJ98sy0nU|POF87ds7ADEJU*b^ey@HNa2lH*c>>t&SaRTa4VEX2s3OL@$e^+? zAZN=TxE!UqoO^QBa$LmXef=f+xT3$o^Rr@ie)nFM*tKTQuCoj-wic0ylvB*badroR z3HcQer|z2|drx(Ng55p=g>?-jM_@p=b==K;$d+i~6Nr;Mw9+|0rVMZD&tx#6`&d>f z3wn3=*Z}34$%T=GP39HFOpaDy!1gDjqmSt-K~~3`bai)Z^=jQvL|2S z%HavhmB&Q)#GIY|=+e>toUZaR!=`fv196$K!rX5~xsOi6@UpdikdW9z(8DB`M$(+9 z0Ww<)ndHWG0*O8jut|@ZFSRP^l!p2j=V!WldZY$ZZH}wQav6nmxGI*8tYUq+f{)pc zK}m}J7jwGH($)hk#l=|F$uDi&+Ro=smG!Q=>TQ~GUfyeb-ic`MHQi{c5DTBqG8;x> zu_hogM<*`|^Zuv-ng{27gdKD$;R|w2%kDh4>2#q(sGqD49ojhErK3iJ;WP_-yj|9< z$P_9aEc&E%k43@kN~=BAExLDEO+X!k*Kq#5+bS!MJ%#gFwiZSSrNoC0d9z-gVZIes z{LV?0Wc89UdFexATSxa05$bxACIWgro@E1hJ8B4Hf{}MDmQk%?EkXH$u=s}s`!YjO z09(meqI?jlJa&Aw2m2+xQ|HOO?Zb-dGe5fT9txrA`UXWf4DJs)NZSX8uS-iIuvOx% z%HXaHhsi(={`pd^h-e6Do*sYa4kD2Xyz+>=+Ouafh4%xCW$~SY6vjUEsvj})lx+Kk zhih)sku(-T`MI#;5VCGnGrtAT%8@lZKwwh zRY7$LlZI0;4a%F`WG~%}8GEnS-EB*AXIoQikf6bVD1r+q24Ct%*-9~ugWd7$m^tBR z;1IzfEbq#W2Ax(@iOOZM{MNG^WPTk~ zBuN4nv&qUl(=Orro9X}o%@2JwhkIx&JKnzGc?-G&B&w1mIO$>BB4@d~>^;}#IHvri zX<;MLd@vaWukeDN^VT^e8o(SnS$gZ^7r=7rvE_+Y-Rl!WlXN8t|0G&pqFMKW(WzOc zqq;jlxrZEQFcmK^>FS3WhhCnc_NPML#);trWR>l59WR?}QbkV)n9td3Ak}pwHTSF` zl&5OE&5O9jjGuf)DBpCvx9cRzYpTV1wtjifc0Y{WzU5F2>g#@MbuoqI2VqSZ|Ci*>64xZg$&OJ0rwFe%#nnL(Olxkfe5{q!o3r3RrAM zzc)H$zp%O9W#TtX;*(fmRYX!Q6F2^n_l!3<)p4r!6|5_u4sU2)Ha~Xk!=}+fJie3M zk5+|kIjc9znQIb(v%9xU_J)onblPOAk{1O4(8dPQH($E_bMek-#B+=+e+$r0l(>S1)Ct%?r^t zO!&er-X?+3rp`_4mzNqbQb;8UZv7mNfird#zr2N#l9j(#w0#P{-b@`AEGtJfyT2OK zx)W{v6ew>q_?EQB&O9YkOvtjmK0a$jSok!-)h$F12w$QHmm|nD{KiNROaq*dD~G;P zof$+KaSmeTU_Y`%A9(hOx7MBMyVFD6OI2-L2fd7K(vl|*XHRwd_N;h(KX{a-$Ds)Z z0!;ncj~QnP(8+Sc7Ov;wpmE6w$AK3}>O0jJ+|G}lHVG4@XbqYG;EI(sj%Db}e61bG z8xhO1d3R)X__qQb+6))IHV1{aUY83(U-$kEz3uf0E6SGp8dCpcYTJa@@% z5nU-6Yg39+?_w$p(fUq;R5QR{MUfpj(PU_mMq?O+@?`{@jl*h$`(wmwbu^X)D-XLe zhnCjFGs%|Cg@y)CyoXMBeIjsu`i4B^bvT8hZ$7Trh)AaW!r*gEtpV-zWRV1?C|HN4ugPv|~1@w#f3Wkg{fc1=c5 z2g|Fb1Bh%mo|No$Hq-hOKB1)#AbpK=+?UsHL|78%FrYuRr6|Zbg z(KN$qU(w)^*o6~o)8&01LRR?3c4^_@8!X>fpbtpDr+jf4zl{D(YH8i1Jmv2xX76Zh zaXagS5ZXVpUV6w^a|jqZ7aq>P4k%!uf}aP_GNYZJbMAaWCiUqE|DP27D=zrY-x~pF*6LN7;U@aJ$~H5JK$TXVyld)@?=2e#%-Oo{7?T>- z+Yz<-!vF47bKQX>bu7Ac1(})1@yWAqp4(ymdE|=^joPYYPx;M76gKc`WUI|z0CSQ} zKc(fbVJI60tW(bnhTNl4|D<=h!hpV=6+xJC?x}Zzelf!-B?g2iN<5(%$aU zfa#qcFgu4~lDTB$s^AwGxD*q}tn{(blxPzRtDELBAPA5ghz&FCUyed{rxx1rrHT%P zVK}uhix!4*ozS)p8_Hxdbd4A?R$pw=h&B=|2`aeMQJMvLjX@_Ti5b?#V;;&5R-vN4 zbdSnf9`ut#i=^eb8=&&qm_p)tbFnydauhG5X|Cp5g^(Wc)X)QmLmsO1hTIW{!5>%q zKc>gUc@|mvn@jCdE)e2k+lyNF-B(SoXt8Eu?#~yTAeoS+w6qZtGz?HTdFuInmEyOz zSOHyN2xIoaGT!-;Wj}D4b_u`E=j_X7$lpj4jN2b^pTP4k30jHbgLJNCYL21E=8)bF)R#+rSF-9(oO{J$`|2 z_T^kV{Ix4lRT~s4RIbAvDO&Nn>Px9q2)DLF5e3|M_vjXD@+O*jXk_=>-O$cEo7p^R zc6pP9K;udBVn#w!WRi%dGHmyP+f?f55(#^#W1XFg*5V~g%U;)p>la$pW)7DG7uw`I z_DVArxSTfgJ?8W|;5;HZ^6ipLc^yKqCtz=aw7-h&qS$S7*v*z0f%UBsfv}9{V@-CGD?d>8;Cy;8X=G(+kVIYPDGz?Xt`AM~!ub&zpum5C_(=9OmGc zhd)^-Z?=FYn%C$*pZNNPzC~*yOKu}rrIkK0prIs68&UR{|EM_4fyuwz_3^z)8P4tJ z+woqSQ;ke4pSQL?XqA8OPG%5&s1MbZ$#`&wkR?n1M5Xn?YQ)ZhZG+*n9S7BQ(`U>& z_BmanzI3!xQ-R4e114r9L_W*xZ^7#`m!-PpP*B=cs>4_(N>QJvgr@D|f+9wx^}GuCVoLT}%Z}z}PH+oUYw$p% zOE3~`K3|o#1clsPotE_%?2nou*1_>b!9MTch{V%^F2m{QH17)Eszpye|7jCD9o4m* z<|BYU7T53NT(_X`dv-IvpN;Cjj8juqT3P{XKvuDD2nWK!+K(8hR^uM!opV&~gK)^R zcXL&#$;x-LbB1xD7pV?%5AyO)@QyBldIq?Lj@ZX>m^tM8x)rL_m|0jAmqBt0^=iz8 zeB*t+1CTLAM9IQ0_sjMdDJlq|XDh5CuJom!-0f;2UeF!i+Xfq>e z2a;RZz^choj=mF0r*z1RBnpW_6{-lA>-#_U8c=wD*fzVe-U{a?djfEsr_UzbMZcK%gh^58tAbG&WkL~tj%7K&L9huv^#U8kQzh{ zM7Y)|+gPmBvMGJD#lzAJT4l{*c(9=Ob{m%Dict%EqU^hpb-(yP(mz8Z1-0*LasE&` zhVM$tF8Li8txhk`7&fDDEGi^5r8ki$p`e!Z+Bs?EwLJHz%aJ@((aU$D>>zb7YO2cL z?eln&$d3u_=x`})k1QJ^&p{&~-LWG?7MUC2yR_rA2hOV>eyEJ6Io&xDRF6?2&7uB}G#De0r<2)7{70^*mVf{PW z-sJU8O&r1+F{-``7(YA-HJw8Z@P@#zPI@y6&5Jum(3U7yml%m2uDx2I`$FXcJ;+1R ztk?-s5n?_&Qrg#@1_QrN5p;Z3R)HWP6QTU)Qp3$0{@16*|L5tig@@aSzf~Fj4FCXY z{AGVb{Ih#}tB7%%^|wO7FV@iCS+~jtw^@JN$Nt6gxk<5KL;EMLe|EKRlm52S`$a0i zx=s3h-FKVvxAoaC3jd9S{EFv>@~_p}?fyTb_|ts;>Yt4FTmL_;=xxy7+~O|~&&|^D zRv^E5$J>;@y?K67*lwoRKa@X|>;K+Ef71Tc_g^%^n+n)}qWw$l|3v+%$-hv={{{7% mLjQ^UQYYzSVJsa z?lVi(6mblBgtd5q&tL^mmpqjxE-$riE!1z5tY^&W#;L>Rl`m(!Yv-j#f&0FJvfAcx zxYWjNOjtifmAc;NRfq@km#?n#b?eFj5~~h^W>dUTT?qi|0#V>tNUs}x1y4cqG$UR* zP2gi@{33os`5?IGM0C&8nZ9@K2oK_4B5C?uQW89|E9!cU{s_;K=6O(0n;42wU??$4 zpj|{j57@~*?wzVB_73ZHtR}iE+9+S3Lq?0=We#~8Qnjcox^}2*FZsdh`@wxqF5=#5 zxlg^G^*Ai!X5L!)ucHS05*Sj5Z`ekIx@LDiXpi4!pU{j(}EPx3oD~VZ2lALP4&St92=efSHn?pB<1< zLWLN5ID+|yitbjtn;sptwY5TezVswtS&Y4E8zj;i&YVpuDe zY8fq07b+tN1?p-}$I+j3JTTdieCYJjB4a-LA`)tezhd)|p%pZ9KG316g#PscMwYiH zVUHZ1ey$$UeZD_26|Z8)PaOXu%R0#Bf%8nQ(v6c@yi#<}^+kydwMWIfrSkW5;N=m; z;d&FhV1IAk32f~2M_Rozj}_n+iDKv_7WuUSYF$JNc$S5EyS(c2=3#6`o1x#Zb8I$O zfZ@COPbR)7y!o8yKniIosFMrQX$sCxX?gaP$jqh zt91m>M2B^0f{hXL6f3f=wK0Qi_@QZOZh3!cp5`0H8CE<*L!OL`c_w_)VN)dxHvJ}s zH;Db~-x4xEO~73JUg;W}@VvkcrqtP?Tf|n zIfw*_PAk!8U_VAWU$Efi`%Stvml_lmMkir~~vEp1SZ3DGr;1Kl7`t+ps{ z5FQhurLl{LXF593RCdnZZ{64zZF*N?z$ME@e|QF<<06<9J^%y zG*&>@q_xF}r=6i;0JAU*{#p)hRM)g;w4mhiP7{H`a-<&rae><%as!@m7*DK~pU@mu zi;NcFbc3A;KQEsW1KICeEn&jf5nWt49CO}Q8{@;n0QD6O*1(K|mE$0Ir21?qMe;)p(j?fXHwHWbj{&|^6|qkT>vIksDPy`8n#{VORn1*(}I zE6XA_VU2A2n%wlh5w&ZCGYMJhhV7!=ydRotp}%|?HMP*p|W-B+y4twH|8 zt$Tfy59L}_{AfrnVm(Q`HrhQ=Hjb7~S3oN(j$Z^3!5yow_vsF5|8Z=g(Md|KbL`%8 zs~97X5}rtncH0ZV#_E+Ky@-;<(ENzJt+n%DwFl*1VGq6B=_6M*N2g&1Ybj076d0*J zNPI&HxN~ajE>jomHoOmP8j=JTM`1)PTZyqALW>Bz*yp5<>}5jyX%=3iC}mBBF-@N? z%+d1*z8~OK$kPwhc~c?MY^*aN9ahc)>ALISGdxj(Jh)GMyP-DiU^yLShS_Qy8$bo=gjEig)k=j`3=weBv> z!e?%v0Q@$ZzQM5qL>-4w+&Y%Ji z-l>P#+m|=#*b8pdbu7l3=GySwopb_G1=>=#5rS3z8~EOWZL|$}8dd}?XO)@h6EfJF zJm=TQ6JA*lMwRp*D`%ndu-%+PDlvD~k+b=ypC`}p*EJRMX4RtXM=4J`4%?HI2dzUt zjJ(`<0TvV00^?jRgv6dr9#OvCPa54hN;{C*IpAWTa9~&1iUcNY$9X%O7pBn2K1X`2 zj=2y<$keQmfY>{nEhM%g+%D<8ZWgZ{{zz+ex@gCGB(m9)sEh6RV81%`3(Uw(uQ~|# z*$t5Xapzbq2L)G)0XWIU^gY|#{`k(Q6(wl3zy_{q1I;m7<&+(={=fOsN$tW?8C zJ08z$-a>XtwfJb|*b}RibWGxmBx1C(^OP~ZC8{_04dk5B{YW=C9?(|s{zAg?e6okrBIM0U= z+Vu$X1SQB37IqxMmNPjX<7|o$Ydp2s-GnWt@F*cpDkg6tG@god7wCNM-5>cTTzOU7 zXIY%1Q)-x{;S3#1Rzkw$`+nz_gbsf7;7&+)rVLiU`)TiX)U`&R_RL81wcYd#f3YSP z+&&(|!fVPyJBy2_hK{ANQwpAG%&Hc1%i-*kz+U%Pxwax1xG*0liH`h9G{L^e+--8I z_S}Lx^h&O&8ljCxAMhzI>Ga~DF+!@3u(k%K$`?$Is>M&VQ_BsZ@LUT(I5BLShACES zCKsg_dZ60YORdQXA@ahv-)`uVzO%fN7oXL4T;*-ovePr!{?j@w*%x4D6Khs z@d>xh=K4iwULthUS<|3J;X8N#BKd6fqvA z7t8zZ%py@xnFT#^SLe0y76`eO*Dh-%@b(+EeV47k;$z)smN=B!*?F0|7;2^9!A3`j z_41(GNZlKC7VG#);F30^HE+X3H03}d^CiohamyXMkFd7mM)Ip!mMp@N=rw|L!e*8nVi7dFJJr02b?kr2yY`x7^q7;5072a@CMB?EAh zUnQlK^w5f#?5FnHobyBN?!MCVnr85bZ!#;vsZ@)Z&lMVQzsPo(Z;W>u2!4Y(zA9Uk zJl(l(ww8+JDtFqYcpzs}vXQqf9y|lumz&Ax=$jhp4ZO338%sWN7p~oFBF}qYmmD03C9$XWlGB- zLDNL)CUK$2x#%FyO2a2abtaaYy44JoWB)Gll!i5N&{jYwWjdOmjycMuTLJR0L@DTF z+QaOW92kAm*?L|uGmV)h%AjT|emErFSWgV`Nd(8d)2VN^t}dR%Lm6UzMJJxv%Pjk;{dHbW4-quAvly2al)b50trMRk@HRhC&JrNWz zTtdv)_U8w0r;Cuo#k_WQ&P?%?;bJCfPQ*=R-e@bTP8KTi?~pKF1rAXro)40&qT&>- zD?blnfsd3tTrQx@vOp$&NyUH6@7r9$i*r+25ARdbpM908lXzLjVmY*XdvGaT5;292DO6bB3eV6(%q8f0bG$&WvCWMLs(Vc8qE#O7)SAL4O3ds*$L!O zx()etqvL|+DVde2#<;Yi3LNA%-WymbmiJyIL}W-6yrT^o!)#xE+&I5>uD2Ne?t-9^XriXPT@{L!Yq?uF_-l? zT}MkcJ)ogr$XiwV*XzZYJ_DIc>e)slyy}#+X)x)TckevgoO2GPX9F!O?IA+3UMD7= zY0AozIwY$rt^N2O0q6tJA_N0xL!%mcVr~A$S7Y}6#Whn8Rsi56o0-iqJBnsp_ZsVy zH&42&-*cGX=}%T%ZLOI3w}KqnYe1O1~dkCb_`%j>eX^-YH!e!fG*)z?989$yIvaSVL^;YokHlywq1%3 z#IQ%hJ-uif2hPmN6<+7@JdI1T3irEFnRk5X&g5|T=0!2r#phzt*srVlVOC4u+_G-+ zp58q%X%DCii2f3vrICjS!%X_ZMnSO~jLGkX{-7hzK?gTOKKQol3c(8tem8S8Sa6AU zo#-1`S$fo~&XWbHrgJ3D>V!~&kuVx!9%alNo$#KPq27S-p1r$o)$m~rerI~BM0;q< z7^rvfzpG$*HLr~Ez1lpxS3%O>RmaiA9Q>;ghQV}y2%*$?fX+B5VkIhycLz|+L=L$M zreZ|Ax?(?^gB2DJe)Pm=bgo7vq4Q&#clu;ChJ?6fOFaD^EyB9C6T^g+k*j7x&A}4G!u@~h{I9qmKYs5uvHjVQ!zW(6 z#XXj}k1$58P;XMvXey_9_%cGVd3KO|D@h$iBD|c}KH|mM_k5%iB9WY~h6Fkc%+ZbZ zEedDF!6@g*jIY9Gj&$qR<4?v~D~+%SBD8bUUA3Vs!X=<$E96O{+~GpYgZ=>i#4ZF) zXK@SPO*h|ejvG-#%`PWKi{l;k2J_r>u{bU%qHHV}ybjftc+_Mg00T~40uI_7x*eo; zq6pR+dZw9y7V%^?TGgC5Iko+4*+Oh(fLYXDM8@)(Z%5|L2iJf+ityv}L*2Rs@Wf?!NV-&JHyzcSmLoc+ERT-Q%+PsHk;A^P=mgo-@)|3XkT$vRXm;rB9eA&s(Nm%wgTk*y1(}( zoFiP@RVURA0sN*7?6zn?{&fPTbCq+HOCs>nhS`JG+|ku@xu-%f`Qs|~OM!C@^JYWY zCRNgTmoZI!!9kE6xFks1fhVlML57_#MTB8*6U-$aPO(U9D-uvpe&VfHg9t6q8f*;F zQ@r^)I&uB>d=dAf?bIDThqW|?Rp0VrDR3YiRoQUR_gLY1IONGovTLW&Pug-N!BR|) z`=$Q4~xh4@> zHmbc*&*FeA`HQ&dcRbhLK{K@^^r6<4(1|6POrG1>!^LM)Oiwd9v%2T-E{n03AIHvc zUs`I~pKS+V_I_O%JvYk^VJ{sGh0hhht0Z^(^c>@+Zb)Z6h6V2WbfZS>TOqVmjsixp>izoxj7`;a6q5@&4rT)kIhD<73;T>3LsWy?;bz;_$SG~({I;~b$ti*ce<98r%%7U@ z<^JH$%Rtty@m`~@?g!PcTE)i6&c*)R+7-fP?O^H1p(Ld#CMl*VrW7lt-1`Vi>dZB& zC6KmJp$6a<6!-#FuQm5!XPufIp;aYMPA;;V$<@;Am3}D%YPOUhlLy^!YlP~-v#K<| zRyM=DOs%j{s-XjZXxO0jJ)APo5k9Bq2qxVIN-hm&x$HeN>-Xv7+?Y(DZ&iWYWWD+B zR7$zGWvWoax zN*kzhicRW_r93mkL!%1QO0bI0Yw*9+sy|jDm|zI?!F@RQr?3FJ{}L9&(aG4(!rj7- z9pVXLQK^z0k%cL!Z2XBMHsNA_3=aUjxflH341+%wR$WR$MP1TuTrN3F6%VxHcSpb@ zOmCV9;zCA7MWt~t8pp&+kn>c$y)0CpW>T-?^SGE6*oylWI#Ju&|C9<={y?JH!4j$< z5w+pJ^rb?F;;7jwQD(kpu0~O98&N_{A+yI_;Ih6{W4l zq~*L?sW}cU@ZiUSXXfD`oAf7(nGS;RV}Q37d-atSpmN}1Wb1Lc&C6q?q!-$3?=+|A zqtOpQFt6bnQX4W~B?;gtN;W7e*<;bvYkb|+525_F#jTdzJmXAiJhj;qk%M!F{^n%N z>`qkxNcZ*5n_pMf-;#zjKZLKa!=(Xzb9|PcOzByURS{19S~#YiD9%&1U=V9V;VwbZ z;lzoUP?3~QBWGF|V)76QBdx z|DC!;q+>hIt&k%rJ57B9KQ99Cx#?6cF^(aV>WTlog#ZVShwz`huKP{=ug~kCAD6#6 zU%w*$W^?@;006uWlKlGR}pThhP63*+tApb4XstO>)AOC-F|8=<+_Fa-6cmD$;5qvBF diff --git a/dnn/torch/rdovae/requirements.txt b/dnn/torch/rdovae/requirements.txt index 8afdcda3..9225ea84 100644 --- a/dnn/torch/rdovae/requirements.txt +++ b/dnn/torch/rdovae/requirements.txt @@ -1,5 +1,4 @@ numpy scipy torch -tqdm -libs/wexchange-1.2-py3-none-any.whl \ No newline at end of file +tqdm \ No newline at end of file diff --git a/dnn/torch/weight-exchange/setup.py b/dnn/torch/weight-exchange/setup.py index bf08db19..e590aad6 100644 --- a/dnn/torch/weight-exchange/setup.py +++ b/dnn/torch/weight-exchange/setup.py @@ -39,7 +39,7 @@ with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f: print(install_requires) setup(name='wexchange', - version='1.4', + version='1.5', author='Jan Buethe', author_email='jbuethe@amazon.de', description='Weight-exchange library between Pytorch and Tensorflow', diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py index 8601d7df..36050881 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py @@ -35,8 +35,8 @@ class CWriter: filename_without_extension, message=None, header_only=False, - enable_binary_blob=False, create_state_struct=False, + enable_binary_blob=True, model_struct_name="Model", nnet_header="nnet.h"): """ @@ -78,7 +78,7 @@ class CWriter: self.layer_dict = OrderedDict() # for binary blob format, format is key=, value= - self.weight_arrays = set() + self.weight_arrays = [] # form model struct, format is key=, value= self.state_dict = OrderedDict() @@ -134,6 +134,8 @@ f""" if self.enable_binary_blob: # create weight array + if len(set(self.weight_arrays)) != len(self.weight_arrays): + raise ValueError("error: detected duplicates in weight arrays") self.source.write("\n#ifndef USE_WEIGHTS_FILE\n") self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n") for name in self.weight_arrays: diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py index ae2c39a1..d8b3f7e7 100644 --- a/dnn/torch/weight-exchange/wexchange/c_export/common.py +++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py @@ -29,27 +29,49 @@ import numpy as np from .c_writer import CWriter -def print_vector(writer, vector, name, dtype='float', dotp=False, static=True): +def print_vector(writer, vector, name, dtype='float', reshape_8x4=False, static=True, debug_float=False): + + if isinstance(writer, CWriter): + f = writer.source + binary_blob = writer.enable_binary_blob + else: + f = writer + binary_blob = False + + dtype_suffix = { + 'float' : 'float', + 'opus_int8' : 'int8', + 'opus_uint16' : 'uint16', + 'opus_int16' : 'int16', + 'int' : 'int', + 'qweight': 'qweight' + } - f = writer.source - binary_blob = writer.enable_binary_blob if binary_blob: f.write( f''' #ifndef USE_WEIGHTS_FILE -#define WEIGHTS_{name}_DEFINED -#define WEIGHTS_{name}_TYPE WEIGHT_TYPE_{"qweight" if dotp else "float"} ''' ) - writer.weight_arrays.add(name) + writer.weight_arrays.append(name) - if dotp: + if reshape_8x4: vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8)) vector = vector.transpose((2, 0, 3, 1)) v = np.reshape(vector, (-1)) + if debug_float: + f.write('#ifndef DISABLE_DEBUG_FLOAT\n') + if binary_blob: + f.write( +f''' +#define WEIGHTS_{name}_DEFINED +#define WEIGHTS_{name}_TYPE WEIGHT_TYPE_{dtype_suffix[dtype]} +''' + ) + if static: f.write('static ') @@ -70,6 +92,8 @@ f''' f.write(" ") f.write('\n};\n\n') + if debug_float: f.write('#endif /*DISABLE_DEBUG_FLOAT*/\n') + if binary_blob: f.write( f''' @@ -81,19 +105,48 @@ f''' -def print_sparse_vector(writer, A, name, have_diag=True): - f = writer.source +def extract_diagonal(A): + """ input shape is (N, k*N) """ + + N, M = A.shape + B = A.copy() + assert M % N == 0 + k = M // N + + diags = [] + for l in range(k): + diag = np.diag(B[:, l * N : (l+1) * N]).copy() + B[:, l * N : (l+1) * N] -= np.diag(diag) + diags.append(diag) + + diag = np.concatenate(diags) + + return diag, B + +def quantize_weight(weight, scale): + Aq = np.round(weight / scale).astype('int') + if Aq.max() > 127 or Aq.min() <= -128: + raise ValueError("value out of bounds in quantize_weight") + Aq = np.clip(np.round(weight / scale).astype('int'), -128, 127) + return Aq + + +def print_sparse_weight(writer, A, name, scale=1/128, have_diag=True, quantize=False): N = A.shape[0] M = A.shape[1] W = np.zeros((0,), dtype='int') W0 = np.zeros((0,)) + if have_diag: - diag = np.concatenate([np.diag(A[:,:N]), np.diag(A[:,N:2*N]), np.diag(A[:,2*N:])]) - A[:,:N] = A[:,:N] - np.diag(np.diag(A[:,:N])) - A[:,N:2*N] = A[:,N:2*N] - np.diag(np.diag(A[:,N:2*N])) - A[:,2*N:] = A[:,2*N:] - np.diag(np.diag(A[:,2*N:])) + diag, A = extract_diagonal(A) print_vector(writer, diag, name + '_diag') - AQ = np.minimum(127, np.maximum(-128, np.round(A*128))).astype('int') + + if quantize: + Aq = quantize_weight(A, scale) + else: + Aq = A + + # extract blocks idx = np.zeros((0,), dtype='int') for i in range(M//8): pos = idx.shape[0] @@ -101,7 +154,7 @@ def print_sparse_vector(writer, A, name, have_diag=True): nb_nonzero = 0 for j in range(N//4): block = A[j*4:(j+1)*4, i*8:(i+1)*8] - qblock = AQ[j*4:(j+1)*4, i*8:(i+1)*8] + qblock = Aq[j*4:(j+1)*4, i*8:(i+1)*8] if np.sum(np.abs(block)) > 1e-10: nb_nonzero = nb_nonzero + 1 idx = np.append(idx, j*4) @@ -109,102 +162,125 @@ def print_sparse_vector(writer, A, name, have_diag=True): W0 = np.concatenate([W0, block.reshape((-1,))]) W = np.concatenate([W, vblock]) idx[pos] = nb_nonzero - f.write('#ifdef DOT_PROD\n') - print_vector(writer, W, name, dtype='qweight') - f.write('#else /*DOT_PROD*/\n') - print_vector(writer, W0, name, dtype='qweight') - f.write('#endif /*DOT_PROD*/\n') - print_vector(writer, idx, name + '_idx', dtype='int') - return AQ + if quantize: print_vector(writer, W, name + '_int8', reshape_8x4=False, dtype='opus_int8') + print_vector(writer, W0, name + '_float', reshape_8x4=False, dtype='float', debug_float=quantize) + print_vector(writer, idx, name + '_idx', reshape_8x4=False, dtype='int') + + return Aq + + +def qn(string): + if string == "NULL": return string + else: return '"' + string + '"' + +def print_linear_layer(writer : CWriter, + name : str, + weight : np.ndarray, + bias : np.ndarray, + scale : np.ndarray = None, + sparse : bool = False, + diagonal : bool = False, + quantize : bool = True): + + """ prints linear layer + + Parameters: + ----------- + name : str + layer name + weight: np.ndarray + ... + scale: np.ndarray or None + If None auto scaling will be applied. Otherwise, output channels will be multiplied by scale (the usual broadcasting rules apply). + + + """ + + if len(weight.shape) != 2: + raise ValueError('expecting 2-dim weight array in print_linear_layer') + + + bias_name = "NULL" if bias is None else name + "_bias" + subias_name = name + "_subias" if quantize else "NULL" + scale_name = name + "_scale" if quantize else "NULL" + idx_name = name + "_weights_idx" if sparse else "NULL" + float_weight_name = name + "_weights_float" + int_weight_name = name + "_weights_int8" if quantize else "NULL" + diag_name = name + "_weights_diag" if sparse and diagonal else "NULL" + + nb_inputs, nb_outputs = weight.shape + + if scale is None: + raise ValueError("None scale case not implemented yet.") + + + + if sparse: + weight_q = print_sparse_weight(writer, weight, name + "_weights", scale=scale, have_diag=diagonal, quantize=quantize) + else: + if quantize: + weight_q = quantize_weight(weight, scale) + print_vector(writer, weight_q, name + "_weights_int8", dtype='opus_int8', reshape_8x4=True) + + print_vector(writer, weight, name + "_weights_float", dtype='float', reshape_8x4=False, debug_float=quantize) + + if quantize: + subias = (np.zeros(nb_outputs) if bias is None else bias) - np.sum(weight_q * scale, axis=0) + print_vector(writer, subias, name + "_subias") + + final_scale = scale / 127 * np.ones(nb_outputs) + print_vector(writer, final_scale, name + "_scale") + + if bias is not None: + print_vector(writer, bias, name + "_bias") + + + init_call = f'linear_init(&model->{name}, arrays, {qn(bias_name)}, {qn(subias_name)}, {qn(int_weight_name)},' \ + + f'{qn(float_weight_name)}, {qn(idx_name)}, {qn(diag_name)}, {qn(scale_name)}, {nb_inputs}, {nb_outputs})' + + writer.layer_dict[name] = ('LinearLayer', init_call) -def _check_activation(activation): - if not activation in {"TANH", "SIGMOID", "LINEAR", "SWISH", "RELU", "SOFTMAX"}: - raise ValueError(f"error: unknown activation {activation}") def print_dense_layer(writer : CWriter, name : str, weight : np.ndarray, bias : np.ndarray, - activation: str, - format : str = 'torch'): - - _check_activation(activation) + scale=1/128, + format : str = 'torch', + sparse=False, + diagonal=False, + quantize=False): if format == 'torch': weight = weight.transpose() - print_vector(writer, weight, name + "_weights") - print_vector(writer, bias, name + "_bias") + print_linear_layer(writer, name, weight, bias, scale=scale, sparse=sparse, diagonal=diagonal, quantize=quantize) writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[1]}\n") - if writer.enable_binary_blob: - init_call = f'dense_init(&model->{name}, arrays, "{name}_bias", "{name}_weights", {weight.shape[0]}, {weight.shape[1]}, ACTIVATION_{activation})' - writer.layer_dict[name] = ('DenseLayer', init_call) - else: - writer.source.write( -f""" - -const DenseLayer {name} = {{ - {name}_bias, - {name}_weights, - {weight.shape[0]}, - {weight.shape[1]}, - ACTIVATION_{activation} -}}; - -""" - ) - - writer.header.write(f"\nextern const DenseLayer {name};\n\n") - - - - def print_conv1d_layer(writer : CWriter, name : str, weight : np.ndarray, bias : np.ndarray, - activation: str, - format : str = 'torch'): + scale=1/128, + format : str = 'torch', + quantize=False): - _check_activation(activation) if format == "torch": # convert to channels last weight = np.transpose(weight, (2, 1, 0)) - print_vector(writer, weight, name + "_weights") - print_vector(writer, bias, name + "_bias") + lin_weight = np.reshape(weight, (-1, weight.shape[-1])) + print_linear_layer(writer, name, lin_weight, bias, scale=scale, sparse=False, diagonal=False, quantize=quantize) writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {weight.shape[2]}\n") + writer.header.write(f"\n#define {name.upper()}_IN_SIZE {weight.shape[1]}\n") writer.header.write(f"\n#define {name.upper()}_STATE_SIZE ({weight.shape[1]} * ({weight.shape[0] - 1}))\n") writer.header.write(f"\n#define {name.upper()}_DELAY {(weight.shape[0] - 1) // 2}\n") # CAVE: delay is not a property of the conv layer - if writer.enable_binary_blob: - init_call = f'conv1d_init(&model->{name}, arrays, "{name}_bias", "{name}_weights", {weight.shape[1]}, {weight.shape[0]}, {weight.shape[2]}, ACTIVATION_{activation})' - writer.layer_dict[name] = ('Conv1DLayer', init_call) - else: - - writer.source.write( -f""" - -const Conv1DLayer {name} = {{ - {name}_bias, - {name}_weights, - {weight.shape[1]}, - {weight.shape[0]}, - {weight.shape[2]}, - ACTIVATION_{activation} -}}; - -""" - ) - - writer.header.write(f"\nextern const Conv1DLayer {name};\n\n") - return weight.shape[0] * weight.shape[1] @@ -214,17 +290,16 @@ def print_gru_layer(writer : CWriter, recurrent_weight : np.ndarray, bias : np.ndarray, recurrent_bias : np.ndarray, - activation: str, format : str = 'torch', - dotp : bool = False, + quantize : bool = False, input_sparse : bool = False, - reset_after : int = 0 + recurrent_sparse : bool = False, + scale=1/128, + recurrent_scale=1/128 ): - _check_activation(activation) - if format == "torch": - # transpose weight matrices and change gate order from rzn to zrn + # change gate ordering from rzn to zrn N = weight.shape[0] // 3 for x in [weight, recurrent_weight, bias, recurrent_bias]: @@ -234,80 +309,14 @@ def print_gru_layer(writer : CWriter, weight = weight.transpose() recurrent_weight = recurrent_weight.transpose() - - - # input weights - if input_sparse: - qweight = print_sparse_vector(writer, weight, name + '_weights', have_diag=False) else: - qweight = np.clip(np.round(128. * weight).astype('int'), -128, 127) - - if dotp: - writer.source.write("#ifdef DOT_PROD\n") - print_vector(writer, qweight, name + '_weights', dtype='qweight', dotp=True) - writer.source.write("#else /*DOT_PROD*/\n") - - print_vector(writer, weight, name + '_weights') - - if dotp: - writer.source.write("#endif /*DOT_PROD*/\n") - - - # recurrent weights - recurrent_qweight = np.clip(np.round(128. * recurrent_weight).astype('int'), -128, 127) - - if dotp: - writer.source.write("#ifdef DOT_PROD\n") - print_vector(writer, recurrent_qweight, name + '_recurrent_weights', dtype='qweight', dotp=True) - writer.source.write("#else /*DOT_PROD*/\n") - - print_vector(writer, recurrent_weight, name + '_recurrent_weights') - - if dotp: - writer.source.write("#endif /*DOT_PROD*/\n") - - - # corrected bias for unsigned int matrix multiplication - subias = bias - np.sum(qweight / 128., axis=0) - recurrent_subias = recurrent_bias - np.sum(recurrent_qweight / 128., axis=0) - - print_vector(writer, np.concatenate((bias, recurrent_bias)), name + "_bias") - print_vector(writer, np.concatenate((subias, recurrent_subias)), name + "_subias") + N = weight.shape[1] // 3 + print_linear_layer(writer, name + "_input", weight, bias, scale=scale, sparse=input_sparse, quantize=quantize) + print_linear_layer(writer, name + "_recurrent", recurrent_weight, recurrent_bias, scale=recurrent_scale, sparse=recurrent_sparse, diagonal=recurrent_sparse, quantize=quantize) # wrapping it up writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {N}\n") writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n") - if writer.enable_binary_blob: - if input_sparse: - init_call = f'gru_init(&model->{name}, arrays, "{name}_bias", "{name}_subias", "{name}_weights", "{name + "_weights_idx"}", "{name}_recurrent_weights", {weight.shape[0]}, {weight.shape[1] // 3}, ACTIVATION_{activation}, {reset_after})' - else: - init_call = f'gru_init(&model->{name}, arrays, "{name}_bias", "{name}_subias", "{name}_weights", NULL, "{name}_recurrent_weights", {weight.shape[0]}, {weight.shape[1] // 3}, ACTIVATION_{activation}, {reset_after})' - - writer.layer_dict[name] = ('GRULayer', init_call) - - else: - - writer.source.write( -f""" - -const GRULayer {name} = {{ - {name}_bias, - {name}_subias, - {name}_weights, - {name + "_weights_idx" if input_sparse else "NULL"}, - {name}_recurrent_weights, - {weight.shape[0]}, - {weight.shape[1] // 3}, - ACTIVATION_{activation}, - {reset_after} -}}; - -""" - ) - - writer.header.write(f"\nextern const GRULayer {name};\n") - - - return N + return N \ No newline at end of file diff --git a/dnn/torch/weight-exchange/wexchange/tf/tf.py b/dnn/torch/weight-exchange/wexchange/tf/tf.py index c8f9ed2f..bebbb55a 100644 --- a/dnn/torch/weight-exchange/wexchange/tf/tf.py +++ b/dnn/torch/weight-exchange/wexchange/tf/tf.py @@ -34,7 +34,7 @@ import numpy as np from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer -def dump_tf_gru_weights(where, gru, name=None, input_sparse=False, dotp=False): +def dump_tf_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): assert gru.activation == tf.keras.activations.tanh @@ -47,7 +47,7 @@ def dump_tf_gru_weights(where, gru, name=None, input_sparse=False, dotp=False): b_hh = gru.weights[2].numpy()[1].copy() if isinstance(where, CWriter): - return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, 'TANH', format='tf', reset_after=1, input_sparse=input_sparse, dotp=dotp) + return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='tf', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale) else: os.makedirs(where, exist_ok=True) @@ -87,7 +87,7 @@ def load_tf_gru_weights(path, gru): gru.weights[2].assign(tf.convert_to_tensor(np.vstack((b_ih, b_hh)))) -def dump_tf_dense_weights(where, dense, name=None): +def dump_tf_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False): w = dense.weights[0].numpy() if dense.bias is None: @@ -98,12 +98,7 @@ def dump_tf_dense_weights(where, dense, name=None): if isinstance(where, CWriter): - try: - activation = dense.activation.__name__.upper() - except: - activation = "LINEAR" - - return print_dense_layer(where, name, w, b, activation, format='tf') + return print_dense_layer(where, name, w, b, scale=scale, format='tf', sparse=sparse, diagonal=diagonal, quantize=quantize) else: os.makedirs(where, exist_ok=True) @@ -122,7 +117,7 @@ def load_tf_dense_weights(path, dense): dense.weights[1].assign(tf.convert_to_tensor(b)) -def dump_tf_conv1d_weights(where, conv, name=None): +def dump_tf_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): assert conv.data_format == 'channels_last' @@ -133,12 +128,7 @@ def dump_tf_conv1d_weights(where, conv, name=None): b = conv.bias.numpy() if isinstance(where, CWriter): - try: - activation = conv.activation.__name__.upper() - except: - activation = "LINEAR" - - return print_conv1d_layer(where, name, w, b, activation, format='tf') + return print_conv1d_layer(where, name, w, b, scale=scale, format='tf', quantize=quantize) else: os.makedirs(where, exist_ok=True) diff --git a/dnn/torch/weight-exchange/wexchange/torch/torch.py b/dnn/torch/weight-exchange/wexchange/torch/torch.py index 729c1bc9..4f6d7dfd 100644 --- a/dnn/torch/weight-exchange/wexchange/torch/torch.py +++ b/dnn/torch/weight-exchange/wexchange/torch/torch.py @@ -34,7 +34,7 @@ import numpy as np from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer -def dump_torch_gru_weights(where, gru, name=None, input_sparse=False, dotp=False): +def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): assert gru.num_layers == 1 assert gru.bidirectional == False @@ -45,7 +45,7 @@ def dump_torch_gru_weights(where, gru, name=None, input_sparse=False, dotp=False b_hh = gru.bias_hh_l0.detach().cpu().numpy() if isinstance(where, CWriter): - return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, 'TANH', format='torch', reset_after=1, input_sparse=input_sparse, dotp=dotp) + return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='torch', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale) else: os.makedirs(where, exist_ok=True) @@ -73,7 +73,7 @@ def load_torch_gru_weights(where, gru): gru.bias_hh_l0.set_(torch.from_numpy(b_hh)) -def dump_torch_dense_weights(where, dense, name=None, activation="LINEAR"): +def dump_torch_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False): w = dense.weight.detach().cpu().numpy() if dense.bias is None: @@ -82,7 +82,7 @@ def dump_torch_dense_weights(where, dense, name=None, activation="LINEAR"): b = dense.bias.detach().cpu().numpy() if isinstance(where, CWriter): - return print_dense_layer(where, name, w, b, activation, format='torch') + return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize) else: os.makedirs(where, exist_ok=True) @@ -102,7 +102,7 @@ def load_torch_dense_weights(where, dense): dense.bias.set_(torch.from_numpy(b)) -def dump_torch_conv1d_weights(where, conv, name=None, activation="LINEAR"): +def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): w = conv.weight.detach().cpu().numpy() if conv.bias is None: @@ -112,7 +112,7 @@ def dump_torch_conv1d_weights(where, conv, name=None, activation="LINEAR"): if isinstance(where, CWriter): - return print_conv1d_layer(where, name, w, b, activation, format='torch') + return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize) else: os.makedirs(where, exist_ok=True) @@ -146,12 +146,12 @@ def load_torch_embedding_weights(where, emb): with torch.no_grad(): emb.weight.set_(torch.from_numpy(w)) -def dump_torch_weights(where, module, name=None, activation="LINEAR", verbose=False, **kwargs): +def dump_torch_weights(where, module, name=None, verbose=False, **kwargs): """ generic function for dumping weights of some torch.nn.Module """ if verbose and name is not None: print(f"printing layer {name} of type {type(module)}...") if isinstance(module, torch.nn.Linear): - return dump_torch_dense_weights(where, module, name, activation, **kwargs) + return dump_torch_dense_weights(where, module, name, **kwargs) elif isinstance(module, torch.nn.GRU): return dump_torch_gru_weights(where, module, name, **kwargs) elif isinstance(module, torch.nn.Conv1d):