一个Java老鸟的TensorFlow入门——从计算图到GradientTape写了20年Java突然要学TensorFlow第一反应是这东西怎么这么绕TF 1.x的计算图、Session、placeholder跟Java的思维方式完全不一样。后来TF 2.x出了GradientTape终于顺畅了。这篇记录我从零开始学TensorFlow的过程不是教程是一个老程序员的踩坑笔记。一、TF 1.x先建图再跑图第一个程序常量加法importtensorflow.compat.v1astf tf.disable_eager_execution()atf.constant(3.0,namenode1)btf.constant(4.0,namenode2)ctf.add(a,b)withtf.Session()assess:print(sess.run(c))# 7.0Java程序员的困惑为什么不直接3.0 4.0因为TF 1.x的设计思路是先画蓝图再施工。tf.constant(3.0)不是在算3.0是在图里画了一个节点。tf.add(a, b)也不是在算加法是在图里画了一条从a、b到c的边。直到sess.run(c)施工队才开始干活。这种声明式编程在Java里也有类似的东西——SQL。你写SELECT * FROM t WHERE id 1也不是在执行是在描述你想要什么数据库引擎去执行。TF 1.x的计算图也是这个意思。第二个程序变量与累加value1tf.Variable(0.0)const1tf.constant(1.0)sum1tf.Variable(0.0)new_value1tf.add(value1,const1)value1value1.assign(new_value1)sum1tf.assign_add(sum1,value1)sesstf.Session()inittf.global_variables_initializer()sess.run(init)foriinrange(10):resultsess.run([value1,sum1])print(第%d次, 累加:%d, 和:%d%(i1,result[0],result[1]))这里有两点跟Java不一样变量要初始化——tf.global_variables_initializer()不调这个变量是空的。Java里int i 0就完了TF里要显式告诉Session请初始化所有变量。赋值是操作不是语句——value1.assign(new_value1)返回的是一个操作节点不是立刻赋值。得sess.run()才生效。第三个程序占位符placeholderatf.placeholder(tf.float32,namea)btf.placeholder(tf.float32,nameb)ctf.add(a,b)dtf.multiply(a,b)withtf.Session()assess:resultsess.run([c,d],feed_dict{a:[1.0,2.0,3.0],b:[4.0,5.0,6.0]})print(result[0])# [5.0, 7.0, 9.0]print(result[1])# [4.0, 10.0, 18.0]placeholder就是方法的参数。先在图里留个坑运行的时候用feed_dict填数据。Java程序员可以理解为接口定义——你声明了参数类型调用时传具体值。还顺手把计算图写到了TensorBoard日志writertf.summary.FileWriter(e:\\log,tf.get_default_graph())打开TensorBoard可以看到可视化计算图——节点和边的拓扑结构。调试时很有用。二、TF 1.x的痛点学了三个例子之后我感觉到几个不舒服的地方所有东西都得在图里——想打个中间变量的值sess.run()。想看类型图里没有运行时类型。调试困难——图建好了跑不了断点。出错了报错信息跟图节点名相关不是Python代码行号。代码啰嗦——建图、初始化、Session、feed_dict干个加法要写一堆。这不是TF的问题是声明式编程的代价。SQL也有类似问题——复杂SQL调试起来也很难。三、TF 2.x终于像正常代码了GradientTape做多项式回归importtensorflowastfimportnumpyasnpimportmatplotlib.pyplotasplt np.random.seed(0)Xnp.linspace(-1,1,100)Y0.5*X**20.5*X2np.random.normal(0,0.05,(100,))X_train,Y_trainX[:70],Y[:70]X_test,Y_testX[70:],Y[70:]W1tf.Variable(np.random.randn())W2tf.Variable(np.random.randn())btf.Variable(np.random.randn())deflinear_regression(x):returnW1*x**2W2*xb optimizertf.optimizers.SGD(learning_rate0.01)forstepinrange(100):withtf.GradientTape()astape:predlinear_regression(X_train)losstf.reduce_mean(tf.square(pred-Y_train))gradientstape.gradient(loss,[W1,W2,b])optimizer.apply_gradients(zip(gradients,[W1,W2,b]))if(step1)%200:print(Step: %i, loss: %f, W1: %f, W2: %f, b: %f%(step1,loss,W1.numpy(),W2.numpy(),b.numpy()))对比TF 1.x变化巨大不需要Session了——直接执行像正常Python代码不需要建图了——GradientTape自动记录前向计算过程调试方便——W1.numpy()随时可以看值不需要sess.run()代码量少了一半GradientTape的核心思想用with tf.GradientTape() as tape包住前向计算TF自动记录所有操作。然后tape.gradient(loss, [参数])自动求导。不需要手写反向传播不需要理解链式法则的推导过程。Java程序员可以类比TF 1.x像JDBC手动管理连接、Statement、ResultSetTF 2.x像MyBatis框架帮你搞定底层你只写业务逻辑。四、Keras加载现成数据集fromkeras.api.datasetsimportmnist,imdb(train_images,train_labels),(test_images,test_labels)mnist.load_data()print(train_images.shape)# (60000, 28, 28)(train_datas,train_labels),(_,_)imdb.load_data()word_indeximdb.get_word_index()reverse_word_indexdict([(value,key)for(key,value)inword_index.items()])decode_view.join(reverse_word_index.get(i-3,?)foriintrain_datas[3])print(decode_view)Keras内置了常用数据集mnist.load_data()直接下载手写数字imdb.load_data()直接下载电影评论。IMDB的数据已经转成了词索引通过word_index反查可以还原原始文本。这一步没什么技术含量但省了很多数据准备的时间。学习阶段用现成数据集项目阶段用自己的数据——这个节奏是对的。五、总结一个Java老兵的TF学习路径阶段我做了什么关键收获TF 1.x常量建图、Session、run()理解声明式编程TF 1.x变量Variable、assign、初始化变量是图的一部分TF 1.x占位符placeholder、feed_dict参数化计算图TF 2.x GradientTape自动求导、多项式回归终于像正常代码了Keras数据集MNIST、IMDB加载数据准备的起点最大的体会如果你现在开始学TensorFlow直接学TF 2.x。TF 1.x的计算图概念了解一下就行很多老教程和老项目还在用但写代码用2.x。GradientTape Eager Execution学习曲线平很多。环境搭建我踩的坑Python版本用3.9-3.11太新可能TF不支持TensorFlow安装pip install tensorflowGPU版装tensorflow-gpu需要CUDA和cuDNN很折腾学习阶段CPU够用如果只是学基础CPU版就行MNIST和线性回归秒跑完相关阅读*《一个46岁架构师的AI实战经验总结》**《老鸟的JVM理解——不是背出来的是搬对象搬出来的》*